From 24b037da94e5991a2ba29471d42c6798c5eab39d Mon Sep 17 00:00:00 2001 From: "E.S" Date: Sat, 24 Apr 2021 16:49:33 +0300 Subject: [PATCH 001/239] Add extensions field to ExecutionResult (#188) (#190) * Add extensions field to ExecutionResult * Update graphql-core min version to 3.1.4 --- gql/transport/aiohttp.py | 6 ++- gql/transport/phoenix_channel_websockets.py | 4 +- gql/transport/requests.py | 6 ++- gql/transport/websockets.py | 4 +- setup.py | 2 +- tests/test_aiohttp.py | 32 +++++++++++++++ tests/test_requests.py | 44 +++++++++++++++++++++ tests/test_websocket_query.py | 30 ++++++++++++++ 8 files changed, 123 insertions(+), 5 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index b1a33ad2..cdc4571b 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -236,7 +236,11 @@ async def execute( f"{result_text}" ) - return ExecutionResult(errors=result.get("errors"), data=result.get("data")) + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) def subscribe( self, diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index aaa6686a..557636db 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -189,7 +189,9 @@ def _parse_answer( answer_type = "data" execution_result = ExecutionResult( - errors=payload.get("errors"), data=result.get("data") + errors=payload.get("errors"), + data=result.get("data"), + extensions=payload.get("extensions"), ) elif event == "phx_reply": diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 5eb2a36c..c7d03adb 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -171,7 +171,11 @@ def execute( # type: ignore if "errors" not in result and "data" not in result: raise TransportProtocolError("Server did not return a GraphQL result") - return ExecutionResult(errors=result.get("errors"), data=result.get("data")) + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) def close(self): """Closing the transport by closing the inner session""" diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 76a234bd..e7eb4e8f 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -303,7 +303,9 @@ def _parse_answer( ) execution_result = ExecutionResult( - errors=payload.get("errors"), data=payload.get("data") + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), ) elif answer_type == "error": diff --git a/setup.py b/setup.py index fdcaccf0..496e7f3f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.1,<3.2", + "graphql-core>=3.1.4,<3.2", "yarl>=1.6,<2.0", ] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 0bf8c1ba..815b4904 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -948,3 +948,35 @@ async def handler(request): expected_error = "Syntax Error: Unexpected Name 'BLAHBLAH'" assert expected_error in captured_err + + +query1_server_answer_with_extensions = ( + f'{{"data":{query1_server_answer_data}, "extensions":{{"key1": "val1"}}}}' +) + + +@pytest.mark.asyncio +async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + execution_result = await session._execute(query) + + assert execution_result.extensions["key1"] == "val1" diff --git a/tests/test_requests.py b/tests/test_requests.py index 99d40bf1..a0f8ca27 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -216,3 +216,47 @@ def test_code(): sample_transport.execute(query) await run_sync_test(event_loop, server, test_code) + + +query1_server_answer_with_extensions = ( + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_with_extensions( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + execution_result = session._execute(query) + + assert execution_result.extensions["key1"] == "val1" + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index fc89dc80..e825c637 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -569,3 +569,33 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): received_answer = json.loads(captured_out) assert received_answer == expected_answer + + +query1_server_answer_with_extensions = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}},' + '"extensions": {{"key1": "val1"}}}}}}' +) + +server1_answers_with_extensions = [ + query1_server_answer_with_extensions, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_simple_query_with_extensions( + event_loop, client_and_server, query_str +): + + session, server = client_and_server + + query = gql(query_str) + + execution_result = await session._execute(query) + + assert execution_result.extensions["key1"] == "val1" From c4f2dc2868605df1cbdaedb36818016466dc20cb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 24 Apr 2021 16:58:40 +0200 Subject: [PATCH 002/239] Adding docs and test about cookies (#202) --- docs/transports/aiohttp.rst | 35 +++++++++++++++++++++++++++++++++++ tests/test_aiohttp.py | 33 +++++++++++++++++++++++++++++++++ tests/test_requests.py | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+) diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index a54809cc..4b792232 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -12,4 +12,39 @@ This transport uses the `aiohttp`_ library and allows you to send GraphQL querie .. literalinclude:: ../code_examples/aiohttp_async.py +Authentication +-------------- + +There are multiple ways to authenticate depending on the server configuration. + +1. Using HTTP Headers + +.. code-block:: python + + transport = AIOHTTPTransport( + url='https://SERVER_URL:SERVER_PORT/graphql', + headers={'Authorization': 'token'} + ) + +2. Using HTTP Cookies + +You can manually set the cookies which will be sent with each connection: + +.. code-block:: python + + transport = AIOHTTPTransport(url=url, cookies={"cookie1": "val1"}) + +Or you can use a cookie jar to save cookies set from the backend and reuse them later. + +In some cases, the server will set some connection cookies after a successful login mutation +and you can save these cookies in a cookie jar to reuse them in a following connection +(See `issue 197`_): + +.. code-block:: python + + jar = aiohttp.CookieJar() + transport = AIOHTTPTransport(url=url, client_session_args={'cookie_jar': jar}) + + .. _aiohttp: https://docs.aiohttp.org +.. _issue 197: https://github.com/graphql-python/gql/issues/197 diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 815b4904..5c135b29 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -69,6 +69,39 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.asyncio +async def test_aiohttp_cookies(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + assert "COOKIE" in request.headers + assert "cookie1=val1" == request.headers["COOKIE"] + + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, cookies={"cookie1": "val1"}) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + @pytest.mark.asyncio async def test_aiohttp_error_code_500(event_loop, aiohttp_server): from aiohttp import web diff --git a/tests/test_requests.py b/tests/test_requests.py index a0f8ca27..2afbd84a 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -64,6 +64,43 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + assert "COOKIE" in request.headers + assert "cookie1=val1" == request.headers["COOKIE"] + + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url, cookies={"cookie1": "val1"}) + + with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): From 73d0ba50796a7fa159cc0214e51f3230eb5af07b Mon Sep 17 00:00:00 2001 From: Abhishek Shekhar Date: Sun, 25 Apr 2021 01:17:27 +0530 Subject: [PATCH 003/239] Fix 4xx error handling in transports (#195) --- gql/transport/aiohttp.py | 31 ++++++++++++++++--------------- gql/transport/exceptions.py | 4 ++++ gql/transport/requests.py | 33 +++++++++++++++++++++------------ tests/test_aiohttp.py | 37 ++++++++++++++++++++++++++++++++++--- tests/test_client.py | 3 ++- tests/test_requests.py | 35 +++++++++++++++++++++++++++++++++++ 6 files changed, 112 insertions(+), 31 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index cdc4571b..780027b8 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -206,35 +206,36 @@ async def execute( raise TransportClosed("Transport is not connected") async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: - try: - result = await resp.json() - if log.isEnabledFor(logging.INFO): - result_text = await resp.text() - log.info("<<< %s", result_text) - except Exception: + async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases try: # Raise a ClientResponseError if response status is 400 or higher resp.raise_for_status() - except ClientResponseError as e: - raise TransportServerError(str(e)) from e + raise TransportServerError(str(e), e.status) from e result_text = await resp.text() raise TransportProtocolError( - f"Server did not return a GraphQL result: {result_text}" + f"Server did not return a GraphQL result: " + f"{reason}: " + f"{result_text}" ) + try: + result = await resp.json() + + if log.isEnabledFor(logging.INFO): + result_text = await resp.text() + log.info("<<< %s", result_text) + + except Exception: + await raise_response_error(resp, "Not a JSON answer") + if "errors" not in result and "data" not in result: - result_text = await resp.text() - raise TransportProtocolError( - "Server did not return a GraphQL result: " - 'No "data" or "error" keys in answer: ' - f"{result_text}" - ) + await raise_response_error(resp, 'No "data" or "errors" keys in answer') return ExecutionResult( errors=result.get("errors"), diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 4df2ec43..899d5d66 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -18,6 +18,10 @@ class TransportServerError(TransportError): This exception will close the transport connection. """ + def __init__(self, message=None, code=None): + super(TransportServerError, self).__init__(message) + self.code = code + class TransportQueryError(Exception): """The server returned an error for a specific query. diff --git a/gql/transport/requests.py b/gql/transport/requests.py index c7d03adb..d0bc1467 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -38,7 +38,7 @@ def __init__( verify: bool = True, retries: int = 0, method: str = "POST", - **kwargs: Any + **kwargs: Any, ): """Initialize the transport with the given request parameters. @@ -150,26 +150,35 @@ def execute( # type: ignore response = self.session.request( self.method, self.url, **post_args # type: ignore ) - try: - result = response.json() - if log.isEnabledFor(logging.INFO): - log.info("<<< %s", response.text) - except Exception: + def raise_response_error(resp: requests.Response, reason: str): # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases try: - # Raise a requests.HTTPerror if response status is 400 or higher - response.raise_for_status() - + # Raise a HTTPError if response status is 400 or higher + resp.raise_for_status() except requests.HTTPError as e: - raise TransportServerError(str(e)) + raise TransportServerError(str(e), e.response.status_code) from e + + result_text = resp.text + raise TransportProtocolError( + f"Server did not return a GraphQL result: " + f"{reason}: " + f"{result_text}" + ) - raise TransportProtocolError("Server did not return a GraphQL result") + try: + result = response.json() + + if log.isEnabledFor(logging.INFO): + log.info("<<< %s", response.text) + + except Exception: + raise_response_error(response, "Not a JSON answer") if "errors" not in result and "data" not in result: - raise TransportProtocolError("Server did not return a GraphQL result") + raise_response_error(response, 'No "data" or "errors" keys in answer') return ExecutionResult( errors=result.get("errors"), diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 5c135b29..3fb85cd0 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -102,6 +102,37 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.asyncio +async def test_aiohttp_error_code_401(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "401, message='Unauthorized'" in str(exc_info.value) + + @pytest.mark.asyncio async def test_aiohttp_error_code_500(event_loop, aiohttp_server): from aiohttp import web @@ -163,20 +194,20 @@ async def handler(request): "response": "{}", "expected_exception": ( "Server did not return a GraphQL result: " - 'No "data" or "error" keys in answer: {}' + 'No "data" or "errors" keys in answer: {}' ), }, { "response": "qlsjfqsdlkj", "expected_exception": ( - "Server did not return a GraphQL result: " "qlsjfqsdlkj" + "Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj" ), }, { "response": '{"not_data_or_errors": 35}', "expected_exception": ( "Server did not return a GraphQL result: " - 'No "data" or "error" keys in answer: {"not_data_or_errors": 35}' + 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' ), }, ] diff --git a/tests/test_client.py b/tests/test_client.py index f2a7ecf8..1521eac7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ from gql import Client, gql from gql.transport import Transport +from gql.transport.exceptions import TransportQueryError with suppress(ModuleNotFoundError): from urllib3.exceptions import NewConnectionError @@ -105,7 +106,7 @@ def test_execute_result_error(): """ ) - with pytest.raises(Exception) as exc_info: + with pytest.raises(TransportQueryError) as exc_info: client.execute(failing_query) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) diff --git a/tests/test_requests.py b/tests/test_requests.py index 2afbd84a..e18875a2 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -101,6 +101,41 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "401 Client Error: Unauthorized" in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): From beab41b028f614f2a4c9dd12b42a7fa3fcf558a7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 22 May 2021 12:51:22 +0200 Subject: [PATCH 004/239] Fix tests with graphql-core 3.1.5 (#211) * Fix tests to allow graphql-core>=3.1.5 to break arguments over multiple lines * Stop checking exact error message when passing an int to gql --- tests/starwars/test_dsl.py | 30 ++++++++++++++++++++++++++++-- tests/starwars/test_validation.py | 18 ++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 5807e87f..b105cfa3 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -264,9 +264,15 @@ def test_multiple_operations(ds): ), ) + """ + From graphql-core version 3.1.5, print_ast() break arguments over multiple lines + Accepting both cases here + """ + assert ( - print_ast(query) - == """query GetHeroName { + ( + print_ast(query) + == """query GetHeroName { hero { name } @@ -280,6 +286,26 @@ def test_multiple_operations(ds): } } """ + ) + or ( + print_ast(query) + == """query GetHeroName { + hero { + name + } +} + +mutation CreateReviewMutation { + createReview( + episode: JEDI + review: {stars: 5, commentary: "This is a great movie!"} + ) { + stars + commentary + } +} +""" + ) ) diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index 00384c99..468bb553 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -75,9 +75,23 @@ def validation_errors(client, query): def test_incompatible_request_gql(client): - with pytest.raises(TypeError) as exc_info: + with pytest.raises(TypeError): gql(123) - assert "body must be a string" in str(exc_info.value) + + """ + The error generated depends on graphql-core version + < 3.1.5: "body must be a string" + >= 3.1.5: some variation of "object of type 'int' has no len()" + depending on the python environment + + So we are not going to check the exact error message here anymore. + """ + + """ + assert ("body must be a string" in str(exc_info.value)) or ( + "object of type 'int' has no len()" in str(exc_info.value) + ) + """ def test_nested_query_with_fragment(client): From 5209232682e81896c86b3aae5209321c709ee371 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 24 May 2021 18:23:37 +0200 Subject: [PATCH 005/239] Bump graphql-core version to 3.1.5 (#213) --- setup.py | 2 +- tests/starwars/test_dsl.py | 28 ++-------------------------- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 496e7f3f..e0ca29f6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.1.4,<3.2", + "graphql-core>=3.1.5,<3.2", "yarl>=1.6,<2.0", ] diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index b105cfa3..95a92989 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -264,32 +264,9 @@ def test_multiple_operations(ds): ), ) - """ - From graphql-core version 3.1.5, print_ast() break arguments over multiple lines - Accepting both cases here - """ - assert ( - ( - print_ast(query) - == """query GetHeroName { - hero { - name - } -} - -mutation CreateReviewMutation { - createReview(episode: JEDI, review: {stars: 5, \ -commentary: "This is a great movie!"}) { - stars - commentary - } -} -""" - ) - or ( - print_ast(query) - == """query GetHeroName { + print_ast(query) + == """query GetHeroName { hero { name } @@ -305,7 +282,6 @@ def test_multiple_operations(ds): } } """ - ) ) From 4528977734e60b1ee439d61feb97a715d7da495f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 24 May 2021 19:02:31 +0200 Subject: [PATCH 006/239] Fix incorrect typing of variable_values (#215) --- gql/transport/aiohttp.py | 4 ++-- gql/transport/async_transport.py | 6 +++--- gql/transport/phoenix_channel_websockets.py | 4 ++-- gql/transport/websockets.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 780027b8..77a0c0c2 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -114,7 +114,7 @@ async def close(self) -> None: async def execute( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, extra_args: Dict[str, Any] = None, upload_files: bool = False, @@ -246,7 +246,7 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): def subscribe( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 73cb46d7..7de24015 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,5 +1,5 @@ import abc -from typing import AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Dict, Optional from graphql import DocumentNode, ExecutionResult @@ -25,7 +25,7 @@ async def close(self): async def execute( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> ExecutionResult: """Execute the provided document AST for either a remote or local GraphQL Schema. @@ -38,7 +38,7 @@ async def execute( def subscribe( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using an async generator diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 557636db..27e58f2a 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple from graphql import DocumentNode, ExecutionResult, print_ast from websockets.exceptions import ConnectionClosed @@ -116,7 +116,7 @@ async def _send_connection_terminate_message(self) -> None: async def _send_query( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> int: """Send a query to the provided websocket connection. diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index e7eb4e8f..d1656de2 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -237,7 +237,7 @@ async def _send_connection_terminate_message(self) -> None: async def _send_query( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> int: """Send a query to the provided websocket connection. @@ -394,7 +394,7 @@ async def _handle_answer( async def subscribe( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, send_stop: Optional[bool] = True, ) -> AsyncGenerator[ExecutionResult, None]: @@ -452,7 +452,7 @@ async def subscribe( async def execute( self, document: DocumentNode, - variable_values: Optional[Dict[str, str]] = None, + variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> ExecutionResult: """Execute the provided document AST against the configured remote server From 35203e89dce6d299c4008d1dbd096bb0208b31a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Ram=C3=A9?= <8195958+sneko@users.noreply.github.com> Date: Sat, 5 Jun 2021 22:01:57 +0200 Subject: [PATCH 007/239] Handle keep-alive behavior to close the connection (#201) --- gql/transport/websockets.py | 65 ++++++++++++++++++++++++++-- tests/test_websocket_subscription.py | 62 ++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index d1656de2..701b6de6 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,6 +1,7 @@ import asyncio import json import logging +from contextlib import suppress from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast @@ -94,6 +95,7 @@ def __init__( connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, + keep_alive_timeout: Optional[int] = None, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -107,6 +109,8 @@ def __init__( :param close_timeout: Timeout in seconds for the close. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect """ self.url: str = url @@ -117,6 +121,7 @@ def __init__( self.connect_timeout: int = connect_timeout self.close_timeout: int = close_timeout self.ack_timeout: int = ack_timeout + self.keep_alive_timeout: Optional[int] = keep_alive_timeout self.connect_args = connect_args @@ -125,6 +130,7 @@ def __init__( self.listeners: Dict[int, ListenerQueue] = {} self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None # We need to set an event loop here if there is none @@ -141,6 +147,10 @@ def __init__( self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() + if self.keep_alive_timeout is not None: + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + self._connecting: bool = False self.close_exception: Optional[Exception] = None @@ -315,8 +325,9 @@ def _parse_answer( ) elif answer_type == "ka": - # KeepAlive message - pass + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() elif answer_type == "connection_ack": pass elif answer_type == "connection_error": @@ -332,8 +343,41 @@ def _parse_answer( return answer_type, answer_id, execution_result - async def _receive_data_loop(self) -> None: + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _receive_data_loop(self) -> None: try: while True: @@ -549,6 +593,13 @@ async def connect(self) -> None: await self._fail(e, clean_close=False) raise e + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) @@ -597,6 +648,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # We should always have an active websocket connection here assert self.websocket is not None + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + # Saving exception to raise it later if trying to use the transport # after it has already closed. self.close_exception = e @@ -629,6 +687,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self.websocket = None self.close_task = None + self.check_keep_alive_task = None self._wait_closed.set() diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 7d80c8eb..fcd176b5 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -7,6 +7,7 @@ from parse import search from gql import Client, gql +from gql.transport.exceptions import TransportServerError from .conftest import MS, WebSocketServerHelper @@ -378,6 +379,67 @@ async def test_websocket_subscription_with_keepalive( assert count == -1 +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive_with_timeout_ok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive_with_timeout_nok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_websocket_subscription_sync(server, subscription_str): From ae35c785d80685abfb18f9990db8f8c3b6bf1f9a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 5 Jun 2021 22:10:01 +0200 Subject: [PATCH 008/239] Bump websockets to >=9 (#214) --- gql/transport/websockets.py | 5 +++-- setup.py | 2 +- tests/conftest.py | 8 +++----- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 701b6de6..7e26f31c 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -8,8 +8,8 @@ import websockets from graphql import DocumentNode, ExecutionResult, print_ast from websockets.client import WebSocketClientProtocol +from websockets.datastructures import HeadersLike from websockets.exceptions import ConnectionClosed -from websockets.http import HeadersLike from websockets.typing import Data, Subprotocol from .async_transport import AsyncTransport @@ -573,7 +573,8 @@ async def connect(self) -> None: # Set the _connecting flag to False after in all cases try: self.websocket = await asyncio.wait_for( - websockets.connect(self.url, **connect_args,), self.connect_timeout, + websockets.client.connect(self.url, **connect_args,), + self.connect_timeout, ) finally: self._connecting = False diff --git a/setup.py b/setup.py index e0ca29f6..248099ab 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ ] install_websockets_requires = [ - "websockets>=8.1,<9", + "websockets>=9,<10", ] install_all_requires = ( diff --git a/tests/conftest.py b/tests/conftest.py index 1865152e..44973ae1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,7 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.server", "gql.transport.websockets", "gql.dsl"]: +for name in ["websockets.legacy.server", "gql.transport.websockets", "gql.dsl"]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) @@ -125,7 +125,7 @@ def __init__(self, with_ssl: bool = False): async def start(self, handler): - import websockets + from websockets.legacy import server print("Starting server") @@ -149,9 +149,7 @@ async def start(self, handler): extra_serve_args["ssl"] = ssl_context # Start a server with a random open port - self.start_server = websockets.server.serve( - handler, "127.0.0.1", 0, **extra_serve_args - ) + self.start_server = server.serve(handler, "127.0.0.1", 0, **extra_serve_args) # Wait that the server is started self.server = await self.start_server From 9605a4fc910e3006acc6e1e496d80ad242e5b219 Mon Sep 17 00:00:00 2001 From: "Walther E. Lee" Date: Sat, 5 Jun 2021 13:14:48 -0700 Subject: [PATCH 009/239] Add support for variable definitions in dsl (#210) --- docs/advanced/dsl_module.rst | 41 +++++++++ gql/dsl.py | 159 ++++++++++++++++++++++++++++++++++- tests/starwars/test_dsl.py | 135 ++++++++++++++++++++++++++++- 3 files changed, 331 insertions(+), 4 deletions(-) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index 2e60f045..2ec544b7 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -159,6 +159,47 @@ then you need to create the GraphQL operation using the class ) ) +Variable arguments +^^^^^^^^^^^^^^^^^^ + +To provide variables instead of argument values directly for an operation, you have to: + +* Instanciate a :class:`DSLVariableDefinitions `:: + + var = DSLVariableDefinitions() + +* From this instance you can generate :class:`DSLVariable ` instances + and provide them as the value of the arguments:: + + ds.Mutation.createReview.args(review=var.review, episode=var.episode) + +* Once the operation has been defined, you have to save the variable definitions used + in it:: + + operation.variable_definitions = var + +The following code: + +.. code-block:: python + + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args(review=var.review, episode=var.episode).select( + ds.Review.stars, ds.Review.commentary + ) + ) + op.variable_definitions = var + query = dsl_gql(op) + +will generate a query equivalent to:: + + mutation ($review: ReviewInput, $episode: Episode) { + createReview(review: $review, episode: $episode) { + stars + commentary + } + } + Subscriptions ^^^^^^^^^^^^^ diff --git a/gql/dsl.py b/gql/dsl.py index 72abfcb9..6542d6a6 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,6 +1,6 @@ import logging from abc import ABC -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, @@ -8,24 +8,106 @@ FieldNode, GraphQLArgument, GraphQLField, + GraphQLInputObjectType, + GraphQLInputType, GraphQLInterfaceType, + GraphQLList, GraphQLNamedType, + GraphQLNonNull, GraphQLObjectType, GraphQLSchema, + GraphQLWrappingType, + ListTypeNode, + ListValueNode, + NamedTypeNode, NameNode, + NonNullTypeNode, + NullValueNode, + ObjectFieldNode, + ObjectValueNode, OperationDefinitionNode, OperationType, SelectionSetNode, - ast_from_value, + TypeNode, + Undefined, + ValueNode, + VariableDefinitionNode, + VariableNode, + assert_named_type, + is_input_object_type, + is_list_type, + is_non_null_type, + is_wrapping_type, print_ast, ) from graphql.pyutils import FrozenList +from graphql.utilities import ast_from_value as default_ast_from_value from .utils import to_camel_case log = logging.getLogger(__name__) +def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: + """ + This is a partial copy paste of the ast_from_value function in + graphql-core utilities/ast_from_value.py + + Overwrite the if blocks that use recursion and add a new case to return a + VariableNode when value is a DSLVariable + + Produce a GraphQL Value AST given a Python object. + """ + if isinstance(value, DSLVariable): + return value.set_type(type_).ast_variable + + if is_non_null_type(type_): + type_ = cast(GraphQLNonNull, type_) + ast_value = ast_from_value(value, type_.of_type) + if isinstance(ast_value, NullValueNode): + return None + return ast_value + + # only explicit None, not Undefined or NaN + if value is None: + return NullValueNode() + + # undefined + if value is Undefined: + return None + + # Convert Python list to GraphQL list. If the GraphQLType is a list, but the value + # is not a list, convert the value using the list's item type. + if is_list_type(type_): + type_ = cast(GraphQLList, type_) + item_type = type_.of_type + if isinstance(value, Iterable) and not isinstance(value, str): + maybe_value_nodes = (ast_from_value(item, item_type) for item in value) + value_nodes = filter(None, maybe_value_nodes) + return ListValueNode(values=FrozenList(value_nodes)) + return ast_from_value(value, item_type) + + # Populate the fields of the input object by creating ASTs from each value in the + # Python dict according to the fields in the input type. + if is_input_object_type(type_): + if value is None or not isinstance(value, Mapping): + return None + type_ = cast(GraphQLInputObjectType, type_) + field_items = ( + (field_name, ast_from_value(value[field_name], field.type)) + for field_name, field in type_.fields.items() + if field_name in value + ) + field_nodes = ( + ObjectFieldNode(name=NameNode(value=field_name), value=field_value) + for field_name, field_value in field_items + if field_value + ) + return ObjectValueNode(fields=FrozenList(field_nodes)) + + return default_ast_from_value(value, type_) + + def dsl_gql( *operations: "DSLOperation", **operations_with_name: "DSLOperation" ) -> DocumentNode: @@ -77,6 +159,9 @@ def dsl_gql( OperationDefinitionNode( operation=OperationType(operation.operation_type), selection_set=operation.selection_set, + variable_definitions=FrozenList( + operation.variable_definitions.get_ast_definitions() + ), **({"name": NameNode(value=operation.name)} if operation.name else {}), ) for operation in all_operations @@ -156,6 +241,7 @@ def __init__( """ self.name: Optional[str] = None + self.variable_definitions: DSLVariableDefinitions = DSLVariableDefinitions() # Concatenate fields without and with alias all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( @@ -194,6 +280,75 @@ class DSLSubscription(DSLOperation): operation_type = OperationType.SUBSCRIPTION +class DSLVariable: + """The DSLVariable represents a single variable defined in a GraphQL operation + + Instances of this class are generated for you automatically as attributes + of the :class:`DSLVariableDefinitions` + + The type of the variable is set by the :class:`DSLField` instance that receives it + in the `args` method. + """ + + def __init__(self, name: str): + self.type: Optional[TypeNode] = None + self.name = name + self.ast_variable = VariableNode(name=NameNode(value=self.name)) + + def to_ast_type( + self, type_: Union[GraphQLWrappingType, GraphQLNamedType] + ) -> TypeNode: + if is_wrapping_type(type_): + if isinstance(type_, GraphQLList): + return ListTypeNode(type=self.to_ast_type(type_.of_type)) + elif isinstance(type_, GraphQLNonNull): + return NonNullTypeNode(type=self.to_ast_type(type_.of_type)) + + type_ = assert_named_type(type_) + return NamedTypeNode(name=NameNode(value=type_.name)) + + def set_type( + self, type_: Union[GraphQLWrappingType, GraphQLNamedType] + ) -> "DSLVariable": + self.type = self.to_ast_type(type_) + return self + + +class DSLVariableDefinitions: + """The DSLVariableDefinitions represents variable definitions in a GraphQL operation + + Instances of this class have to be created and set as the `variable_definitions` + attribute of a DSLOperation instance + + Attributes of the DSLVariableDefinitions class are generated automatically + with the `__getattr__` dunder method in order to generate + instances of :class:`DSLVariable`, that can then be used as values in the + `DSLField.args` method + """ + + def __init__(self): + self.variables: Dict[str, DSLVariable] = {} + + def __getattr__(self, name: str) -> "DSLVariable": + if name not in self.variables: + self.variables[name] = DSLVariable(name) + return self.variables[name] + + def get_ast_definitions(self) -> List[VariableDefinitionNode]: + """ + :meta private: + + Return a list of VariableDefinitionNodes for each variable with a type + """ + return [ + VariableDefinitionNode( + type=var.type, variable=var.ast_variable, default_value=None, + ) + for var in self.variables.values() + if var.type is not None # only variables used + ] + + class DSLType: """The DSLType represents a GraphQL type for the DSL code. diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 95a92989..8fdaf426 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,8 +1,29 @@ import pytest -from graphql import print_ast +from graphql import ( + GraphQLInt, + GraphQLList, + GraphQLNonNull, + IntValueNode, + ListTypeNode, + NamedTypeNode, + NameNode, + NonNullTypeNode, + NullValueNode, + Undefined, + print_ast, +) from gql import Client -from gql.dsl import DSLMutation, DSLQuery, DSLSchema, DSLSubscription, dsl_gql +from gql.dsl import ( + DSLMutation, + DSLQuery, + DSLSchema, + DSLSubscription, + DSLVariable, + DSLVariableDefinitions, + ast_from_value, + dsl_gql, +) from .schema import StarWarsSchema @@ -17,6 +38,116 @@ def client(): return Client(schema=StarWarsSchema) +def test_ast_from_value_with_input_type_and_not_mapping_value(): + obj_type = StarWarsSchema.get_type("ReviewInput") + assert ast_from_value(8, obj_type) is None + + +def test_ast_from_value_with_list_type_and_non_iterable_value(): + assert ast_from_value(5, GraphQLList(GraphQLInt)) == IntValueNode(value="5") + + +def test_ast_from_value_with_none(): + assert ast_from_value(None, GraphQLInt) == NullValueNode() + + +def test_ast_from_value_with_undefined(): + assert ast_from_value(Undefined, GraphQLInt) is None + + +def test_ast_from_value_with_non_null_type_and_none(): + typ = GraphQLNonNull(GraphQLInt) + assert ast_from_value(None, typ) is None + + +def test_variable_to_ast_type_passing_wrapping_type(): + wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("Droid"))) + variable = DSLVariable("droids") + ast = variable.to_ast_type(wrapping_type) + assert ast == NonNullTypeNode( + type=ListTypeNode(type=NamedTypeNode(name=NameNode(value="Droid"))) + ) + + +def test_use_variable_definition_multiple_times(ds): + var = DSLVariableDefinitions() + + # `episode` variable is used in both fields + op = DSLMutation( + ds.Mutation.createReview.alias("badReview") + .args(review=var.badReview, episode=var.episode) + .select(ds.Review.stars, ds.Review.commentary), + ds.Mutation.createReview.alias("goodReview") + .args(review=var.goodReview, episode=var.episode) + .select(ds.Review.stars, ds.Review.commentary), + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """mutation ($badReview: ReviewInput, $episode: Episode, $goodReview: ReviewInput) { + badReview: createReview(review: $badReview, episode: $episode) { + stars + commentary + } + goodReview: createReview(review: $goodReview, episode: $episode) { + stars + commentary + } +} +""" + ) + + +def test_add_variable_definitions(ds): + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args(review=var.review, episode=var.episode).select( + ds.Review.stars, ds.Review.commentary + ) + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """mutation ($review: ReviewInput, $episode: Episode) { + createReview(review: $review, episode: $episode) { + stars + commentary + } +} +""" + ) + + +def test_add_variable_definitions_in_input_object(ds): + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review={"stars": var.stars, "commentary": var.commentary}, + episode=var.episode, + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """mutation ($stars: Int, $commentary: String, $episode: Episode) { + createReview( + review: {stars: $stars, commentary: $commentary} + episode: $episode + ) { + stars + commentary + } +} +""" + ) + + def test_invalid_field_on_type_query(ds): with pytest.raises(AttributeError) as exc_info: ds.Query.extras.select(ds.Character.name) From cd00427216be36fc1649e0619c2c7eaded2b4e84 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 5 Jun 2021 23:16:07 +0200 Subject: [PATCH 010/239] Fix websockets 9 import (#217) --- tests/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 44973ae1..62f107ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -125,7 +125,7 @@ def __init__(self, with_ssl: bool = False): async def start(self, handler): - from websockets.legacy import server + import websockets.server print("Starting server") @@ -149,7 +149,9 @@ async def start(self, handler): extra_serve_args["ssl"] = ssl_context # Start a server with a random open port - self.start_server = server.serve(handler, "127.0.0.1", 0, **extra_serve_args) + self.start_server = websockets.server.serve( + handler, "127.0.0.1", 0, **extra_serve_args + ) # Wait that the server is started self.server = await self.start_server From c98d8e266fd2fc3127b65862942b6597f72ae69b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 9 Jun 2021 22:08:42 +0200 Subject: [PATCH 011/239] Bump version number --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index a9b5bf3c..c28a7154 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0a5" +__version__ = "3.0.0a6" From 4e08f09f278414477e5d7e458b866a76476bcbce Mon Sep 17 00:00:00 2001 From: Wilberto Morales Date: Tue, 10 Aug 2021 04:39:23 -0500 Subject: [PATCH 012/239] Fix variable_values spelling in docs (#226) --- docs/usage/file_upload.rst | 2 +- gql/transport/aiohttp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index d5f07c50..18718e75 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -92,7 +92,7 @@ See `Streaming uploads on aiohttp docs`_. In order to stream local files, instead of providing opened files to the -`variables_values` argument of `execute`, you need to provide an async generator +`variable_values` argument of `execute`, you need to provide an async generator which will provide parts of the files. You can use `aiofiles`_ diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 77a0c0c2..84679365 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -128,7 +128,7 @@ async def execute( :code:`execute` on a client or a session. :param document: the parsed GraphQL request - :param variables_values: An optional Dict of variable values + :param variable_values: An optional Dict of variable values :param operation_name: An optional Operation name for the request :param extra_args: additional arguments to send to the aiohttp post method :param upload_files: Set to True if you want to put files in the variable values From 66174a6aefe36d8736e624b6ad82a86b628c0dc4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 17 Aug 2021 01:38:45 +0200 Subject: [PATCH 013/239] Async generators always ensure that inner generator are closed properly (#230) Should close a generator directly after a break if you don't keep any references of the generator --- gql/client.py | 39 ++++++++++++++++------------ tests/test_websocket_subscription.py | 3 ++- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/gql/client.py b/gql/client.py index 30399eb8..e750c63c 100644 --- a/gql/client.py +++ b/gql/client.py @@ -356,13 +356,11 @@ async def _subscribe( # before a break if python version is too old (pypy3 py 3.6.1) self._generator = inner_generator - async for result in inner_generator: - if result.errors: - # Note: we need to run generator.aclose() here or the finally block in - # transport.subscribe will not be reached in pypy3 (py 3.6.1) - await inner_generator.aclose() - - yield result + try: + async for result in inner_generator: + yield result + finally: + await inner_generator.aclose() async def subscribe( self, document: DocumentNode, *args, **kwargs @@ -372,17 +370,24 @@ async def subscribe( The extra arguments are passed to the transport subscribe method.""" - # Validate and subscribe on the transport - async for result in self._subscribe(document, *args, **kwargs): - - # Raise an error if an error is returned in the ExecutionResult object - if result.errors: - raise TransportQueryError( - str(result.errors[0]), errors=result.errors, data=result.data - ) + inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( + document, *args, **kwargs + ) - elif result.data is not None: - yield result.data + try: + # Validate and subscribe on the transport + async for result in inner_generator: + + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError( + str(result.errors[0]), errors=result.errors, data=result.data + ) + + elif result.data is not None: + yield result.data + finally: + await inner_generator.aclose() async def _execute( self, document: DocumentNode, *args, **kwargs diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index fcd176b5..7d87ee81 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -163,7 +163,8 @@ async def test_websocket_subscription_break( if count <= 5: # Note: the following line is only necessary for pypy3 v3.6.1 - await session._generator.aclose() + if sys.version_info < (3, 7): + await session._generator.aclose() break count -= 1 From 20ae2e24b12e415e049264730df321bebd87290f Mon Sep 17 00:00:00 2001 From: Peter Zingg Date: Sun, 22 Aug 2021 09:40:43 -0700 Subject: [PATCH 014/239] Handle Absinthe unsubscriptions (#228) --- gql/transport/phoenix_channel_websockets.py | 233 ++++++++--- gql/transport/websockets.py | 29 +- tests/conftest.py | 9 +- tests/test_phoenix_channel_exceptions.py | 403 +++++++++++++++++--- tests/test_phoenix_channel_query.py | 103 ++++- tests/test_phoenix_channel_subscription.py | 306 ++++++++++++--- 6 files changed, 918 insertions(+), 165 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 27e58f2a..56d35f8b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,5 +1,6 @@ import asyncio import json +import logging from typing import Any, Dict, Optional, Tuple from graphql import DocumentNode, ExecutionResult, print_ast @@ -12,6 +13,16 @@ ) from .websockets import WebsocketsTransport +log = logging.getLogger(__name__) + + +class Subscription: + """Records listener_id and unsubscribe query_id for a subscription.""" + + def __init__(self, query_id: int) -> None: + self.listener_id: int = query_id + self.unsubscribe_id: Optional[int] = None + class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport @@ -24,17 +35,23 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport): """ def __init__( - self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs + self, + channel_name: str = "__absinthe__:control", + heartbeat_interval: float = 30, + *args, + **kwargs, ) -> None: """Initialize the transport with the given parameters. - :param channel_name: Channel on the server this transport will join + :param channel_name: Channel on the server this transport will join. + The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client """ - self.channel_name = channel_name - self.heartbeat_interval = heartbeat_interval - self.subscription_ids_to_query_ids: Dict[str, int] = {} + self.channel_name: str = channel_name + self.heartbeat_interval: float = heartbeat_interval + self.heartbeat_task: Optional[asyncio.Future] = None + self.subscriptions: Dict[str, Subscription] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) async def _send_init_message_and_wait_ack(self) -> None: @@ -90,14 +107,32 @@ async def heartbeat_coro(): self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) async def _send_stop_message(self, query_id: int) -> None: - try: - await self.listeners[query_id].put(("complete", None)) - except KeyError: # pragma: no cover - pass + """Send an 'unsubscribe' message to the Phoenix Channel referencing + the listener's query_id, saving the query_id of the message. - async def _send_connection_terminate_message(self) -> None: - """Send a phx_leave message to disconnect from the provided channel. + The server should afterwards return a 'phx_reply' message with + the same query_id and subscription_id of the 'unsubscribe' request. """ + subscription_id = self._find_existing_subscription(query_id) + + unsubscribe_query_id = self.next_query_id + self.next_query_id += 1 + + # Save the ref so it can be matched in the reply + self.subscriptions[subscription_id].unsubscribe_id = unsubscribe_query_id + unsubscribe_message = json.dumps( + { + "topic": self.channel_name, + "event": "unsubscribe", + "payload": {"subscriptionId": subscription_id}, + "ref": unsubscribe_query_id, + } + ) + + await self._send(unsubscribe_message) + + async def _send_connection_terminate_message(self) -> None: + """Send a phx_leave message to disconnect from the provided channel.""" query_id = self.next_query_id self.next_query_id += 1 @@ -152,7 +187,7 @@ def _parse_answer( Returns a list consisting of: - the answer_type (between: - 'heartbeat', 'data', 'reply', 'error', 'close') + 'data', 'reply', 'complete', 'close') - the answer id (Integer) if received or None - an execution Result if the answer_type is 'data' or None """ @@ -161,56 +196,129 @@ def _parse_answer( answer_id: Optional[int] = None answer_type: str = "" execution_result: Optional[ExecutionResult] = None + subscription_id: Optional[str] = None + + def _get_value(d: Any, key: str, label: str) -> Any: + if not isinstance(d, dict): + raise ValueError(f"{label} is not a dict") + + return d.get(key) + + def _required_value(d: Any, key: str, label: str) -> Any: + value = _get_value(d, key, label) + if value is None: + raise ValueError(f"null {key} in {label}") + + return value + + def _required_subscription_id( + d: Any, label: str, must_exist: bool = False, must_not_exist=False + ) -> str: + subscription_id = str(_required_value(d, "subscriptionId", label)) + if must_exist and (subscription_id not in self.subscriptions): + raise ValueError("unregistered subscriptionId") + if must_not_exist and (subscription_id in self.subscriptions): + raise ValueError("previously registered subscriptionId") + + return subscription_id + + def _validate_data_response(d: Any, label: str) -> dict: + """Make sure query, mutation or subscription answer conforms. + The GraphQL spec says only three keys are permitted. + """ + if not isinstance(d, dict): + raise ValueError(f"{label} is not a dict") + + keys = set(d.keys()) + invalid = keys - {"data", "errors", "extensions"} + if len(invalid) > 0: + raise ValueError( + f"{label} contains invalid items: " + ", ".join(invalid) + ) + return d try: json_answer = json.loads(answer) - event = str(json_answer.get("event")) + event = str(_required_value(json_answer, "event", "answer")) if event == "subscription:data": - payload = json_answer.get("payload") + payload = _required_value(json_answer, "payload", "answer") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - subscription_id = str(payload.get("subscriptionId")) - try: - answer_id = self.subscription_ids_to_query_ids[subscription_id] - except KeyError: - raise ValueError( - f"subscription '{subscription_id}' has not been registerd" - ) - - result = payload.get("result") + subscription_id = _required_subscription_id( + payload, "payload", must_exist=True + ) - if not isinstance(result, dict): - raise ValueError("result is not a dict") + result = _validate_data_response(payload.get("result"), "result") answer_type = "data" + subscription = self.subscriptions[subscription_id] + answer_id = subscription.listener_id + execution_result = ExecutionResult( - errors=payload.get("errors"), data=result.get("data"), - extensions=payload.get("extensions"), + errors=result.get("errors"), + extensions=result.get("extensions"), ) elif event == "phx_reply": - answer_id = int(json_answer.get("ref")) - payload = json_answer.get("payload") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") + # Will generate a ValueError if 'ref' is not there + # or if it is not an integer + answer_id = int(_required_value(json_answer, "ref", "answer")) - status = str(payload.get("status")) + payload = _required_value(json_answer, "payload", "answer") - if status == "ok": + status = _get_value(payload, "status", "payload") + if status == "ok": answer_type = "reply" - response = payload.get("response") - if isinstance(response, dict) and "subscriptionId" in response: - subscription_id = str(response.get("subscriptionId")) - self.subscription_ids_to_query_ids[subscription_id] = answer_id + if answer_id in self.listeners: + response = _required_value(payload, "response", "payload") + + if isinstance(response, dict) and "subscriptionId" in response: + + # Subscription answer + subscription_id = _required_subscription_id( + response, "response", must_not_exist=True + ) + + self.subscriptions[subscription_id] = Subscription( + answer_id + ) + + else: + # Query or mutation answer + # GraphQL spec says only three keys are permitted + response = _validate_data_response(response, "response") + + answer_type = "data" + + execution_result = ExecutionResult( + data=response.get("data"), + errors=response.get("errors"), + extensions=response.get("extensions"), + ) + else: + ( + registered_subscription_id, + listener_id, + ) = self._find_subscription(answer_id) + if registered_subscription_id is not None: + # Unsubscription answer + response = _required_value(payload, "response", "payload") + subscription_id = _required_subscription_id( + response, "response" + ) + + if subscription_id != registered_subscription_id: + raise ValueError("subscription id does not match") + + answer_type = "complete" + + answer_id = listener_id elif status == "error": response = payload.get("response") @@ -224,21 +332,28 @@ def _parse_answer( raise TransportQueryError( str(response.get("reason")), query_id=answer_id ) - raise ValueError("reply error") + raise TransportQueryError("reply error", query_id=answer_id) elif status == "timeout": raise TransportQueryError("reply timeout", query_id=answer_id) + else: + # missing or unrecognized status, just continue + pass elif event == "phx_error": + # Sent if the channel has crashed + # answer_id will be the "join_ref" for the channel + # answer_id = int(json_answer.get("ref")) raise TransportServerError("Server error") elif event == "phx_close": answer_type = "close" else: - raise ValueError + raise ValueError("unrecognized event") except ValueError as e: + log.error(f"Error parsing answer '{answer}': {e!r}") raise TransportProtocolError( - "Server did not return a GraphQL result" + f"Server did not return a GraphQL result: {e!s}" ) from e return answer_type, answer_id, execution_result @@ -254,6 +369,38 @@ async def _handle_answer( else: await super()._handle_answer(answer_type, answer_id, execution_result) + def _remove_listener(self, query_id: int) -> None: + """If the listener was a subscription, remove that information.""" + try: + subscription_id = self._find_existing_subscription(query_id) + del self.subscriptions[subscription_id] + except Exception: + pass + super()._remove_listener(query_id) + + def _find_subscription(self, query_id: int) -> Tuple[Optional[str], int]: + """Perform a reverse lookup to find the subscription id matching + a listener's query_id. + """ + for subscription_id, subscription in self.subscriptions.items(): + if query_id == subscription.listener_id: + return subscription_id, query_id + if query_id == subscription.unsubscribe_id: + return subscription_id, subscription.listener_id + return None, query_id + + def _find_existing_subscription(self, query_id: int) -> str: + """Perform a reverse lookup to find the subscription id matching + a listener's query_id. + """ + subscription_id, _listener_id = self._find_subscription(query_id) + + if subscription_id is None: + raise TransportProtocolError( + f"No subscription registered for listener {query_id}" + ) + return subscription_id + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: if self.heartbeat_task is not None: self.heartbeat_task.cancel() diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 7e26f31c..50eeb6b0 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -175,8 +175,9 @@ async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer """ - # We should always have an active websocket connection here - assert self.websocket is not None + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") # Wait for the next websocket frame. Can raise ConnectionClosed data: Data = await self.websocket.recv() @@ -387,6 +388,8 @@ async def _receive_data_loop(self) -> None: except (ConnectionClosed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break + except TransportClosed: + break # Parse the answer try: @@ -483,15 +486,14 @@ async def subscribe( break except (asyncio.CancelledError, GeneratorExit) as e: - log.debug("Exception in subscribe: " + repr(e)) + log.debug(f"Exception in subscribe: {e!r}") if listener.send_stop: await self._send_stop_message(query_id) listener.send_stop = False finally: - del self.listeners[query_id] - if len(self.listeners) == 0: - self._no_more_listeners.set() + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) async def execute( self, @@ -609,6 +611,19 @@ async def connect(self) -> None: log.debug("connect: done") + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + async def _clean_close(self, e: Exception) -> None: """Coroutine which will: @@ -627,7 +642,7 @@ async def _clean_close(self, e: Exception) -> None: try: await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) except asyncio.TimeoutError: # pragma: no cover - pass + log.debug("Timer close_timeout fired") # Finally send the 'connection_terminate' message await self._send_connection_terminate_message() diff --git a/tests/conftest.py b/tests/conftest.py index 62f107ac..df69c121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,12 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.legacy.server", "gql.transport.websockets", "gql.dsl"]: +for name in [ + "websockets.legacy.server", + "gql.transport.websockets", + "gql.transport.phoenix_channel_websockets", + "gql.dsl", +]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) @@ -170,7 +175,7 @@ async def stop(self): self.server.close() try: - await asyncio.wait_for(self.server.wait_closed(), timeout=1) + await asyncio.wait_for(self.server.wait_closed(), timeout=5) except asyncio.TimeoutError: # pragma: no cover assert False, "Server failed to stop" diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 6f066325..1711d25a 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from gql import Client, gql @@ -7,9 +9,22 @@ TransportServerError, ) +from .conftest import MS + # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets + +def ensure_list(s): + return ( + s + if s is None or isinstance(s, list) + else list(s) + if isinstance(s, tuple) + else [s] + ) + + query1_str = """ query getContinents { continents { @@ -19,21 +34,48 @@ } """ -default_subscription_server_answer = ( +default_query_server_answer = ( '{"event":"phx_reply",' '"payload":' '{"response":' - '{"subscriptionId":"test_subscription"},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' '"status":"ok"},' '"ref":2,' '"topic":"test_topic"}' ) + +# other protocol exceptions + +reply_ref_null_answer = ( + '{"event":"phx_reply","payload":{}', + '"ref":null,' '"topic":"test_topic"}', +) + +reply_ref_zero_answer = ( + '{"event":"phx_reply","payload":{}', + '"ref":0,' '"topic":"test_topic"}', +) + + +# "status":"error" responses + +generic_error_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + error_with_reason_server_answer = ( '{"event":"phx_reply",' '"payload":' - '{"response":' - '{"reason":"internal error"},' + '{"response":{"reason":"internal error"},' '"status":"error"},' '"ref":2,' '"topic":"test_topic"}' @@ -42,8 +84,7 @@ multiple_errors_server_answer = ( '{"event":"phx_reply",' '"payload":' - '{"response":' - '{"errors": ["error 1", "error 2"]},' + '{"response":{"errors": ["error 1", "error 2"]},' '"status":"error"},' '"ref":2,' '"topic":"test_topic"}' @@ -57,31 +98,95 @@ '"topic":"test_topic"}' ) +invalid_payload_data_answer = ( + '{"event":"phx_reply",' '"payload":"INVALID",' '"ref":2,' '"topic":"test_topic"}' +) + +# "status":"ok" exceptions -def server( - query_server_answer, subscription_server_answer=default_subscription_server_answer, -): +invalid_response_server_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":"INVALID",' + '"status":"ok"}' + '"ref":2,' + '"topic":"test_topic"}' +) + +invalid_response_keys_server_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":' + '{"data":{"continents":null},"invalid":null}",' + '"status":"ok"}' + '"ref":2,' + '"topic":"test_topic"}' +) + +invalid_event_server_answer = '{"event":"unknown"}' + + +def query_server(server_answers=default_query_server_answer): from .conftest import PhoenixChannelServerHelper async def phoenix_server(ws, path): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() - await ws.send(subscription_server_answer) - if query_server_answer is not None: - await ws.send(query_server_answer) + for server_answer in ensure_list(server_answers): + await ws.send(server_answer) await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() return phoenix_server +async def no_connection_ack_phoenix_server(ws, path): + from .conftest import PhoenixChannelServerHelper + + await ws.recv() + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + @pytest.mark.asyncio @pytest.mark.parametrize( "server", [ - server(error_with_reason_server_answer), - server(multiple_errors_server_answer), - server(timeout_server_answer), + query_server(reply_ref_null_answer), + query_server(reply_ref_zero_answer), + query_server(invalid_payload_data_answer), + query_server(invalid_response_server_answer), + query_server(invalid_response_keys_server_answer), + no_connection_ack_phoenix_server, + query_server(invalid_event_server_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query_protocol_error(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + query_server(generic_error_server_answer), + query_server(error_with_reason_server_answer), + query_server(multiple_errors_server_answer), + query_server(timeout_server_answer), ], indirect=True, ) @@ -104,71 +209,207 @@ async def test_phoenix_channel_query_error(event_loop, server, query_str): await session.execute(query) -invalid_subscription_id_server_answer = ( +query2_str = """ + subscription getContinents { + continents { + code + name + } + } +""" + +default_subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +ref_is_not_an_integer_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":"not_an_integer",' + '"topic":"test_topic"}' +) + +missing_ref_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"topic":"test_topic"}' +) + +missing_subscription_id_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{},"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +null_subscription_id_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":{"subscriptionId":null},"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +default_subscription_data_answer = ( '{"event":"subscription:data","payload":' - '{"subscriptionId":"INVALID","result":' + '{"subscriptionId":"test_subscription","result":' '{"data":{"continents":[' '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +default_subscription_unsubscribe_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' '"ref":3,' '"topic":"test_topic"}' ) -invalid_payload_server_answer = ( +missing_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +null_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":null,"result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_subscription_id_data_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"INVALID","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_payload_data_answer = ( '{"event":"subscription:data",' '"payload":"INVALID",' - '"ref":3,' - '"topic":"test_topic"}' + '"ref":null,' + '"topic":"test_subscription"}' ) -invalid_result_server_answer = ( +invalid_result_data_answer = ( '{"event":"subscription:data","payload":' - '{"subscriptionId":"test_subscription","result": "INVALID"},' - '"ref":3,' - '"topic":"test_topic"}' + '{"subscriptionId":"test_subscription","result":"INVALID"},' + '"ref":null,' + '"topic":"test_subscription"}' ) -generic_error_server_answer = ( +invalid_result_keys_data_answer = ( + '{"event":"subscription:data",' + '"payload":{"subscriptionId":"test_subscription",' + '"result":{"data":{"continents":null},"invalid":null}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +invalid_subscription_ref_answer = ( '{"event":"phx_reply",' - '"payload":' - '{"status":"error"},' - '"ref":2,' + '"payload":{"response":{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":99,' '"topic":"test_topic"}' ) -protocol_server_answer = '{"event":"unknown"}' - -invalid_payload_subscription_server_answer = ( - '{"event":"phx_reply", "payload":"INVALID", "ref":2, "topic":"test_topic"}' +mismatched_unsubscribe_answer = ( + '{"event":"phx_reply",' + '"payload":{"response":{"subscriptionId":"no_such_subscription"},' + '"status":"ok"},' + '"ref":3,' + '"topic":"test_topic"}' ) -async def no_connection_ack_phoenix_server(ws, path): +def subscription_server( + server_answers=default_subscription_server_answer, + data_answers=default_subscription_data_answer, + unsubscribe_answers=default_subscription_unsubscribe_answer, +): from .conftest import PhoenixChannelServerHelper + import json - await ws.recv() - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() + async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + if server_answers is not None: + for server_answer in ensure_list(server_answers): + await ws.send(server_answer) + if data_answers is not None: + for data_answer in ensure_list(data_answers): + await ws.send(data_answer) + if unsubscribe_answers is not None: + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "unsubscribe" + for unsubscribe_answer in ensure_list(unsubscribe_answers): + await ws.send(unsubscribe_answer) + else: + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + return phoenix_server @pytest.mark.asyncio @pytest.mark.parametrize( "server", [ - server(invalid_subscription_id_server_answer), - server(invalid_result_server_answer), - server(generic_error_server_answer), - no_connection_ack_phoenix_server, - server(protocol_server_answer), - server(invalid_payload_server_answer), - server(None, invalid_payload_subscription_server_answer), + subscription_server(invalid_subscription_ref_answer), + subscription_server(missing_subscription_id_server_answer), + subscription_server(null_subscription_id_server_answer), + subscription_server( + [default_subscription_server_answer, default_subscription_server_answer] + ), + subscription_server(data_answers=missing_subscription_id_data_answer), + subscription_server(data_answers=null_subscription_id_data_answer), + subscription_server(data_answers=invalid_subscription_id_data_answer), + subscription_server(data_answers=ref_is_not_an_integer_server_answer), + subscription_server(data_answers=missing_ref_server_answer), + subscription_server(data_answers=invalid_payload_data_answer), + subscription_server(data_answers=invalid_result_data_answer), + subscription_server(data_answers=invalid_result_keys_data_answer), ], indirect=True, ) -@pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_protocol_error(event_loop, server, query_str): +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_subscription_protocol_error( + event_loop, server, query_str +): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -183,17 +424,17 @@ async def test_phoenix_channel_protocol_error(event_loop, server, query_str): query = gql(query_str) with pytest.raises(TransportProtocolError): async with Client(transport=sample_transport) as session: - await session.execute(query) + async for _result in session.subscribe(query): + await asyncio.sleep(10 * MS) + break -server_error_subscription_server_answer = ( - '{"event":"phx_error", "ref":2, "topic":"test_topic"}' -) +server_error_server_answer = '{"event":"phx_error", "ref":2, "topic":"test_topic"}' @pytest.mark.asyncio @pytest.mark.parametrize( - "server", [server(None, server_error_subscription_server_answer)], indirect=True, + "server", [query_server(server_error_server_answer)], indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_phoenix_channel_server_error(event_loop, server, query_str): @@ -212,3 +453,65 @@ async def test_phoenix_channel_server_error(event_loop, server, query_str): with pytest.raises(TransportServerError): async with Client(transport=sample_transport) as session: await session.execute(query) + + +# These cannot be caught by the client +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + subscription_server(unsubscribe_answers=invalid_subscription_ref_answer), + subscription_server(unsubscribe_answers=mismatched_unsubscribe_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + # Reduce close_timeout. These tests will wait for an unsubscribe + # reply that will never come... + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, close_timeout=1 + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + async for _result in session.subscribe(query): + break + + +# We can force the error if somehow the generator is still running while +# we receive a mismatched unsubscribe answer +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [subscription_server(unsubscribe_answers=mismatched_unsubscribe_answer)], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_unsubscribe_error_forcing(event_loop, server, query_str): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, close_timeout=1 + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + async for _result in session.subscribe(query): + await session.transport._send_stop_message(2) + await asyncio.sleep(10 * MS) diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index c3679ac6..b13a8c55 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -14,6 +14,68 @@ } """ +default_query_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +@pytest.fixture +def ws_server_helper(request): + from .conftest import PhoenixChannelServerHelper + + yield PhoenixChannelServerHelper + + +async def query_server(ws, path): + from .conftest import PhoenixChannelServerHelper + + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(default_query_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query(event_loop, server, query_str): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + result = await session.execute(query) + + print("Client received:", result) + + +query2_str = """ + subscription getContinents { + continents { + code + name + } + } +""" + subscription_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -24,7 +86,7 @@ '"topic":"test_topic"}' ) -query1_server_answer = ( +subscription_data_server_answer = ( '{"event":"subscription:data","payload":' '{"subscriptionId":"test_subscription","result":' '{"data":{"continents":[' @@ -32,33 +94,39 @@ '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' '{"code":"SA","name":"South America"}]}}},' + '"ref":null,' + '"topic":"test_subscription"}' +) + +unsubscribe_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' '"ref":3,' '"topic":"test_topic"}' ) -@pytest.fixture -def ws_server_helper(request): - from .conftest import PhoenixChannelServerHelper - - yield PhoenixChannelServerHelper - - -async def phoenix_server(ws, path): +async def subscription_server(ws, path): from .conftest import PhoenixChannelServerHelper await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(subscription_server_answer) - await ws.send(query1_server_answer) - await PhoenixChannelServerHelper.send_close(ws) + await ws.send(subscription_data_server_answer) + await ws.recv() + await ws.send(unsubscribe_server_answer) + # Unsubscribe will remove the listener + # await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() @pytest.mark.asyncio -@pytest.mark.parametrize("server", [phoenix_server], indirect=True) -@pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_simple_query(event_loop, server, query_str): +@pytest.mark.parametrize("server", [subscription_server], indirect=True) +@pytest.mark.parametrize("query_str", [query2_str]) +async def test_phoenix_channel_subscription(event_loop, server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -69,8 +137,11 @@ async def test_phoenix_channel_simple_query(event_loop, server, query_str): channel_name="test_channel", url=url ) + first_result = None query = gql(query_str) async with Client(transport=sample_transport) as session: - result = await session.execute(query) + async for result in session.subscribe(query): + first_result = result + break - print("Client received:", result) + print("Client received:", first_result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index ef46db47..3c6ec2b2 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -1,5 +1,6 @@ import asyncio import json +import sys import pytest from parse import search @@ -9,26 +10,77 @@ # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets -subscription_server_answer = ( - '{"event":"phx_reply",' - '"payload":' - '{"response":' - '{"subscriptionId":"test_subscription"},' - '"status":"ok"},' - '"ref":2,' - '"topic":"test_topic"}' +test_channel = "test_channel" +test_subscription_id = "test_subscription" + +# A server should send this after receiving a 'phx_leave' request message. +# 'query_id' should be the value of the 'ref' in the 'phx_leave' request. +# With only one listener, the transport is closed automatically when +# it exits a subscription, so this is not used in current tests. +channel_leave_reply_template = ( + "{{" + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{}},' + '"status":"ok"' + "}}," + '"ref":{query_id}' + "}}" ) -countdown_server_answer = ( - '{{"event":"subscription:data",' - '"payload":{{"subscriptionId":"test_subscription","result":' - '{{"data":{{"number":{number}}}}}}},' - '"ref":{query_id}}}' +# A server should send this after sending the 'channel_leave_reply' +# above, to confirm to the client that the channel was actually closed. +# With only one listener, the transport is closed automatically when +# it exits a subscription, so this is not used in current tests. +channel_close_reply_template = ( + "{{" + '"topic":"{channel_name}",' + '"event":"phx_close",' + '"payload":{{}},' + '"ref":null' + "}}" +) + +# A server sends this when it receives a 'subscribe' request, +# after creating a unique subscription id. 'query_id' should be the +# value of the 'ref' in the 'subscribe' request. +subscription_reply_template = ( + "{{" + '"topic":"{channel_name}",' + '"event":"phx_reply",' + '"payload":{{' + '"response":{{' + '"subscriptionId":"{subscription_id}"' + "}}," + '"status":"ok"' + "}}," + '"ref":{query_id}' + "}}" +) + +countdown_data_template = ( + "{{" + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"countdown":{{' + '"number":{number}' + "}}" + "}}" + "}}" + "}}," + '"ref":null' + "}}" ) async def server_countdown(ws, path): import websockets + from .conftest import MS, PhoenixChannelServerHelper try: @@ -37,20 +89,29 @@ async def server_countdown(ws, path): result = await ws.recv() json_result = json.loads(result) assert json_result["event"] == "doc" - payload = json_result["payload"] - query = payload["query"] + channel_name = json_result["topic"] query_id = json_result["ref"] + payload = json_result["payload"] + query = payload["query"] count_found = search("count: {:d}", query) count = count_found[0] print(f"Countdown started from: {count}") - await ws.send(subscription_server_answer) + await ws.send( + subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) async def counting_coro(): for number in range(count, -1, -1): await ws.send( - countdown_server_answer.format(query_id=query_id, number=number) + countdown_data_template.format( + subscription_id=test_subscription_id, number=number + ) ) await asyncio.sleep(2 * MS) @@ -59,12 +120,23 @@ async def counting_coro(): async def stopping_coro(): nonlocal counting_task while True: - result = await ws.recv() json_result = json.loads(result) - if json_result["type"] == "stop" and json_result["id"] == str(query_id): - print("Cancelling counting task now") + if json_result["event"] == "unsubscribe": + query_id = json_result["ref"] + payload = json_result["payload"] + subscription_id = payload["subscriptionId"] + assert subscription_id == test_subscription_id + + print("Sending unsubscribe reply") + await ws.send( + subscription_reply_template.format( + subscription_id=subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) counting_task.cancel() stopping_task = asyncio.ensure_future(stopping_coro()) @@ -74,16 +146,17 @@ async def stopping_coro(): except asyncio.CancelledError: print("Now counting task is cancelled") - stopping_task.cancel() - + # Waiting for a clean stop try: - await stopping_task + await asyncio.wait_for(stopping_task, 3) except asyncio.CancelledError: print("Now stopping task is cancelled") + except asyncio.TimeoutError: + print("Now stopping task is in timeout") - await PhoenixChannelServerHelper.send_close(ws) + # await PhoenixChannelServerHelper.send_close(ws) except websockets.exceptions.ConnectionClosedOK: - pass + print("Connection closed") finally: await ws.wait_closed() @@ -100,15 +173,29 @@ async def stopping_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_phoenix_channel_subscription(event_loop, server, subscription_str): +@pytest.mark.parametrize("end_count", [0, 5]) +async def test_phoenix_channel_subscription( + event_loop, server, subscription_str, end_count +): + """Parameterized test. + + :param end_count: Target count at which the test will 'break' to unsubscribe. + """ + import logging + from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) + from gql.transport.phoenix_channel_websockets import log as phoenix_logger + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) + phoenix_logger.setLevel(logging.DEBUG) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url + channel_name=test_channel, url=url, close_timeout=5 ) count = 10 @@ -116,39 +203,156 @@ async def test_phoenix_channel_subscription(event_loop, server, subscription_str async with Client(transport=sample_transport) as session: async for result in session.subscribe(subscription): - - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count + if number == end_count: + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + # In more recent versions, 'break' will trigger __aexit__. + if sys.version_info < (3, 7): + await session._generator.aclose() + print("break") + break + count -= 1 - assert count == -1 + assert count == end_count + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_phoenix_channel_subscription_no_break( + event_loop, server, subscription_str +): + import logging + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + from gql.transport.phoenix_channel_websockets import log as phoenix_logger + from gql.transport.websockets import log as websockets_logger + + websockets_logger.setLevel(logging.DEBUG) + phoenix_logger.setLevel(logging.DEBUG) + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + async def testing_stopping_without_break(): + + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name=test_channel, url=url, close_timeout=5 + ) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with Client(transport=sample_transport) as session: + async for result in session.subscribe(subscription): + number = result["countdown"]["number"] + print(f"Number received: {number}") + + # Simulate a slow consumer + await asyncio.sleep(0.1) -heartbeat_server_answer = ( - '{{"event":"subscription:data",' - '"payload":{{"subscriptionId":"test_subscription","result":' - '{{"data":{{"heartbeat_count":{count}}}}}}},' - '"ref":1}}' + if number == 9: + # When we consume the number 9 here in the async generator, + # all the 10 numbers have already been sent by the backend and + # are present in the listener queue + # we simulate here an unsubscribe message + # In that case, all the 10 numbers should be consumed in the + # generator and then the generator should be closed properly + await session.transport._send_stop_message(2) + + assert number == count + + count -= 1 + + assert count == -1 + + try: + await asyncio.wait_for(testing_stopping_without_break(), timeout=5) + except asyncio.TimeoutError: + assert False, "The async generator did not stop" + + +heartbeat_data_template = ( + "{{" + '"topic":"{subscription_id}",' + '"event":"subscription:data",' + '"payload":{{' + '"subscriptionId":"{subscription_id}",' + '"result":{{' + '"data":{{' + '"heartbeat":{{' + '"heartbeat_count":{count}' + "}}" + "}}" + "}}" + "}}," + '"ref":null' + "}}" ) async def phoenix_heartbeat_server(ws, path): + import websockets + from .conftest import PhoenixChannelServerHelper - await PhoenixChannelServerHelper.send_connection_ack(ws) - await ws.recv() - await ws.send(subscription_server_answer) + try: + await PhoenixChannelServerHelper.send_connection_ack(ws) - for i in range(3): - heartbeat_result = await ws.recv() - json_result = json.loads(heartbeat_result) - assert json_result["event"] == "heartbeat" - await ws.send(heartbeat_server_answer.format(count=i)) + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "doc" + channel_name = json_result["topic"] + query_id = json_result["ref"] - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() + await ws.send( + subscription_reply_template.format( + subscription_id=test_subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) + + async def heartbeat_coro(): + i = 0 + while True: + heartbeat_result = await ws.recv() + json_result = json.loads(heartbeat_result) + if json_result["event"] == "heartbeat": + await ws.send( + heartbeat_data_template.format( + subscription_id=test_subscription_id, count=i + ) + ) + i = i + 1 + elif json_result["event"] == "unsubscribe": + query_id = json_result["ref"] + payload = json_result["payload"] + subscription_id = payload["subscriptionId"] + assert subscription_id == test_subscription_id + + print("Sending unsubscribe reply") + await ws.send( + subscription_reply_template.format( + subscription_id=subscription_id, + channel_name=channel_name, + query_id=query_id, + ) + ) + + await asyncio.wait_for(heartbeat_coro(), 60) + # await PhoenixChannelServerHelper.send_close(ws) + except websockets.exceptions.ConnectionClosedOK: + print("Connection closed") + finally: + await ws.wait_closed() heartbeat_subscription_str = """ @@ -171,15 +375,23 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url, heartbeat_interval=1 + channel_name=test_channel, url=url, heartbeat_interval=0.1 ) subscription = gql(heartbeat_subscription_str) async with Client(transport=sample_transport) as session: i = 0 async for result in session.subscribe(subscription): - heartbeat_count = result["heartbeat_count"] + heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") assert heartbeat_count == i + if heartbeat_count == 5: + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + # In more recent versions, 'break' will trigger __aexit__. + if sys.version_info < (3, 7): + await session._generator.aclose() + break + i += 1 From b7d8150c444db97d7b2191c1585bd7a3461b0f9c Mon Sep 17 00:00:00 2001 From: "pony.ma" <38464007+ma-pony@users.noreply.github.com> Date: Mon, 23 Aug 2021 15:04:21 +0800 Subject: [PATCH 015/239] feat(transport/requests): RequestsHTTPTransport execute func add extra_args param (#232) --- gql/transport/requests.py | 6 ++++ tests/fixtures/vcr_cassettes/queries.yaml | 44 +++++++++++++++++++++++ tests/test_transport.py | 19 ++++++++++ 3 files changed, 69 insertions(+) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index d0bc1467..7f9ff26a 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -103,6 +103,7 @@ def execute( # type: ignore variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, timeout: Optional[int] = None, + extra_args: Dict[str, Any] = None, ) -> ExecutionResult: """Execute GraphQL query. @@ -114,6 +115,7 @@ def execute( # type: ignore :param operation_name: Name of the operation that shall be executed. Only required in multi-operation documents (Default: None). :param timeout: Specifies a default timeout for requests (Default: None). + :param extra_args: additional arguments to send to the requests post method :return: The result of execution. `data` is the result of executing the query, `errors` is null if no errors occurred, and is a non-empty array if an error occurred. @@ -146,6 +148,10 @@ def execute( # type: ignore # Pass kwargs to requests post method post_args.update(self.kwargs) + # Pass post_args to requests post method + if extra_args: + post_args.update(extra_args) + # Using the created session to perform requests response = self.session.request( self.method, self.url, **post_args # type: ignore diff --git a/tests/fixtures/vcr_cassettes/queries.yaml b/tests/fixtures/vcr_cassettes/queries.yaml index 526a5273..f3fa1c96 100644 --- a/tests/fixtures/vcr_cassettes/queries.yaml +++ b/tests/fixtures/vcr_cassettes/queries.yaml @@ -338,4 +338,48 @@ interactions: status: code: 200 message: OK +- request: + body: '{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}\n"}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '86' + Content-Type: + - application/json + Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + User-Agent: + - python-requests/2.26.0 + authorization: + - xxx-123 + method: POST + uri: http://127.0.0.1:8000/graphql + response: + body: + string: '{"data":{"planet":{"id":"UGxhbmV0OjEx","name":"Geonosis"}}}' + headers: + Content-Length: + - '59' + Content-Type: + - application/json + Date: + - Fri, 06 Nov 2020 11:30:21 GMT + Server: + - WSGIServer/0.1 Python/2.7.18 + Set-Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + expires=Fri, 05-Nov-2021 11:30:21 GMT; Max-Age=31449600; Path=/ + Vary: + - Cookie + X-Frame-Options: + - SAMEORIGIN + status: + code: 200 + message: OK version: 1 diff --git a/tests/test_transport.py b/tests/test_transport.py index fa6a681a..d9a3eced 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -123,3 +123,22 @@ def test_named_query(client): with use_cassette("queries"): result = client.execute(query, operation_name="Planet2") assert result == expected + + +def test_header_query(client): + query = gql( + """ + query Planet($id: ID!) { + planet(id: $id) { + id + name + } + } + """ + ) + expected = {"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}} + with use_cassette("queries"): + result = client.execute( + query, extra_args={"headers": {"authorization": "xxx-123"}} + ) + assert result == expected From ca4021df2cd7ee9b5057ebddb7bf0408e45a1dba Mon Sep 17 00:00:00 2001 From: Michael Ossareh Date: Thu, 26 Aug 2021 04:26:54 -0500 Subject: [PATCH 016/239] Update typing hints of timeouts to allow passing floats (#234) --- gql/client.py | 3 ++- gql/transport/websockets.py | 34 ++++++++++++++++------------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/gql/client.py b/gql/client.py index e750c63c..6017ab69 100644 --- a/gql/client.py +++ b/gql/client.py @@ -46,7 +46,7 @@ def __init__( type_def: Optional[str] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, - execute_timeout: Optional[int] = 10, + execute_timeout: Optional[Union[int, float]] = 10, ): """Initialize the client with the given parameters. @@ -57,6 +57,7 @@ def __init__( the schema from the transport using an introspection query :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. + Passing None results in waiting forever for a response. """ assert not ( type_def and introspection diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 50eeb6b0..4ec8ce89 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -92,10 +92,10 @@ def __init__( headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, init_payload: Dict[str, Any] = {}, - connect_timeout: int = 10, - close_timeout: int = 10, - ack_timeout: int = 10, - keep_alive_timeout: Optional[int] = None, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -105,10 +105,11 @@ def __init__( :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param init_payload: Dict of the payload sent in the connection_init message. :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. - :param close_timeout: Timeout in seconds for the close. + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. + from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect @@ -118,10 +119,10 @@ def __init__( self.headers: Optional[HeadersLike] = headers self.init_payload: Dict[str, Any] = init_payload - self.connect_timeout: int = connect_timeout - self.close_timeout: int = close_timeout - self.ack_timeout: int = ack_timeout - self.keep_alive_timeout: Optional[int] = keep_alive_timeout + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout self.connect_args = connect_args @@ -156,8 +157,7 @@ def __init__( self.close_exception: Optional[Exception] = None async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message - """ + """Send the provided message to the websocket connection and log the message""" if not self.websocket: raise TransportClosed( @@ -172,8 +172,7 @@ async def _send(self, message: str) -> None: raise e async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer - """ + """Wait the next message from the websocket connection and log the answer""" # It is possible that the websocket has been already closed in another task if self.websocket is None: @@ -194,8 +193,7 @@ async def _receive(self) -> str: return answer async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored - """ + """Wait for the connection_ack message. Keep alive messages are ignored""" while True: init_answer = await self._receive() @@ -575,7 +573,7 @@ async def connect(self) -> None: # Set the _connecting flag to False after in all cases try: self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args,), + websockets.client.connect(self.url, **connect_args), self.connect_timeout, ) finally: From 7236bbdc7be51d36420b9c18853f16e96f321299 Mon Sep 17 00:00:00 2001 From: Jimmy Merrild Krag Date: Mon, 6 Sep 2021 10:34:44 +0200 Subject: [PATCH 017/239] DOC: Rephrase and remove redundant word (#237) --- docs/intro.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index 0032d138..e377c56e 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -26,10 +26,9 @@ Less dependencies ^^^^^^^^^^^^^^^^^ GQL supports multiple :ref:`transports ` to communicate with the backend. -Each transport can each necessitate specific dependencies. -If you only need one transport, instead of using the "`all`" extra dependency -as described above which installs everything, -you might want to install only the dependency needed for your transport. +Each transport can necessitate specific dependencies. +If you only need one transport you might want to install only the dependency needed for your transport, +instead of using the "`all`" extra dependency as described above, which installs everything. If for example you only need the :ref:`AIOHTTPTransport `, which needs the :code:`aiohttp` dependency, then you can install GQL with:: From bd96caaaa8ae1aa3fbc2bbc79b193bb2115f8cad Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 12 Sep 2021 19:29:52 +0200 Subject: [PATCH 018/239] DSL implementation of fragments (#235) --- docs/advanced/dsl_module.rst | 91 ++++- docs/code_examples/aiohttp_async_dsl.py | 2 +- docs/code_examples/requests_sync_dsl.py | 2 +- gql/cli.py | 6 +- gql/dsl.py | 428 ++++++++++++++++++------ tests/starwars/test_dsl.py | 164 ++++++++- 6 files changed, 584 insertions(+), 109 deletions(-) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index 2ec544b7..afaa3bc6 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -164,7 +164,7 @@ Variable arguments To provide variables instead of argument values directly for an operation, you have to: -* Instanciate a :class:`DSLVariableDefinitions `:: +* Instantiate a :class:`DSLVariableDefinitions `:: var = DSLVariableDefinitions() @@ -252,6 +252,93 @@ It is possible to create an Document with multiple operations:: operation_name_3=DSLMutation( ... ), ) +Fragments +^^^^^^^^^ + +To define a `Fragment`_, you have to: + +* Instantiate a :class:`DSLFragment ` with a name:: + + name_and_appearances = DSLFragment("NameAndAppearances") + +* Provide the GraphQL type of the fragment with the + :meth:`on ` method:: + + name_and_appearances.on(ds.Character) + +* Add children fields using the :meth:`select ` method:: + + name_and_appearances.select(ds.Character.name, ds.Character.appearsIn) + +Once your fragment is defined, to use it you should: + +* select it as a field somewhere in your query:: + + query_with_fragment = DSLQuery(ds.Query.hero.select(name_and_appearances)) + +* add it as an argument of :func:`dsl_gql ` with your query:: + + query = dsl_gql(name_and_appearances, query_with_fragment) + +The above example will generate the following request:: + + fragment NameAndAppearances on Character { + name + appearsIn + } + + { + hero { + ...NameAndAppearances + } + } + +Inline Fragments +^^^^^^^^^^^^^^^^ + +To define an `Inline Fragment`_, you have to: + +* Instantiate a :class:`DSLInlineFragment `:: + + human_fragment = DSLInlineFragment() + +* Provide the GraphQL type of the fragment with the + :meth:`on ` method:: + + human_fragment.on(ds.Human) + +* Add children fields using the :meth:`select ` method:: + + human_fragment.select(ds.Human.homePlanet) + +Once your inline fragment is defined, to use it you should: + +* select it as a field somewhere in your query:: + + query_with_inline_fragment = ds.Query.hero.args(episode=6).select( + ds.Character.name, + human_fragment + ) + +The above example will generate the following request:: + + hero(episode: JEDI) { + name + ... on Human { + homePlanet + } + } + +Note: because the :meth:`on ` and +:meth:`select ` methods return :code:`self`, +this can be written in a concise manner:: + + query_with_inline_fragment = ds.Query.hero.args(episode=6).select( + ds.Character.name, + DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet) + ) + + Executable examples ------------------- @@ -265,3 +352,5 @@ Sync example .. literalinclude:: ../code_examples/requests_sync_dsl.py +.. _Fragment: https://graphql.org/learn/queries/#fragments +.. _Inline Fragment: https://graphql.org/learn/queries/#inline-fragments diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py index d558ef6d..958ea490 100644 --- a/docs/code_examples/aiohttp_async_dsl.py +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -17,7 +17,7 @@ async def main(): # GQL will fetch the schema just after the establishment of the first session async with client as session: - # Instanciate the root of the DSL Schema as ds + # Instantiate the root of the DSL Schema as ds ds = DSLSchema(client.schema) # Create the query using dynamically generated attributes from ds diff --git a/docs/code_examples/requests_sync_dsl.py b/docs/code_examples/requests_sync_dsl.py index 23c40e18..925b9aa2 100644 --- a/docs/code_examples/requests_sync_dsl.py +++ b/docs/code_examples/requests_sync_dsl.py @@ -17,7 +17,7 @@ # We should have received the schema now that the session is established assert client.schema is not None - # Instanciate the root of the DSL Schema as ds + # Instantiate the root of the DSL Schema as ds ds = DSLSchema(client.schema) # Create the query using dynamically generated attributes from ds diff --git a/gql/cli.py b/gql/cli.py index f971859e..c75ad120 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -183,7 +183,7 @@ def get_execute_args(args: Namespace) -> Dict[str, Any]: def get_transport(args: Namespace) -> AsyncTransport: - """Instanciate a transport from the parsed command line arguments + """Instantiate a transport from the parsed command line arguments :param args: parsed command line arguments """ @@ -196,7 +196,7 @@ def get_transport(args: Namespace) -> AsyncTransport: # (headers) transport_args = get_transport_args(args) - # Instanciate transport depending on url scheme + # Instantiate transport depending on url scheme transport: AsyncTransport if scheme in ["ws", "wss"]: from gql.transport.websockets import WebsocketsTransport @@ -226,7 +226,7 @@ async def main(args: Namespace) -> int: logging.basicConfig(level=args.loglevel) try: - # Instanciate transport from command line arguments + # Instantiate transport from command line arguments transport = get_transport(args) # Get extra execute parameters from command line arguments diff --git a/gql/dsl.py b/gql/dsl.py index 6542d6a6..f3bd1fe2 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,11 +1,13 @@ import logging -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, DocumentNode, FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, GraphQLArgument, GraphQLField, GraphQLInputObjectType, @@ -17,6 +19,7 @@ GraphQLObjectType, GraphQLSchema, GraphQLWrappingType, + InlineFragmentNode, ListTypeNode, ListValueNode, NamedTypeNode, @@ -109,10 +112,10 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: def dsl_gql( - *operations: "DSLOperation", **operations_with_name: "DSLOperation" + *operations: "DSLExecutable", **operations_with_name: "DSLExecutable" ) -> DocumentNode: - r"""Given arguments instances of :class:`DSLOperation` - containing GraphQL operations, + r"""Given arguments instances of :class:`DSLExecutable` + containing GraphQL operations or fragments, generate a Document which can be executed later in a gql client or a gql session. @@ -122,21 +125,22 @@ def dsl_gql( by instances of :class:`DSLType` which themselves originated from a :class:`DSLSchema` class. - :param \*operations: the GraphQL operations - :type \*operations: DSLOperation (DSLQuery, DSLMutation, DSLSubscription) + :param \*operations: the GraphQL operations and fragments + :type \*operations: DSLQuery, DSLMutation, DSLSubscription, DSLFragment :param \**operations_with_name: the GraphQL operations with an operation name - :type \**operations_with_name: DSLOperation (DSLQuery, DSLMutation, DSLSubscription) + :type \**operations_with_name: DSLQuery, DSLMutation, DSLSubscription :return: a Document which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a :class:`sync session ` - :raises TypeError: if an argument is not an instance of :class:`DSLOperation` + :raises TypeError: if an argument is not an instance of :class:`DSLExecutable` + :raises AttributeError: if a type has not been provided in a :class:`DSLFragment` """ # Concatenate operations without and with name - all_operations: Tuple["DSLOperation", ...] = ( + all_operations: Tuple["DSLExecutable", ...] = ( *operations, *(operation for operation in operations_with_name.values()), ) @@ -147,25 +151,15 @@ def dsl_gql( # Check the type for operation in all_operations: - if not isinstance(operation, DSLOperation): + if not isinstance(operation, DSLExecutable): raise TypeError( - "Operations should be instances of DSLOperation " - "(DSLQuery, DSLMutation or DSLSubscription).\n" + "Operations should be instances of DSLExecutable " + "(DSLQuery, DSLMutation, DSLSubscription or DSLFragment).\n" f"Received: {type(operation)}." ) return DocumentNode( - definitions=[ - OperationDefinitionNode( - operation=OperationType(operation.operation_type), - selection_set=operation.selection_set, - variable_definitions=FrozenList( - operation.variable_definitions.get_ast_definitions() - ), - **({"name": NameNode(value=operation.name)} if operation.name else {}), - ) - for operation in all_operations - ] + definitions=[operation.executable_ast for operation in all_operations] ) @@ -201,26 +195,33 @@ def __getattr__(self, name: str) -> "DSLType": if type_def is None: raise AttributeError(f"Type '{name}' not found in the schema!") - assert isinstance(type_def, GraphQLObjectType) or isinstance( - type_def, GraphQLInterfaceType - ) + assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)) return DSLType(type_def) -class DSLOperation(ABC): - """Interface for GraphQL operations. +class DSLExecutable(ABC): + """Interface for the root elements which can be executed + in the :func:`dsl_gql ` function Inherited by - :class:`DSLQuery `, - :class:`DSLMutation ` and - :class:`DSLSubscription ` + :class:`DSLOperation ` and + :class:`DSLFragment ` """ - operation_type: OperationType + variable_definitions: "DSLVariableDefinitions" + name: Optional[str] + + @property + @abstractmethod + def executable_ast(self): + """Generates the ast for :func:`dsl_gql `.""" + raise NotImplementedError( + "Any DSLExecutable subclass must have executable_ast property" + ) # pragma: no cover def __init__( - self, *fields: "DSLField", **fields_with_alias: "DSLField", + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", ): r"""Given arguments of type :class:`DSLField` containing GraphQL requests, generate an operation which can be converted to a Document @@ -240,11 +241,11 @@ def __init__( to the operation type """ - self.name: Optional[str] = None - self.variable_definitions: DSLVariableDefinitions = DSLVariableDefinitions() + self.name = None + self.variable_definitions = DSLVariableDefinitions() # Concatenate fields without and with alias - all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( + all_fields: Tuple["DSLSelectable", ...] = DSLField.get_aliased_fields( fields, fields_with_alias ) @@ -258,13 +259,39 @@ def __init__( f"Received type: {type(field)}" ) ) - assert field.type_name.upper() == self.operation_type.name, ( - f"Invalid root field for operation {self.operation_type.name}.\n" - f"Received: {field.type_name}" - ) + if isinstance(self, DSLOperation): + assert field.type_name.upper() == self.operation_type.name, ( + f"Invalid root field for operation {self.operation_type.name}.\n" + f"Received: {field.type_name}" + ) + + self.selection_set = SelectionSetNode( + selections=FrozenList(DSLSelectable.get_ast_fields(all_fields)) + ) + + +class DSLOperation(DSLExecutable): + """Interface for GraphQL operations. + + Inherited by + :class:`DSLQuery `, + :class:`DSLMutation ` and + :class:`DSLSubscription ` + """ + + operation_type: OperationType - self.selection_set: SelectionSetNode = SelectionSetNode( - selections=FrozenList(DSLField.get_ast_fields(all_fields)) + @property + def executable_ast(self) -> OperationDefinitionNode: + """Generates the ast for :func:`dsl_gql `.""" + + return OperationDefinitionNode( + operation=OperationType(self.operation_type), + selection_set=self.selection_set, + variable_definitions=FrozenList( + self.variable_definitions.get_ast_definitions() + ), + **({"name": NameNode(value=self.name)} if self.name else {}), ) @@ -396,42 +423,23 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" -class DSLField: - """The DSLField represents a GraphQL field for the DSL code. - - Instances of this class are generated for you automatically as attributes - of the :class:`DSLType` +class DSLSelectable(ABC): + """DSLSelectable is an abstract class which indicates that + the subclasses can be used as arguments of the + :meth:`select ` method. - If this field contains children fields, then you need to select which ones - you want in the request using the :meth:`select ` - method. + Inherited by + :class:`DSLField `, + :class:`DSLFragment ` + :class:`DSLInlineFragment ` """ - def __init__( - self, - name: str, - graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], - graphql_field: GraphQLField, - ): - """Initialize the DSLField. - - .. warning:: - Don't instantiate this class yourself. - Use attributes of the :class:`DSLType` instead. - - :param name: the name of the field - :param graphql_type: the GraphQL type definition from the schema - :param graphql_field: the GraphQL field definition from the schema - """ - self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type - self.field: GraphQLField = graphql_field - self.ast_field: FieldNode = FieldNode( - name=NameNode(value=name), arguments=FrozenList() - ) - log.debug(f"Creating {self!r}") + ast_field: Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] @staticmethod - def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: + def get_ast_fields( + fields: Iterable["DSLSelectable"], + ) -> List[Union[FieldNode, InlineFragmentNode, FragmentSpreadNode]]: """ :meta private: @@ -439,11 +447,11 @@ def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: But with a type check for each field in the list. :raises TypeError: if any of the provided fields are not instances - of the :class:`DSLField` class. + of the :class:`DSLSelectable` class. """ ast_fields = [] for field in fields: - if isinstance(field, DSLField): + if isinstance(field, DSLSelectable): ast_fields.append(field.ast_field) else: raise TypeError(f'Received incompatible field: "{field}".') @@ -452,8 +460,9 @@ def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: @staticmethod def get_aliased_fields( - fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"] - ) -> Tuple["DSLField", ...]: + fields: Iterable["DSLSelectable"], + fields_with_alias: Dict[str, "DSLSelectableWithAlias"], + ) -> Tuple["DSLSelectable", ...]: """ :meta private: @@ -467,9 +476,29 @@ def get_aliased_fields( *(field.alias(alias) for alias, field in fields_with_alias.items()), ) + def __str__(self) -> str: + return print_ast(self.ast_field) + + +class DSLSelector(ABC): + """DSLSelector is an abstract class which defines the + :meth:`select ` method to select + children fields in the query. + + Inherited by + :class:`DSLField `, + :class:`DSLFragment `, + :class:`DSLInlineFragment ` + """ + + selection_set: SelectionSetNode + + def __init__(self): + self.selection_set = SelectionSetNode(selections=FrozenList([])) + def select( - self, *fields: "DSLField", **fields_with_alias: "DSLField" - ) -> "DSLField": + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" + ) -> "DSLSelector": r"""Select the new children fields that we want to receive in the request. @@ -477,46 +506,48 @@ def select( to the existing children fields. :param \*fields: new children fields - :type \*fields: DSLField + :type \*fields: DSLSelectable (DSLField, DSLFragment or DSLInlineFragment) :param \**fields_with_alias: new children fields with alias as key :type \**fields_with_alias: DSLField :return: itself :raises TypeError: if any of the provided fields are not instances - of the :class:`DSLField` class. + of the :class:`DSLSelectable` class. """ # Concatenate fields without and with alias - added_fields: Tuple["DSLField", ...] = self.get_aliased_fields( + added_fields: Tuple["DSLSelectable", ...] = DSLSelectable.get_aliased_fields( fields, fields_with_alias ) - added_selections: List[FieldNode] = self.get_ast_fields(added_fields) + # Get a list of AST Nodes for each added field + added_selections: List[ + Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] + ] = DSLSelectable.get_ast_fields(added_fields) - current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set - - if current_selection_set is None: - self.ast_field.selection_set = SelectionSetNode( - selections=FrozenList(added_selections) - ) - else: - current_selection_set.selections = FrozenList( - current_selection_set.selections + added_selections - ) + # Update the current selection list with new selections + self.selection_set.selections = FrozenList( + self.selection_set.selections + added_selections + ) - log.debug(f"Added fields: {fields} in {self!r}") + log.debug(f"Added fields: {added_fields} in {self!r}") return self - def __call__(self, **kwargs) -> "DSLField": - return self.args(**kwargs) - def alias(self, alias: str) -> "DSLField": +class DSLSelectableWithAlias(DSLSelectable): + """DSLSelectableWithAlias is an abstract class which indicates that + the subclasses can be selected with an alias. + """ + + ast_field: FieldNode + + def alias(self, alias: str) -> "DSLSelectableWithAlias": """Set an alias .. note:: You can also pass the alias directly at the - :meth:`select ` method. + :meth:`select ` method. :code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to: :code:`ds.Query.human.select(ds.Character.name.alias("my_name"))` @@ -528,6 +559,47 @@ def alias(self, alias: str) -> "DSLField": self.ast_field.alias = NameNode(value=alias) return self + +class DSLField(DSLSelectableWithAlias, DSLSelector): + """The DSLField represents a GraphQL field for the DSL code. + + Instances of this class are generated for you automatically as attributes + of the :class:`DSLType` + + If this field contains children fields, then you need to select which ones + you want in the request using the :meth:`select ` + method. + """ + + _type: Union[GraphQLObjectType, GraphQLInterfaceType] + ast_field: FieldNode + field: GraphQLField + + def __init__( + self, + name: str, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], + graphql_field: GraphQLField, + ): + """Initialize the DSLField. + + .. warning:: + Don't instantiate this class yourself. + Use attributes of the :class:`DSLType` instead. + + :param name: the name of the field + :param graphql_type: the GraphQL type definition from the schema + :param graphql_field: the GraphQL field definition from the schema + """ + DSLSelector.__init__(self) + self._type = graphql_type + self.field = graphql_field + self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) + log.debug(f"Creating {self!r}") + + def __call__(self, **kwargs) -> "DSLField": + return self.args(**kwargs) + def args(self, **kwargs) -> "DSLField": r"""Set the arguments of a field @@ -576,16 +648,170 @@ def _get_argument(self, name: str) -> GraphQLArgument: return arg + def select( + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" + ) -> "DSLField": + """Calling :meth:`select ` method with + corrected typing hints + """ + + super().select(*fields, **fields_with_alias) + self.ast_field.selection_set = self.selection_set + + return self + @property def type_name(self): """:meta private:""" return self._type.name - def __str__(self) -> str: - return print_ast(self.ast_field) - def __repr__(self) -> str: return ( f"<{self.__class__.__name__} {self._type.name}" f"::{self.ast_field.name.value}>" ) + + +class DSLInlineFragment(DSLSelectable, DSLSelector): + """DSLInlineFragment represents an inline fragment for the DSL code.""" + + _type: Union[GraphQLObjectType, GraphQLInterfaceType] + ast_field: InlineFragmentNode + + def __init__( + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + ): + r"""Initialize the DSLInlineFragment. + + :param \*fields: new children fields + :type \*fields: DSLSelectable (DSLField, DSLFragment or DSLInlineFragment) + :param \**fields_with_alias: new children fields with alias as key + :type \**fields_with_alias: DSLField + """ + + DSLSelector.__init__(self) + self.ast_field = InlineFragmentNode() + self.select(*fields, **fields_with_alias) + log.debug(f"Creating {self!r}") + + def select( + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" + ) -> "DSLInlineFragment": + """Calling :meth:`select ` method with + corrected typing hints + """ + super().select(*fields, **fields_with_alias) + self.ast_field.selection_set = self.selection_set + + return self + + def on(self, type_condition: DSLType) -> "DSLInlineFragment": + """Provides the GraphQL type of this inline fragment.""" + + self._type = type_condition._type + self.ast_field.type_condition = NamedTypeNode( + name=NameNode(value=self._type.name) + ) + return self + + def __repr__(self) -> str: + type_info = "" + + try: + type_info += f" on {self._type.name}" + except AttributeError: + pass + + return f"<{self.__class__.__name__}{type_info}>" + + +class DSLFragment(DSLSelectable, DSLSelector, DSLExecutable): + """DSLFragment represents a named GraphQL fragment for the DSL code.""" + + _type: Optional[Union[GraphQLObjectType, GraphQLInterfaceType]] + ast_field: FragmentSpreadNode + name: str + + def __init__( + self, + name: str, + *fields: "DSLSelectable", + **fields_with_alias: "DSLSelectableWithAlias", + ): + r"""Initialize the DSLFragment. + + :param name: the name of the fragment + :type name: str + :param \*fields: new children fields + :type \*fields: DSLSelectable (DSLField, DSLFragment or DSLInlineFragment) + :param \**fields_with_alias: new children fields with alias as key + :type \**fields_with_alias: DSLField + """ + + DSLSelector.__init__(self) + DSLExecutable.__init__(self, *fields, **fields_with_alias) + + self.name = name + self._type = None + + log.debug(f"Creating {self!r}") + + @property # type: ignore + def ast_field(self) -> FragmentSpreadNode: # type: ignore + """ast_field property will generate a FragmentSpreadNode with the + provided name. + + Note: We need to ignore the type because of + `issue #4125 of mypy `_. + """ + + spread_node = FragmentSpreadNode() + spread_node.name = NameNode(value=self.name) + + return spread_node + + def select( + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" + ) -> "DSLFragment": + """Calling :meth:`select ` method with + corrected typing hints + """ + super().select(*fields, **fields_with_alias) + + return self + + def on(self, type_condition: DSLType) -> "DSLFragment": + """Provides the GraphQL type of this fragment. + + :param type_condition: the provided type + :type type_condition: DSLType + """ + + self._type = type_condition._type + + return self + + @property + def executable_ast(self) -> FragmentDefinitionNode: + """Generates the ast for :func:`dsl_gql `. + + :raises AttributeError: if a type has not been provided + """ + assert self.name is not None + + if self._type is None: + raise AttributeError( + "Missing type condition. Please use .on(type_condition) method" + ) + + return FragmentDefinitionNode( + type_condition=NamedTypeNode(name=NameNode(value=self._type.name)), + selection_set=self.selection_set, + variable_definitions=FrozenList( + self.variable_definitions.get_ast_definitions() + ), + name=NameNode(value=self.name), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.name!s}>" diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 8fdaf426..93de6c03 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -15,6 +15,8 @@ from gql import Client from gql.dsl import ( + DSLFragment, + DSLInlineFragment, DSLMutation, DSLQuery, DSLSchema, @@ -187,6 +189,14 @@ def test_hero_name_and_friends_query(ds): ) assert query == str(query_dsl) + # Should also work with a chain of selects + query_dsl = ( + ds.Query.hero.select(ds.Character.id) + .select(ds.Character.name) + .select(ds.Character.friends.select(ds.Character.name,),) + ) + assert query == str(query_dsl) + def test_hero_id_and_name(ds): query = """ @@ -244,6 +254,10 @@ def test_fetch_luke_aliased(ds): query_dsl = ds.Query.human.args(id=1000).alias("luke").select(ds.Character.name,) assert query == str(query_dsl) + # Should also work with select before alias + query_dsl = ds.Query.human.args(id=1000).select(ds.Character.name,).alias("luke") + assert query == str(query_dsl) + def test_fetch_name_aliased(ds: DSLSchema): query = """ @@ -416,6 +430,152 @@ def test_multiple_operations(ds): ) +def test_inline_fragments(ds): + query = """hero(episode: JEDI) { + name + ... on Droid { + primaryFunction + } + ... on Human { + homePlanet + } +}""" + query_dsl = ds.Query.hero.args(episode=6).select( + ds.Character.name, + DSLInlineFragment().on(ds.Droid).select(ds.Droid.primaryFunction), + DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet), + ) + assert query == str(query_dsl) + + +def test_fragments_repr(ds): + + assert repr(DSLInlineFragment()) == "" + assert repr(DSLInlineFragment().on(ds.Droid)) == "" + assert repr(DSLFragment("fragment_1")) == "" + assert repr(DSLFragment("fragment_2").on(ds.Droid)) == "" + + +def test_fragments(ds): + query = """fragment NameAndAppearances on Character { + name + appearsIn +} + +{ + hero { + ...NameAndAppearances + } +} +""" + + name_and_appearances = ( + DSLFragment("NameAndAppearances") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + ) + + query_dsl = DSLQuery(ds.Query.hero.select(name_and_appearances)) + + document = dsl_gql(name_and_appearances, query_dsl) + + print(print_ast(document)) + + assert query == print_ast(document) + + +def test_fragment_without_type_condition_error(ds): + + # We create a fragment without using the .on(type_condition) method + name_and_appearances = DSLFragment("NameAndAppearances").select( + ds.Character.name, ds.Character.appearsIn + ) + + # If we try to use this fragment, gql generates an error + with pytest.raises( + AttributeError, + match=r"Missing type condition. Please use .on\(type_condition\) method", + ): + dsl_gql(name_and_appearances) + + +def test_fragment_with_name_changed(ds): + + fragment = DSLFragment("ABC") + + assert str(fragment) == "...ABC" + + fragment.name = "DEF" + + assert str(fragment) == "...DEF" + + +def test_dsl_nested_query_with_fragment(ds): + query = """fragment NameAndAppearances on Character { + name + appearsIn +} + +query NestedQueryWithFragment { + hero { + ...NameAndAppearances + friends { + ...NameAndAppearances + friends { + ...NameAndAppearances + } + } + } +} +""" + + name_and_appearances = ( + DSLFragment("NameAndAppearances") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + ) + + query_dsl = DSLQuery( + ds.Query.hero.select( + name_and_appearances, + ds.Character.friends.select( + name_and_appearances, ds.Character.friends.select(name_and_appearances) + ), + ) + ) + + document = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + + print(print_ast(document)) + + assert query == print_ast(document) + + # Same thing, but incrementaly + + name_and_appearances = DSLFragment("NameAndAppearances") + name_and_appearances.on(ds.Character) + name_and_appearances.select(ds.Character.name) + name_and_appearances.select(ds.Character.appearsIn) + + level_2 = ds.Character.friends + level_2.select(name_and_appearances) + level_1 = ds.Character.friends + level_1.select(name_and_appearances) + level_1.select(level_2) + + hero = ds.Query.hero + hero.select(name_and_appearances) + hero.select(level_1) + + query_dsl = DSLQuery(hero) + + document = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + + print(print_ast(document)) + + assert query == print_ast(document) + + def test_dsl_query_all_fields_should_be_instances_of_DSLField(): with pytest.raises( TypeError, match="fields must be instances of DSLField. Received type:" @@ -432,9 +592,9 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): ) -def test_dsl_gql_all_arguments_should_be_operations(): +def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( - TypeError, match="Operations should be instances of DSLOperation " + TypeError, match="Operations should be instances of DSLExecutable " ): dsl_gql("I am a string") From af8f22362d70a6989a703bca5c941b06065d69c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kucmus?= Date: Thu, 16 Sep 2021 16:39:32 +0200 Subject: [PATCH 019/239] AIOHTTP Filename upload (#241) * add the option to specify a filename on AIOHTTPTransport multipart upload --- CONTRIBUTING.md | 2 +- gql/transport/aiohttp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e4df615a..70750f8e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,7 +31,7 @@ virtualenv gql-dev Activate the virtualenv and install dependencies by running: ```console -python pip install -e.[dev] +python -m pip install -e.[dev] ``` If you are using Linux or MacOS, you can make use of Makefile command diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 84679365..04c85edf 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -185,7 +185,7 @@ async def execute( # Add the extracted files as remaining fields for k, v in file_streams.items(): - data.add_field(k, v, filename=k) + data.add_field(k, v, filename=getattr(v, "name", k)) post_args: Dict[str, Any] = {"data": data} From 3db7a8610ab0a453c059c372f3076bfdf1edc401 Mon Sep 17 00:00:00 2001 From: DENKweit GmbH <82954629+DENKweit@users.noreply.github.com> Date: Wed, 6 Oct 2021 14:47:20 +0200 Subject: [PATCH 020/239] Add upload_files functionality for requests transport (#244) --- docs/usage/file_upload.rst | 6 + gql/transport/requests.py | 73 ++++++- setup.py | 1 + tests/test_requests.py | 412 +++++++++++++++++++++++++++++++++++++ 4 files changed, 487 insertions(+), 5 deletions(-) diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index 18718e75..cfc85df9 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -2,6 +2,7 @@ File uploads ============ GQL supports file uploads with the :ref:`aiohttp transport ` +and the :ref:`requests transport ` using the `GraphQL multipart request spec`_. .. _GraphQL multipart request spec: https://github.com/jaydenseric/graphql-multipart-request-spec @@ -18,6 +19,7 @@ In order to upload a single file, you need to: .. code-block:: python transport = AIOHTTPTransport(url='YOUR_URL') + # Or transport = RequestsHTTPTransport(url='YOUR_URL') client = Client(transport=transport) @@ -45,6 +47,7 @@ It is also possible to upload multiple files using a list. .. code-block:: python transport = AIOHTTPTransport(url='YOUR_URL') + # Or transport = RequestsHTTPTransport(url='YOUR_URL') client = Client(transport=transport) @@ -84,6 +87,9 @@ We provide methods to do that for two different uses cases: * Sending local files * Streaming downloaded files from an external URL to the GraphQL API +.. note:: + Streaming is only supported with the :ref:`aiohttp transport ` + Streaming local files ^^^^^^^^^^^^^^^^^^^^^ diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 7f9ff26a..68b4144b 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,15 +1,18 @@ +import io import json import logging -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Type, Union import requests from graphql import DocumentNode, ExecutionResult, print_ast from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar +from requests_toolbelt.multipart.encoder import MultipartEncoder from gql.transport import Transport +from ..utils import extract_files from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -27,6 +30,8 @@ class RequestsHTTPTransport(Transport): The transport uses the requests library to send HTTP POST requests. """ + file_classes: Tuple[Type[Any], ...] = (io.IOBase,) + def __init__( self, url: str, @@ -104,6 +109,7 @@ def execute( # type: ignore operation_name: Optional[str] = None, timeout: Optional[int] = None, extra_args: Dict[str, Any] = None, + upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. @@ -116,6 +122,7 @@ def execute( # type: ignore Only required in multi-operation documents (Default: None). :param timeout: Specifies a default timeout for requests (Default: None). :param extra_args: additional arguments to send to the requests post method + :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. `data` is the result of executing the query, `errors` is null if no errors occurred, and is a non-empty array if an error occurred. @@ -126,21 +133,77 @@ def execute( # type: ignore query_str = print_ast(document) payload: Dict[str, Any] = {"query": query_str} - if variable_values: - payload["variables"] = variable_values + if operation_name: payload["operationName"] = operation_name - data_key = "json" if self.use_json else "data" post_args = { "headers": self.headers, "auth": self.auth, "cookies": self.cookies, "timeout": timeout or self.default_timeout, "verify": self.verify, - data_key: payload, } + if upload_files: + # If the upload_files flag is set, then we need variable_values + assert variable_values is not None + + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=variable_values, file_classes=self.file_classes, + ) + + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values + + # Add the payload to the operations field + operations_str = json.dumps(payload) + log.debug("operations %s", operations_str) + + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} + + # Enumerate the file streams + # Will generate something like {'0': <_io.BufferedReader ...>} + file_streams = {str(i): files[path] for i, path in enumerate(files)} + + # Add the file map field + file_map_str = json.dumps(file_map) + log.debug("file_map %s", file_map_str) + + fields = {"operations": operations_str, "map": file_map_str} + + # Add the extracted files as remaining fields + for k, v in file_streams.items(): + fields[k] = (getattr(v, "name", k), v) + + # Prepare requests http to send multipart-encoded data + data = MultipartEncoder(fields=fields) + + post_args["data"] = data + + if post_args["headers"] is None: + post_args["headers"] = {} + else: + post_args["headers"] = {**post_args["headers"]} + + post_args["headers"]["Content-Type"] = data.content_type + + else: + if variable_values: + payload["variables"] = variable_values + + if log.isEnabledFor(logging.INFO): + log.info(">>> %s", json.dumps(payload)) + + data_key = "json" if self.use_json else "data" + post_args[data_key] = payload + # Log the payload if log.isEnabledFor(logging.INFO): log.info(">>> %s", json.dumps(payload)) diff --git a/setup.py b/setup.py index 248099ab..ead75821 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ install_requests_requires = [ "requests>=2.23,<3", + "requests_toolbelt>=0.9.1,<1", ] install_websockets_requires = [ diff --git a/tests/test_requests.py b/tests/test_requests.py index e18875a2..d0cc7eb7 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -8,6 +8,7 @@ TransportQueryError, TransportServerError, ) +from tests.conftest import TemporaryFile # Marking all tests in this file with the requests marker pytestmark = pytest.mark.requests @@ -332,3 +333,414 @@ def test_code(): assert execution_result.extensions["key1"] == "val1" await run_sync_test(event_loop, server, test_code) + + +file_upload_server_answer = '{"data":{"success":true}}' + +file_upload_mutation_1 = """ + mutation($file: Upload!) { + uploadFile(input:{other_var:$other_var, file:$file}) { + success + } + } +""" + +file_upload_mutation_1_operations = ( + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' + '$other_var, file: $file}) {\\n success\\n }\\n}\\n", "variables": ' + '{"file": null, "other_var": 42}}' +) + +file_upload_mutation_1_map = '{"0": ["variables.file"]}' + +file_1_content = """ +This is a test file +This file will be sent in the GraphQL mutation +""" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def single_upload_handler(request): + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with TemporaryFile(file_1_content) as test_file: + with Client(transport=sample_transport) as session: + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_additional_headers( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def single_upload_handler(request): + from aiohttp import web + + assert request.headers["X-Auth"] == "foobar" + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) + + with TemporaryFile(file_1_content) as test_file: + with Client(transport=sample_transport) as session: + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_binary_file_upload(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + + async def binary_upload_handler(request): + + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_binary = await field_2.read() + assert field_2_binary == binary_file_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", binary_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = RequestsHTTPTransport(url=url) + + def test_code(): + with TemporaryFile(binary_file_content) as test_file: + with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + +file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}\\n", ' + '"variables": {"file1": null, "file2": null}}' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_two_files( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } + """ + + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' + + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ + + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_2_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_2_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + + with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_2) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = { + "file1": f1, + "file2": f2, + } + + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + f1.close() + f2.close() + + await run_sync_test(event_loop, server, test_code) + + +file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(input: {files: $files})' + ' {\\n success\\n }\\n}\\n", "variables": {"files": [null, null]}}' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_list_of_two_files( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{files:$files}) { + success + } + } + """ + + file_upload_mutation_3_map = ( + '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + ) + + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ + + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_3_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_3_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_3) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = {"files": [f1, f2]} + + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + f1.close() + f2.close() + + await run_sync_test(event_loop, server, test_code) From 7992721801d551a9d2a9f39784973a08ee146a80 Mon Sep 17 00:00:00 2001 From: mirkan1 Date: Wed, 20 Oct 2021 08:36:21 -0400 Subject: [PATCH 021/239] AIOHTTPTransport ignore backend mimetype (#248) --- gql/transport/aiohttp.py | 2 +- tests/test_aiohttp.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 04c85edf..090463e9 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -225,7 +225,7 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): ) try: - result = await resp.json() + result = await resp.json(content_type=None) if log.isEnabledFor(logging.INFO): result_text = await resp.text() diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 3fb85cd0..df954f12 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -69,6 +69,35 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.asyncio +async def test_aiohttp_ignore_backend_content_type(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="text/plain") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + @pytest.mark.asyncio async def test_aiohttp_cookies(event_loop, aiohttp_server): from aiohttp import web From feea531037bab206cd2dae36c89ac1bff2c38be7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 24 Oct 2021 17:11:59 +0200 Subject: [PATCH 022/239] Implementation of graphql-ws protocol (#242) * Supporting both apollo and graphql-ws protocol in same class --- docs/modules/transport.rst | 2 + docs/transports/aiohttp.rst | 2 + docs/transports/requests.rst | 2 + docs/transports/websockets.rst | 63 +- gql/transport/websockets.py | 279 +++++++- tests/conftest.py | 92 ++- tests/test_graphqlws_exceptions.py | 289 ++++++++ tests/test_graphqlws_subscription.py | 745 +++++++++++++++++++++ tests/test_phoenix_channel_subscription.py | 9 +- tests/test_websocket_subscription.py | 2 +- tox.ini | 1 + 11 files changed, 1462 insertions(+), 24 deletions(-) create mode 100644 tests/test_graphqlws_exceptions.py create mode 100644 tests/test_graphqlws_subscription.py diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index 9a3caf6e..1b250d7a 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -14,3 +14,5 @@ gql.transport .. autoclass:: gql.transport.aiohttp.AIOHTTPTransport .. autoclass:: gql.transport.websockets.WebsocketsTransport + +.. autoclass:: gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index 4b792232..91f2bf40 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -5,6 +5,8 @@ AIOHTTPTransport This transport uses the `aiohttp`_ library and allows you to send GraphQL queries using the HTTP protocol. +Reference: :py:class:`gql.transport.aiohttp.AIOHTTPTransport` + .. note:: GraphQL subscriptions are not supported on the HTTP transport. diff --git a/docs/transports/requests.rst b/docs/transports/requests.rst index f920f3e0..15eaedb5 100644 --- a/docs/transports/requests.rst +++ b/docs/transports/requests.rst @@ -6,6 +6,8 @@ RequestsHTTPTransport The RequestsHTTPTransport is a sync transport using the `requests`_ library and allows you to send GraphQL queries using the HTTP protocol. +Reference: :py:class:`gql.transport.requests.RequestsHTTPTransport` + .. literalinclude:: ../code_examples/requests_sync.py .. _requests: https://requests.readthedocs.io diff --git a/docs/transports/websockets.rst b/docs/transports/websockets.rst index a082d887..7c91efb6 100644 --- a/docs/transports/websockets.rst +++ b/docs/transports/websockets.rst @@ -3,10 +3,17 @@ WebsocketsTransport =================== -The websockets transport implements the `Apollo websockets transport protocol`_. +The websockets transport supports both: + + - the `Apollo websockets transport protocol`_. + - the `GraphQL-ws websockets transport protocol`_ + +It will detect the backend supported protocol from the response http headers returned. This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. +Reference: :py:class:`gql.transport.websockets.WebsocketsTransport` + .. literalinclude:: ../code_examples/websockets_async.py Websockets SSL @@ -14,11 +21,11 @@ Websockets SSL If you need to connect to an ssl encrypted endpoint: -* use _wss_ instead of _ws_ in the url of the transport +* use :code:`wss` instead of :code:`ws` in the url of the transport .. code-block:: python - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', headers={'Authorization': 'token'} ) @@ -34,7 +41,7 @@ If you have a self-signed ssl certificate, you need to provide an ssl_context wi localhost_pem = pathlib.Path(__file__).with_name("YOUR_SERVER_PUBLIC_CERTIFICATE.pem") ssl_context.load_verify_locations(localhost_pem) - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', ssl=ssl_context ) @@ -54,7 +61,7 @@ There are two ways to send authentication tokens with websockets depending on th .. code-block:: python - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', headers={'Authorization': 'token'} ) @@ -63,9 +70,53 @@ There are two ways to send authentication tokens with websockets depending on th .. code-block:: python - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url='wss://SERVER_URL:SERVER_PORT/graphql', init_payload={'Authorization': 'token'} ) +Keep-Alives +----------- + +Apollo protocol +^^^^^^^^^^^^^^^ + +With the Apollo protocol, the backend can optionally send unidirectional keep-alive ("ka") messages +(only from the server to the client). + +It is possible to configure the transport to close if we don't receive a "ka" message +within a specified time using the :code:`keep_alive_timeout` parameter. + +Here is an example with 60 seconds:: + + transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + keep_alive_timeout=60, + ) + +One disadvantage of the Apollo protocol is that because the keep-alives are only sent from the server +to the client, it can be difficult to detect the loss of a connection quickly from the server side. + +GraphQL-ws protocol +^^^^^^^^^^^^^^^^^^^ + +With the GraphQL-ws protocol, it is possible to send bidirectional ping/pong messages. +Pings can be sent either from the client or the server and the other party should answer with a pong. + +As with the Apollo protocol, it is possible to configure the transport to close if we don't +receive any message from the backend within the specified time using the :code:`keep_alive_timeout` parameter. + +But there is also the possibility for the client to send pings at a regular interval and verify +that the backend sends a pong within a specified delay. +This can be done using the :code:`ping_interval` and :code:`pong_timeout` parameters. + +Here is an example with a ping sent every 60 seconds, expecting a pong within 10 seconds:: + + transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + ping_interval=60, + pong_timeout=10, + ) + .. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +.. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 4ec8ce89..06552d2f 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -86,6 +86,11 @@ class WebsocketsTransport(AsyncTransport): on a websocket connection. """ + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") + GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") + def __init__( self, url: str, @@ -96,6 +101,9 @@ def __init__( close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -112,8 +120,18 @@ def __init__( from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True :param connect_args: Other parameters forwarded to websockets.connect """ + self.url: str = url self.ssl: Union[SSLContext, bool] = ssl self.headers: Optional[HeadersLike] = headers @@ -123,6 +141,15 @@ def __init__( self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout self.connect_args = connect_args @@ -132,6 +159,7 @@ def __init__( self.receive_data_task: Optional[asyncio.Future] = None self.check_keep_alive_task: Optional[asyncio.Future] = None + self.send_ping_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None # We need to set an event loop here if there is none @@ -152,10 +180,28 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + with the graphql-ws protocol. + Possible keys are: 'ping', 'pong', 'connection_ack'""" + self._connecting: bool = False self.close_exception: Optional[Exception] = None + self.supported_subprotocols = [ + self.GRAPHQLWS_SUBPROTOCOL, + self.APOLLO_SUBPROTOCOL, + ] + async def _send(self, message: str) -> None: """Send the provided message to the websocket connection and log the message""" @@ -223,6 +269,28 @@ async def _send_init_message_and_wait_ack(self) -> None: # Wait for the connection_ack message or raise a TimeoutError await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol + """ + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol + """ + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(json.dumps(pong_message)) + async def _send_stop_message(self, query_id: int) -> None: """Send stop message to the provided websocket connection and query_id. @@ -233,6 +301,32 @@ async def _send_stop_message(self, query_id: int) -> None: await self._send(stop_message) + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + + await self._send(complete_message) + + async def _stop_listener(self, query_id: int) -> None: + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. + + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) + + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + async def _send_connection_terminate_message(self) -> None: """Send a connection_terminate message to the provided websocket connection. @@ -265,18 +359,102 @@ async def _send_query( if operation_name: payload["operationName"] = operation_name + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + query_str = json.dumps( - {"id": str(query_id), "type": "start", "payload": payload} + {"id": str(query_id), "type": query_type, "payload": payload} ) await self._send(query_str) return query_id - def _parse_answer( + def _parse_answer_graphqlws( self, answer: str ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + json_answer = json.loads(answer) + + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "next": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. Returns a list consisting of: - the answer_type (between: @@ -342,6 +520,17 @@ def _parse_answer( return answer_type, answer_id, execution_result + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(answer) + + return self._parse_answer_apollo(answer) + async def _check_ws_liveness(self) -> None: """Coroutine which will periodically check the liveness of the connection through keep-alive messages @@ -376,6 +565,39 @@ async def _check_ws_liveness(self) -> None: # The client is probably closing, handle it properly pass + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + async def _receive_data_loop(self) -> None: try: while True: @@ -428,6 +650,7 @@ async def _handle_answer( answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: + try: # Put the answer in the queue if answer_id is not None: @@ -436,6 +659,15 @@ async def _handle_answer( # Do nothing if no one is listening to this query_id. pass + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + async def subscribe( self, document: DocumentNode, @@ -486,7 +718,7 @@ async def subscribe( except (asyncio.CancelledError, GeneratorExit) as e: log.debug(f"Exception in subscribe: {e!r}") if listener.send_stop: - await self._send_stop_message(query_id) + await self._stop_listener(query_id) listener.send_stop = False finally: @@ -540,8 +772,6 @@ async def connect(self) -> None: Should be cleaned with a call to the close coroutine """ - GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws") - log.debug("connect: starting") if self.websocket is None and not self._connecting: @@ -562,7 +792,7 @@ async def connect(self) -> None: connect_args: Dict[str, Any] = { "ssl": ssl, "extra_headers": self.headers, - "subprotocols": [GRAPHQLWS_SUBPROTOCOL], + "subprotocols": self.supported_subprotocols, } # Adding custom parameters passed from init @@ -579,6 +809,19 @@ async def connect(self) -> None: finally: self._connecting = False + self.websocket = cast(WebSocketClientProtocol, self.websocket) + + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket.response_headers + try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + self.next_query_id = 1 self.close_exception = None self._wait_closed.clear() @@ -601,6 +844,14 @@ async def connect(self) -> None: self._check_ws_liveness() ) + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) @@ -633,7 +884,7 @@ async def _clean_close(self, e: Exception) -> None: for query_id, listener in self.listeners.items(): if listener.send_stop: - await self._send_stop_message(query_id) + await self._stop_listener(query_id) listener.send_stop = False # Wait that there is no more listeners (we received 'complete' for all queries) @@ -642,8 +893,9 @@ async def _clean_close(self, e: Exception) -> None: except asyncio.TimeoutError: # pragma: no cover log.debug("Timer close_timeout fired") - # Finally send the 'connection_terminate' message - await self._send_connection_terminate_message() + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + # Finally send the 'connection_terminate' message + await self._send_connection_terminate_message() async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: """Coroutine which will: @@ -669,6 +921,12 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: with suppress(asyncio.CancelledError): await self.check_keep_alive_task + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task + # Saving exception to raise it later if trying to use the transport # after it has already closed. self.close_exception = e @@ -702,6 +960,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self.websocket = None self.close_task = None self.check_keep_alive_task = None + self.send_ping_task = None self._wait_closed.set() diff --git a/tests/conftest.py b/tests/conftest.py index df69c121..004fa9df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,13 +128,14 @@ class WebSocketServer: def __init__(self, with_ssl: bool = False): self.with_ssl = with_ssl - async def start(self, handler): + async def start(self, handler, extra_serve_args=None): import websockets.server print("Starting server") - extra_serve_args = {} + if extra_serve_args is None: + extra_serve_args = {} if self.with_ssl: # This is a copy of certificate from websockets tests folder @@ -192,7 +193,21 @@ async def send_keepalive(ws): await ws.send('{"type":"ka"}') @staticmethod - async def send_connection_ack(ws): + async def send_ping(ws, payload=None): + if payload is None: + await ws.send('{"type":"ping"}') + else: + await ws.send(json.dumps({"type": "ping", "payload": payload})) + + @staticmethod + async def send_pong(ws, payload=None): + if payload is None: + await ws.send('{"type":"pong"}') + else: + await ws.send(json.dumps({"type": "pong", "payload": payload})) + + @staticmethod + async def send_connection_ack(ws, payload=None): # Line return for easy debugging print("") @@ -203,7 +218,10 @@ async def send_connection_ack(ws): assert json_result["type"] == "connection_init" # Send ack - await ws.send('{"type":"connection_ack"}') + if payload is None: + await ws.send('{"type":"connection_ack"}') + else: + await ws.send(json.dumps({"type": "connection_ack", "payload": payload})) @staticmethod async def wait_connection_terminate(ws): @@ -352,6 +370,54 @@ async def server(request): await test_server.stop() +@pytest.fixture +async def graphqlws_server(request): + """Fixture used to start a dummy server with the graphql-ws protocol. + + Similar to the server fixture above but will return "graphql-transport-ws" + as the server subprotocol. + + It can take as argument either a handler function for the websocket server for + complete control OR an array of answers to be sent by the default server handler. + """ + + subprotocol = "graphql-transport-ws" + + from websockets.server import WebSocketServerProtocol + + class CustomSubprotocol(WebSocketServerProtocol): + def select_subprotocol(self, client_subprotocols, server_subprotocols): + print(f"Client subprotocols: {client_subprotocols!r}") + print(f"Server subprotocols: {server_subprotocols!r}") + + return subprotocol + + def process_subprotocol(self, headers, available_subprotocols): + # Overwriting available subprotocols + available_subprotocols = [subprotocol] + + print(f"headers: {headers!r}") + # print (f"Available subprotocols: {available_subprotocols!r}") + + return super().process_subprotocol(headers, available_subprotocols) + + server_handler = get_server_handler(request) + + try: + test_server = WebSocketServer() + + # Starting the server with the fixture param as the handler function + await test_server.start( + server_handler, extra_serve_args={"create_protocol": CustomSubprotocol} + ) + + yield test_server + except Exception as e: + print("Exception received in server fixture:", e) + finally: + await test_server.stop() + + @pytest.fixture async def client_and_server(server): """Helper fixture to start a server and a client connected to its port.""" @@ -369,6 +435,24 @@ async def client_and_server(server): yield session, server +@pytest.fixture +async def client_and_graphqlws_server(graphqlws_server): + """Helper fixture to start a server with the graphql-ws prototocol + and a client connected to its port.""" + + from gql.transport.websockets import WebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + sample_transport = WebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, graphqlws_server + + @pytest.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py new file mode 100644 index 00000000..8a2e7495 --- /dev/null +++ b/tests/test_graphqlws_exceptions.py @@ -0,0 +1,289 @@ +import asyncio +import json +import types +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"next","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_graphqlws_invalid_query( + event_loop, client_and_graphqlws_server, query_str +): + + session, server = client_and_graphqlws_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_invalid_subscription], indirect=True +) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_graphqlws_invalid_subscription( + event_loop, client_and_graphqlws_server, query_str +): + + session, server = client_and_graphqlws_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_graphqlws_server_does_not_send_ack( + event_loop, graphqlws_server, query_str +): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + +invalid_payload_server_answer = ( + '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +) + + +async def server_invalid_payload(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_payload_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_invalid_payload], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_graphqlws_sending_invalid_payload( + event_loop, client_and_graphqlws_server, query_str +): + + session, server = client_and_graphqlws_server + + # Monkey patching the _send_query method to send an invalid payload + + async def monkey_patch_send_query( + self, document, variable_values=None, operation_name=None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + {"id": str(query_id), "type": "subscribe", "payload": "BLAHBLAH"} + ) + + await self._send(query_str) + return query_id + + session.transport._send_query = types.MethodType( + monkey_patch_send_query, session.transport + ) + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["message"] == "Must provide document" + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "next"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "next", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "next", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_graphqlws_transport_protocol_errors( + event_loop, client_and_graphqlws_server +): + + session, server = client_and_graphqlws_server + + query = gql("query { hello }") + + with pytest.raises(TransportProtocolError): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) +async def test_graphqlws_server_does_not_ack(event_loop, graphqlws_server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) +async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): + import websockets + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with Client(transport=sample_transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) +async def test_graphqlws_server_closing_after_ack( + event_loop, client_and_graphqlws_server +): + + import websockets + + session, server = client_and_graphqlws_server + + query = gql("query { hello }") + + with pytest.raises(websockets.exceptions.ConnectionClosed): + await session.execute(query) + + await session.transport.wait_closed() + + with pytest.raises(TransportClosed): + await session.execute(query) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py new file mode 100644 index 00000000..8f38d101 --- /dev/null +++ b/tests/test_graphqlws_subscription.py @@ -0,0 +1,745 @@ +import asyncio +import json +import sys +from typing import List + +import pytest +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportServerError + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +countdown_server_answer = ( + '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +COUNTING_DELAY = 2 * MS +PING_SENDING_DELAY = 5 * MS +PONG_TIMEOUT = 2 * MS + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def server_countdown_factory(keepalive=False, answer_pings=True): + async def server_countdown_template(ws, path): + import websockets + + logged_messages.clear() + + try: + await WebSocketServerHelper.send_connection_ack( + ws, payload="dummy_connection_ack_payload" + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + pong_received: asyncio.Event = asyncio.Event() + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(COUNTING_DELAY) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def keepalive_coro(): + while True: + await asyncio.sleep(PING_SENDING_DELAY) + try: + # Send a ping + await WebSocketServerHelper.send_ping( + ws, payload="dummy_ping_payload" + ) + + # Wait for a pong + try: + await asyncio.wait_for(pong_received.wait(), PONG_TIMEOUT) + except asyncio.TimeoutError: + print("\nNo pong received in time!\n") + break + + pong_received.clear() + + except websockets.exceptions.ConnectionClosed: + break + + if keepalive: + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + answer_type = json_result["type"] + + if answer_type == "complete" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + if keepalive: + print("Cancelling keep alive task now") + keepalive_task.cancel() + + elif answer_type == "ping": + if answer_pings: + payload = json_result.get("payload", None) + await WebSocketServerHelper.send_pong(ws, payload=payload) + + elif answer_type == "pong": + pong_received.set() + + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print("Now receiving task is cancelled") + + if keepalive: + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print("Now keepalive task is cancelled") + + await WebSocketServerHelper.send_complete(ws, query_id) + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\nAssertion failed: {e!s}\n") + finally: + await ws.wait_closed() + + return server_countdown_template + + +async def server_countdown(ws, path): + + server = server_countdown_factory() + await server(ws, path) + + +async def server_countdown_keepalive(ws, path): + + server = server_countdown_factory(keepalive=True) + await server(ws, path) + + +async def server_countdown_dont_answer_pings(ws, path): + + server = server_countdown_factory(answer_pings=False) + await server(ws, path) + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_break( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_task_cancel( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_close_transport( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(COUNTING_DELAY) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_server_connection_closed( + event_loop, client_and_graphqlws_server, subscription_str +): + import websockets + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(websockets.exceptions.ConnectionClosedOK): + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_operation_name( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_keepalive( + event_loop, client_and_graphqlws_server, subscription_str +): + + session, server = client_and_graphqlws_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + assert session.transport.payloads["ping"] == "dummy_ping_payload" + assert ( + session.transport.payloads["connection_ack"] == "dummy_connection_ack_payload" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_keepalive_with_timeout_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, keep_alive_timeout=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_keepalive_with_timeout_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, keep_alive_timeout=(COUNTING_DELAY / 2)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_ping_interval_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport( + url=url, ping_interval=(5 * COUNTING_DELAY), pong_timeout=(4 * COUNTING_DELAY), + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_dont_answer_pings], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_with_ping_interval_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, ping_interval=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No pong received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_manual_pings_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + payload = {"count_received": count} + + await transport.send_ping(payload=payload) + + await transport.pong_received.wait() + transport.pong_received.clear() + + assert transport.payloads["pong"] == payload + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_manual_pong_answers_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url, answer_pings=False) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + async def answer_ping_coro(): + while True: + await transport.ping_received.wait() + transport.ping_received.clear() + await transport.send_pong(payload={"some": "data"}) + + answer_ping_task = asyncio.ensure_future(answer_ping_coro()) + + try: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + finally: + answer_ping_task.cancel() + + assert count == -1 + + +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_graphqlws_subscription_sync(graphqlws_server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_graphqlws_subscription_sync_graceful_shutdown( + graphqlws_server, subscription_str +): + """ Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Check that the server received a connection_terminate message last + # assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_graphqlws_subscription_running_in_thread( + event_loop, graphqlws_server, subscription_str, run_sync_test +): + from gql.transport.websockets import WebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, graphqlws_server, test_code) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 3c6ec2b2..6367945d 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -235,6 +235,8 @@ async def test_phoenix_channel_subscription_no_break( from gql.transport.phoenix_channel_websockets import log as phoenix_logger from gql.transport.websockets import log as websockets_logger + from .conftest import MS + websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) @@ -244,7 +246,7 @@ async def test_phoenix_channel_subscription_no_break( async def testing_stopping_without_break(): sample_transport = PhoenixChannelWebsocketsTransport( - channel_name=test_channel, url=url, close_timeout=5 + channel_name=test_channel, url=url, close_timeout=(5000 * MS) ) count = 10 @@ -256,7 +258,8 @@ async def testing_stopping_without_break(): print(f"Number received: {number}") # Simulate a slow consumer - await asyncio.sleep(0.1) + if number == 10: + await asyncio.sleep(50 * MS) if number == 9: # When we consume the number 9 here in the async generator, @@ -274,7 +277,7 @@ async def testing_stopping_without_break(): assert count == -1 try: - await asyncio.wait_for(testing_stopping_without_break(), timeout=5) + await asyncio.wait_for(testing_stopping_without_break(), timeout=(5000 * MS)) except asyncio.TimeoutError: assert False, "The async generator did not stop" diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 7d87ee81..d5167720 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -391,7 +391,7 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS)) + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(10 * MS)) client = Client(transport=sample_transport) diff --git a/tox.ini b/tox.ini index 19231ed7..414f083b 100644 --- a/tox.ini +++ b/tox.ini @@ -21,6 +21,7 @@ setenv = PYTHONPATH = {toxinidir} MULTIDICT_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/multidict YARL_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/yarl + GQL_TESTS_TIMEOUT_FACTOR = 10 install_command = python -m pip install --ignore-installed {opts} {packages} whitelist_externals = python From 9fbb851d6e1c79c8f50e6075bbb4518e023843b1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 24 Oct 2021 18:00:11 +0200 Subject: [PATCH 023/239] Add .readthedocs.yaml file Requesting to use the 'all' extra dependency to try to fix class references in the docs --- .readthedocs.yaml | 28 ++++++++++++++++++++++++++++ MANIFEST.in | 1 + 2 files changed, 29 insertions(+) create mode 100644 .readthedocs.yaml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..6517a097 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,28 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.9" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Optionally build your docs in additional formats such as PDF +formats: + - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/requirements.txt + + extra_requirements: + - all diff --git a/MANIFEST.in b/MANIFEST.in index fbaa10b4..4d7eaef4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,7 @@ include CODEOWNERS include LICENSE include README.md include CONTRIBUTING.md +include .readthedocs.yaml include dev_requirements.txt include Makefile From aeb2402144fc62e523d02992ffcdc517cbbef989 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 24 Oct 2021 18:08:11 +0200 Subject: [PATCH 024/239] Fix ReadTheDocs documentation compilation --- .readthedocs.yaml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6517a097..749771cf 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -19,10 +19,11 @@ sphinx: formats: - pdf -# Optionally declare the Python requirements required to build your docs python: install: - requirements: docs/requirements.txt - - extra_requirements: - - all + - method: pip + path: . + extra_requirements: + - all + system_packages: true From a26cc747909d7de436309e196d0164367a788c1e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 24 Oct 2021 19:08:30 +0200 Subject: [PATCH 025/239] DOCS fix transport class references --- docs/transports/aiohttp.rst | 2 +- docs/transports/requests.rst | 2 +- docs/transports/websockets.rst | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index 91f2bf40..68b3eb99 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -5,7 +5,7 @@ AIOHTTPTransport This transport uses the `aiohttp`_ library and allows you to send GraphQL queries using the HTTP protocol. -Reference: :py:class:`gql.transport.aiohttp.AIOHTTPTransport` +Reference: :class:`gql.transport.aiohttp.AIOHTTPTransport` .. note:: diff --git a/docs/transports/requests.rst b/docs/transports/requests.rst index 15eaedb5..93e18926 100644 --- a/docs/transports/requests.rst +++ b/docs/transports/requests.rst @@ -6,7 +6,7 @@ RequestsHTTPTransport The RequestsHTTPTransport is a sync transport using the `requests`_ library and allows you to send GraphQL queries using the HTTP protocol. -Reference: :py:class:`gql.transport.requests.RequestsHTTPTransport` +Reference: :class:`gql.transport.requests.RequestsHTTPTransport` .. literalinclude:: ../code_examples/requests_sync.py diff --git a/docs/transports/websockets.rst b/docs/transports/websockets.rst index 7c91efb6..689cc136 100644 --- a/docs/transports/websockets.rst +++ b/docs/transports/websockets.rst @@ -12,7 +12,7 @@ It will detect the backend supported protocol from the response http headers ret This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. -Reference: :py:class:`gql.transport.websockets.WebsocketsTransport` +Reference: :class:`gql.transport.websockets.WebsocketsTransport` .. literalinclude:: ../code_examples/websockets_async.py From e2223ace4b226692bd395ba499d9ebb040abe19d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 24 Oct 2021 21:14:20 +0200 Subject: [PATCH 026/239] RequestsHTTPTransport fix query logged twice --- gql/transport/requests.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 68b4144b..8b95722d 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -198,9 +198,6 @@ def execute( # type: ignore if variable_values: payload["variables"] = variable_values - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(payload)) - data_key = "json" if self.use_json else "data" post_args[data_key] = payload From 14397ade4a3315efbc727369ccd979f1d26d3afb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 25 Oct 2021 17:42:58 +0200 Subject: [PATCH 027/239] Update requests/urllib3 dependency and allow retries on POST requests (#249) --- gql/transport/requests.py | 1 + setup.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 8b95722d..31b52809 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -95,6 +95,7 @@ def connect(self): total=self.retries, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504], + allowed_methods=None, ) ) for prefix in "https://", "https://": diff --git a/setup.py b/setup.py index ead75821..94f3a9ee 100644 --- a/setup.py +++ b/setup.py @@ -37,8 +37,9 @@ ] install_requests_requires = [ - "requests>=2.23,<3", + "requests>=2.26,<3", "requests_toolbelt>=0.9.1,<1", + "urllib3>=1.26", ] install_websockets_requires = [ From 33948bb8012d10525b9c319a7346158a35ca201a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 25 Oct 2021 22:25:05 +0200 Subject: [PATCH 028/239] TESTS fix flaky tests on windows --- tests/test_graphqlws_subscription.py | 174 +++++++++++++++++---------- tests/test_websocket_subscription.py | 2 +- 2 files changed, 109 insertions(+), 67 deletions(-) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 8f38d101..2c7cff23 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -18,9 +18,9 @@ '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' ) -COUNTING_DELAY = 2 * MS -PING_SENDING_DELAY = 5 * MS -PONG_TIMEOUT = 2 * MS +COUNTING_DELAY = 20 * MS +PING_SENDING_DELAY = 50 * MS +PONG_TIMEOUT = 100 * MS # List which can used to store received messages by the server logged_messages: List[str] = [] @@ -48,100 +48,140 @@ async def server_countdown_template(ws, path): count_found = search("count: {:d}", query) count = count_found[0] - print(f"Countdown started from: {count}") + print(f" Server: Countdown started from: {count}") pong_received: asyncio.Event = asyncio.Event() async def counting_coro(): - for number in range(count, -1, -1): - await ws.send( - countdown_server_answer.format(query_id=query_id, number=number) - ) - await asyncio.sleep(COUNTING_DELAY) + print(" Server: counting task started") + try: + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format( + query_id=query_id, number=number + ) + ) + await asyncio.sleep(COUNTING_DELAY) + finally: + print(" Server: counting task ended") + print(" Server: starting counting task") counting_task = asyncio.ensure_future(counting_coro()) async def keepalive_coro(): - while True: - await asyncio.sleep(PING_SENDING_DELAY) - try: - # Send a ping - await WebSocketServerHelper.send_ping( - ws, payload="dummy_ping_payload" - ) - - # Wait for a pong + print(" Server: keepalive task started") + try: + while True: + await asyncio.sleep(PING_SENDING_DELAY) try: - await asyncio.wait_for(pong_received.wait(), PONG_TIMEOUT) - except asyncio.TimeoutError: - print("\nNo pong received in time!\n") + # Send a ping + await WebSocketServerHelper.send_ping( + ws, payload="dummy_ping_payload" + ) + + # Wait for a pong + try: + await asyncio.wait_for( + pong_received.wait(), PONG_TIMEOUT + ) + except asyncio.TimeoutError: + print( + "\n Server: No pong received in time!\n" + ) + break + + pong_received.clear() + + except websockets.exceptions.ConnectionClosed: break - - pong_received.clear() - - except websockets.exceptions.ConnectionClosed: - break + finally: + print(" Server: keepalive task ended") if keepalive: + print(" Server: starting keepalive task") keepalive_task = asyncio.ensure_future(keepalive_coro()) async def receiving_coro(): - nonlocal counting_task - while True: - - try: - result = await ws.recv() - logged_messages.append(result) - except websockets.exceptions.ConnectionClosed: - break - - json_result = json.loads(result) - - answer_type = json_result["type"] - - if answer_type == "complete" and json_result["id"] == str(query_id): - print("Cancelling counting task now") - counting_task.cancel() - if keepalive: - print("Cancelling keep alive task now") - keepalive_task.cancel() - - elif answer_type == "ping": - if answer_pings: - payload = json_result.get("payload", None) - await WebSocketServerHelper.send_pong(ws, payload=payload) + print(" Server: receiving task started") + try: + nonlocal counting_task + while True: - elif answer_type == "pong": - pong_received.set() + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + json_result = json.loads(result) + + answer_type = json_result["type"] + + if answer_type == "complete" and json_result["id"] == str( + query_id + ): + print("Cancelling counting task now") + counting_task.cancel() + if keepalive: + print("Cancelling keep alive task now") + keepalive_task.cancel() + + elif answer_type == "ping": + if answer_pings: + payload = json_result.get("payload", None) + await WebSocketServerHelper.send_pong( + ws, payload=payload + ) + + elif answer_type == "pong": + pong_received.set() + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") receiving_task = asyncio.ensure_future(receiving_coro()) try: + print(" Server: waiting for counting task to complete") await counting_task except asyncio.CancelledError: - print("Now counting task is cancelled") + print(" Server: Now counting task is cancelled") - receiving_task.cancel() - - try: - await receiving_task - except asyncio.CancelledError: - print("Now receiving task is cancelled") + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) if keepalive: + print(" Server: cancelling keepalive task") keepalive_task.cancel() try: await keepalive_task except asyncio.CancelledError: - print("Now keepalive task is cancelled") + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") - await WebSocketServerHelper.send_complete(ws, query_id) except websockets.exceptions.ConnectionClosedOK: pass except AssertionError as e: - print(f"\nAssertion failed: {e!s}\n") + print(f"\n Server: Assertion failed: {e!s}\n") finally: + print(" Server: waiting for websocket connection to close") await ws.wait_closed() + print(" Server: connection closed") return server_countdown_template @@ -406,6 +446,7 @@ async def test_graphqlws_subscription_with_keepalive( count -= 1 assert count == -1 + assert "ping" in session.transport.payloads assert session.transport.payloads["ping"] == "dummy_ping_payload" assert ( session.transport.payloads["connection_ack"] == "dummy_connection_ack_payload" @@ -570,18 +611,19 @@ async def test_graphqlws_subscription_manual_pings_with_payload( number = result["number"] print(f"Number received: {number}") - assert number == count - count -= 1 - payload = {"count_received": count} await transport.send_ping(payload=payload) - await transport.pong_received.wait() + await asyncio.wait_for(transport.pong_received.wait(), 10000 * MS) + transport.pong_received.clear() assert transport.payloads["pong"] == payload + assert number == count + count -= 1 + assert count == -1 diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index d5167720..5300333d 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -391,7 +391,7 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(10 * MS)) + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) client = Client(transport=sample_transport) From d5e3e6db07f7d00069c5217b84747e6f85b3dc87 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 26 Oct 2021 10:28:23 +0200 Subject: [PATCH 029/239] Bump version number to 3.0.0b0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index c28a7154..b3d3a3b4 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0a6" +__version__ = "3.0.0b0" From 37f1917b846db0d5b93a9d5b19e6e4edd0699a9b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 19 Nov 2021 23:50:39 +0100 Subject: [PATCH 030/239] Custom Scalars input serialization in variables (#253) --- docs/modules/gql.rst | 1 + docs/modules/utilities.rst | 6 + docs/usage/custom_scalars.rst | 134 ++++ docs/usage/index.rst | 1 + gql/client.py | 199 +++++- gql/utilities/__init__.py | 5 + gql/utilities/update_schema_scalars.py | 32 + gql/variable_values.py | 117 ++++ tests/custom_scalars/__init__.py | 0 .../test_custom_scalar_datetime.py | 220 ++++++ .../test_custom_scalar_money.py | 635 ++++++++++++++++++ 11 files changed, 1338 insertions(+), 12 deletions(-) create mode 100644 docs/modules/utilities.rst create mode 100644 docs/usage/custom_scalars.rst create mode 100644 gql/utilities/__init__.py create mode 100644 gql/utilities/update_schema_scalars.py create mode 100644 gql/variable_values.py create mode 100644 tests/custom_scalars/__init__.py create mode 100644 tests/custom_scalars/test_custom_scalar_datetime.py create mode 100644 tests/custom_scalars/test_custom_scalar_money.py diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index aac47c86..06a89a96 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,3 +21,4 @@ Sub-Packages client transport dsl + utilities diff --git a/docs/modules/utilities.rst b/docs/modules/utilities.rst new file mode 100644 index 00000000..47043b98 --- /dev/null +++ b/docs/modules/utilities.rst @@ -0,0 +1,6 @@ +gql.utilities +============= + +.. currentmodule:: gql.utilities + +.. automodule:: gql.utilities diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst new file mode 100644 index 00000000..baee441e --- /dev/null +++ b/docs/usage/custom_scalars.rst @@ -0,0 +1,134 @@ +Custom Scalars +============== + +Scalar types represent primitive values at the leaves of a query. + +GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend +can add additional custom scalars to its schema to better express values in their data model. + +For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. + +The schema will then only contain: + +.. code-block:: python + + scalar Datetime + +When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), +their values need to be serialized to be composed +of only built-in scalars, then at the destination the serialized values will be parsed again to +be able to represent the scalar in its local internal representation. + +Because this serialization/unserialization is dependent on the language used at both sides, it is not +described in the schema and needs to be defined independently at both sides (client, backend). + +A custom scalar value can have two different representations during its transport: + + - as a serialized value (usually as json): + + * in the results sent by the backend + * in the variables sent by the client alongside the query + + - as "literal" inside the query itself sent by the client + +To define a custom scalar, you need 3 methods: + + - a :code:`serialize` method used: + + * by the backend to serialize a custom scalar output in the result + * by the client to serialize a custom scalar input in the variables + + - a :code:`parse_value` method used: + + * by the backend to unserialize custom scalars inputs in the variables sent by the client + * by the client to unserialize custom scalars outputs from the results + + - a :code:`parse_literal` method used: + + * by the backend to unserialize custom scalars inputs inside the query itself + +To define a custom scalar object, we define a :code:`GraphQLScalarType` from graphql-core with +its name and the implementation of the above methods. + +Example for Datetime: + +.. code-block:: python + + from datetime import datetime + from typing import Any, Dict, Optional + + from graphql import GraphQLScalarType, ValueNode + from graphql.utilities import value_from_ast_untyped + + + def serialize_datetime(value: Any) -> str: + return value.isoformat() + + + def parse_datetime_value(value: Any) -> datetime: + return datetime.fromisoformat(value) + + + def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None + ) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + return parse_datetime_value(ast_value) + + + DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, + ) + +Custom Scalars in inputs +------------------------ + +To provide custom scalars in input with gql, you can: + +- serialize the scalar yourself as "literal" in the query: + +.. code-block:: python + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + +- serialize the scalar yourself in a variable: + +.. code-block:: python + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + +- add a custom scalar to the schema with :func:`update_schema_scalars ` + and execute the query with :code:`serialize_variables=True` + and gql will serialize the variable values from a Python object representation. + +For this, you need to provide a schema or set :code:`fetch_schema_from_transport=True` +in the client to request the schema from the backend. + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) diff --git a/docs/usage/index.rst b/docs/usage/index.rst index a7dd4d56..4a38093a 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -10,3 +10,4 @@ Usage variables headers file_upload + custom_scalars diff --git a/gql/client.py b/gql/client.py index 6017ab69..368193cc 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,6 +17,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .variable_values import serialize_variable_values class Client: @@ -289,18 +290,79 @@ def __init__(self, client: Client): """:param client: the :class:`client ` used""" self.client = client - def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def _execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> ExecutionResult: + """Execute the provided document AST synchronously using + the sync transport, returning an ExecutionResult object. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) - return self.transport.execute(document, *args, **kwargs) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + + return self.transport.execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: + """Execute the provided document AST synchronously using + the sync transport. + + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = self._execute(document, *args, **kwargs) + result = self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: @@ -341,17 +403,52 @@ def __init__(self, client: Client): self.client = client async def _subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: + """Coroutine to subscribe asynchronously to the provided document AST + asynchronously using the async transport, + returning an async generator producing ExecutionResult objects. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport subscribe method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Subscribe to the transport inner_generator: AsyncGenerator[ ExecutionResult, None - ] = self.transport.subscribe(document, *args, **kwargs) + ] = self.transport.subscribe( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) # Keep a reference to the inner generator to allow the user to call aclose() # before a break if python version is too old (pypy3 py 3.6.1) @@ -364,15 +461,35 @@ async def _subscribe( await inner_generator.aclose() async def subscribe( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> AsyncGenerator[Dict, None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport subscribe method.""" inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, *args, **kwargs + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, ) try: @@ -391,27 +508,85 @@ async def subscribe( await inner_generator.aclose() async def _execute( - self, document: DocumentNode, *args, **kwargs + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, ) -> ExecutionResult: + """Coroutine to execute the provided document AST asynchronously using + the async transport, returning an ExecutionResult object. + + * Validate the query with the schema if provided. + * Serialize the variable_values if requested. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + + The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: self.client.validate(document) + # Parse variable values for custom scalars if requested + if serialize_variables and variable_values is not None: + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) + # Execute the query with the transport with a timeout return await asyncio.wait_for( - self.transport.execute(document, *args, **kwargs), + self.transport.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + *args, + **kwargs, + ), self.client.execute_timeout, ) - async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + async def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: bool = False, + **kwargs, + ) -> Dict: """Coroutine to execute the provided document AST asynchronously using the async transport. + Raises a TransportQueryError if an error has been returned in + the ExecutionResult. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters. + :param operation_name: Name of the operation that shall be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + The extra arguments are passed to the transport execute method.""" # Validate and execute on the transport - result = await self._execute(document, *args, **kwargs) + result = await self._execute( + document, + *args, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + **kwargs, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py new file mode 100644 index 00000000..68b80156 --- /dev/null +++ b/gql/utilities/__init__.py @@ -0,0 +1,5 @@ +from .update_schema_scalars import update_schema_scalars + +__all__ = [ + "update_schema_scalars", +] diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py new file mode 100644 index 00000000..d5434c6b --- /dev/null +++ b/gql/utilities/update_schema_scalars.py @@ -0,0 +1,32 @@ +from typing import Iterable, List + +from graphql import GraphQLError, GraphQLScalarType, GraphQLSchema + + +def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): + """Update the scalars in a schema with the scalars provided. + + This can be used to update the default Custom Scalar implementation + when the schema has been provided from a text file or from introspection. + """ + + if not isinstance(scalars, Iterable): + raise GraphQLError("Scalars argument should be a list of scalars.") + + for scalar in scalars: + if not isinstance(scalar, GraphQLScalarType): + raise GraphQLError("Scalars should be instances of GraphQLScalarType.") + + try: + schema_scalar = schema.type_map[scalar.name] + except KeyError: + raise GraphQLError(f"Scalar '{scalar.name}' not found in schema.") + + assert isinstance(schema_scalar, GraphQLScalarType) + + # Update the conversion methods + # Using setattr because mypy has a false positive + # https://github.com/python/mypy/issues/2427 + setattr(schema_scalar, "serialize", scalar.serialize) + setattr(schema_scalar, "parse_value", scalar.parse_value) + setattr(schema_scalar, "parse_literal", scalar.parse_literal) diff --git a/gql/variable_values.py b/gql/variable_values.py new file mode 100644 index 00000000..7db7091a --- /dev/null +++ b/gql/variable_values.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, Optional + +from graphql import ( + DocumentNode, + GraphQLEnumType, + GraphQLError, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + GraphQLScalarType, + GraphQLSchema, + GraphQLType, + GraphQLWrappingType, + OperationDefinitionNode, + type_from_ast, +) +from graphql.pyutils import inspect + + +def get_document_operation( + document: DocumentNode, operation_name: Optional[str] = None +) -> OperationDefinitionNode: + """Returns the operation which should be executed in the document. + + Raises a GraphQLError if a single operation cannot be retrieved. + """ + + operation: Optional[OperationDefinitionNode] = None + + for definition in document.definitions: + if isinstance(definition, OperationDefinitionNode): + if operation_name is None: + if operation: + raise GraphQLError( + "Must provide operation name" + " if query contains multiple operations." + ) + operation = definition + elif definition.name and definition.name.value == operation_name: + operation = definition + + if not operation: + if operation_name is not None: + raise GraphQLError(f"Unknown operation named '{operation_name}'.") + + # The following line should never happen normally as the document is + # already verified before calling this function. + raise GraphQLError("Must provide an operation.") # pragma: no cover + + return operation + + +def serialize_value(type_: GraphQLType, value: Any) -> Any: + """Given a GraphQL type and a Python value, return the serialized value. + + Can be used to serialize Enums and/or Custom Scalars in variable values. + """ + + if value is None: + if isinstance(type_, GraphQLNonNull): + # raise GraphQLError(f"Type {type_.of_type.name} Cannot be None.") + raise GraphQLError(f"Type {inspect(type_)} Cannot be None.") + else: + return None + + if isinstance(type_, GraphQLWrappingType): + inner_type = type_.of_type + + if isinstance(type_, GraphQLNonNull): + return serialize_value(inner_type, value) + + elif isinstance(type_, GraphQLList): + return [serialize_value(inner_type, v) for v in value] + + elif isinstance(type_, (GraphQLScalarType, GraphQLEnumType)): + return type_.serialize(value) + + elif isinstance(type_, GraphQLInputObjectType): + return { + field_name: serialize_value(field.type, value[field_name]) + for field_name, field in type_.fields.items() + } + + raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.") + + +def serialize_variable_values( + schema: GraphQLSchema, + document: DocumentNode, + variable_values: Dict[str, Any], + operation_name: Optional[str] = None, +) -> Dict[str, Any]: + """Given a GraphQL document and a schema, serialize the Dictionary of + variable values. + + Useful to serialize Enums and/or Custom Scalars in variable values + """ + + parsed_variable_values: Dict[str, Any] = {} + + # Find the operation in the document + operation = get_document_operation(document, operation_name=operation_name) + + # Serialize every variable value defined for the operation + for var_def_node in operation.variable_definitions: + var_name = var_def_node.variable.name.value + var_type = type_from_ast(schema, var_def_node.type) + + if var_name in variable_values: + + assert var_type is not None + + var_value = variable_values[var_name] + + parsed_variable_values[var_name] = serialize_value(var_type, var_value) + + return parsed_variable_values diff --git a/tests/custom_scalars/__init__.py b/tests/custom_scalars/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_custom_scalar_datetime.py new file mode 100644 index 00000000..25c6bb31 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_datetime.py @@ -0,0 +1,220 @@ +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +import pytest +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInt, + GraphQLList, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql + + +def serialize_datetime(value: Any) -> str: + if not isinstance(value, datetime): + raise GraphQLError("Cannot serialize datetime value: " + inspect(value)) + return value.isoformat() + + +def parse_datetime_value(value: Any) -> datetime: + + if isinstance(value, str): + try: + # Note: a more solid custom scalar should use dateutil.parser.isoparse + # Not using it here in the test to avoid adding another dependency + return datetime.fromisoformat(value) + except Exception: + raise GraphQLError("Cannot parse datetime value : " + inspect(value)) + + else: + raise GraphQLError("Cannot parse datetime value: " + inspect(value)) + + +def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + if not isinstance(ast_value, str): + raise GraphQLError("Cannot parse literal datetime value: " + inspect(ast_value)) + + return parse_datetime_value(ast_value) + + +DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, +) + + +def resolve_shift_days(root, _info, time, days): + return time + timedelta(days=days) + + +def resolve_latest(root, _info, times): + return max(times) + + +def resolve_seconds(root, _info, interval): + print(f"interval={interval!r}") + return (interval["end"] - interval["start"]).total_seconds() + + +IntervalInputType = GraphQLInputObjectType( + "IntervalInput", + fields={ + "start": GraphQLInputField(DatetimeScalar), + "end": GraphQLInputField(DatetimeScalar), + }, +) + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "shiftDays": GraphQLField( + DatetimeScalar, + args={ + "time": GraphQLArgument(DatetimeScalar), + "days": GraphQLArgument(GraphQLInt), + }, + resolve=resolve_shift_days, + ), + "latest": GraphQLField( + DatetimeScalar, + args={"times": GraphQLArgument(GraphQLList(DatetimeScalar))}, + resolve=resolve_latest, + ), + "seconds": GraphQLField( + GraphQLInt, + args={"interval": GraphQLArgument(IntervalInputType)}, + resolve=resolve_seconds, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": now, + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_query(): + + client = Client(schema=schema) + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + + result = client.execute(query) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_shift_days_serialized_manually_in_variables(): + + client = Client(schema=schema) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + + print(result) + + assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_latest(): + + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql("query latest($times: [Datetime!]!) {latest(times: $times)}") + + variable_values = { + "times": [now, in_five_days], + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["latest"] == in_five_days.isoformat() + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_seconds(): + client = Client(schema=schema) + + now = datetime.fromisoformat("2021-11-12T11:58:13.461161") + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql( + "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" + ) + + variable_values = {"interval": {"start": now, "end": in_five_days}} + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["seconds"] == 432000 diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_custom_scalar_money.py new file mode 100644 index 00000000..238308a9 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_money.py @@ -0,0 +1,635 @@ +import asyncio +from typing import Any, Dict, NamedTuple, Optional + +import pytest +from graphql import graphql_sync +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect, is_finite +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.transport.exceptions import TransportQueryError +from gql.utilities import update_schema_scalars +from gql.variable_values import serialize_value + +from ..conftest import MS + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +class Money(NamedTuple): + amount: float + currency: str + + +def serialize_money(output_value: Any) -> Dict[str, Any]: + if not isinstance(output_value, Money): + raise GraphQLError("Cannot serialize money value: " + inspect(output_value)) + return output_value._asdict() + + +def parse_money_value(input_value: Any) -> Money: + """Using Money custom scalar from graphql-core tests except here the + input value is supposed to be a dict instead of a Money object.""" + + """ + if isinstance(input_value, Money): + return input_value + """ + + if isinstance(input_value, dict): + amount = input_value.get("amount", None) + currency = input_value.get("currency", None) + + if not is_finite(amount) or not isinstance(currency, str): + raise GraphQLError("Cannot parse money value dict: " + inspect(input_value)) + + return Money(float(amount), currency) + else: + raise GraphQLError("Cannot parse money value: " + inspect(input_value)) + + +def parse_money_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Money: + money = value_from_ast_untyped(value_node, variables) + if variables is not None and ( + # variables are not set when checked with ValuesIOfCorrectTypeRule + not money + or not is_finite(money.get("amount")) + or not isinstance(money.get("currency"), str) + ): + raise GraphQLError("Cannot parse literal money value: " + inspect(money)) + return Money(**money) + + +MoneyScalar = GraphQLScalarType( + name="Money", + serialize=serialize_money, + parse_value=parse_money_value, + parse_literal=parse_money_literal, +) + + +def resolve_balance(root, _info): + return root + + +def resolve_to_euros(_root, _info, money): + amount = money.amount + currency = money.currency + if not amount or currency == "EUR": + return amount + if currency == "DM": + return amount * 0.5 + raise ValueError("Cannot convert to euros: " + inspect(money)) + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "balance": GraphQLField(MoneyScalar, resolve=resolve_balance), + "toEuros": GraphQLField( + GraphQLFloat, + args={"money": GraphQLArgument(MoneyScalar)}, + resolve=resolve_to_euros, + ), + }, +) + + +def resolve_spent_money(spent_money, _info, **kwargs): + return spent_money + + +async def subscribe_spend_all(_root, _info, money): + while money.amount > 0: + money = Money(money.amount - 1, money.currency) + yield money + await asyncio.sleep(1 * MS) + + +subscriptionType = GraphQLObjectType( + "Subscription", + fields=lambda: { + "spend": GraphQLField( + MoneyScalar, + args={"money": GraphQLArgument(MoneyScalar)}, + subscribe=subscribe_spend_all, + resolve=resolve_spent_money, + ) + }, +) + +root_value = Money(42, "DM") + +schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) + + +def test_custom_scalar_in_output(): + + client = Client(schema=schema) + + query = gql("{balance}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +def test_custom_scalar_in_input_query(): + + client = Client(schema=schema) + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = client.execute(query, root_value=root_value) + + assert result["toEuros"] == 10 + + +def test_custom_scalar_in_input_variable_values(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + + variable_values = {"money": money_value} + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + assert result["toEuros"] == 5 + + +def test_custom_scalar_in_input_variable_values_serialized_with_operation_name(): + + client = Client(schema=schema) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="myquery", + ) + + assert result["toEuros"] == 5 + + +def test_serialize_variable_values_exception_multiple_ops_without_operation_name(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } + + query mybalance { + balance + }""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ) + + exception = exc_info.value + + assert ( + str(exception) + == "Must provide operation name if query contains multiple operations." + ) + + +def test_serialize_variable_values_exception_operation_name_not_found(): + + client = Client(schema=schema) + + query = gql( + """ + query myconversion($money: Money) { + toEuros(money: $money) + } +""" + ) + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + with pytest.raises(GraphQLError) as exc_info: + client.execute( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + operation_name="invalid_operation_name", + ) + + exception = exc_info.value + + assert str(exception) == "Unknown operation named 'invalid_operation_name'." + + +def test_custom_scalar_subscribe_in_input_variable_values_serialized(): + + client = Client(schema=schema) + + query = gql("subscription spendAll($money: Money) {spend(money: $money)}") + + money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + expected_result = {"spend": {"amount": 10, "currency": "DM"}} + + for result in client.subscribe( + query, + variable_values=variable_values, + root_value=root_value, + serialize_variables=True, + ): + print(f"result = {result!r}") + expected_result["spend"]["amount"] = expected_result["spend"]["amount"] - 1 + assert expected_result == result + + +async def make_money_backend(aiohttp_server): + from aiohttp import web + + async def handler(request): + data = await request.json() + source = data["query"] + + print(f"data keys = {data.keys()}") + try: + variables = data["variables"] + print(f"variables = {variables!r}") + except KeyError: + variables = None + + result = graphql_sync( + schema, source, variable_values=variables, root_value=root_value + ) + + print(f"backend result = {result!r}") + + return web.json_response( + { + "data": result.data, + "errors": [str(e) for e in result.errors] if result.errors else None, + } + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + return server + + +async def make_money_transport(aiohttp_server): + from gql.transport.aiohttp import AIOHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + return transport + + +async def make_sync_money_transport(aiohttp_server): + from gql.transport.requests import RequestsHTTPTransport + + server = await make_money_backend(aiohttp_server) + + url = server.make_url("/") + + transport = RequestsHTTPTransport(url=url, timeout=10) + + return (server, transport) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("{balance}") + + result = await session.execute(query) + + print(result) + + assert result["balance"] == serialize_money(root_value) + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 5 + + query = gql('{toEuros(money: {amount: 10, currency: "EUR"})}') + + result = await session.execute(query) + + assert result["toEuros"] == 10 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + money_value = {"amount": 10, "currency": "DM"} + # money_value = Money(10, "DM") + + variable_values = {"money": money_value} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_in_input_variable_values_split_with_transport( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql( + """ +query myquery($amount: Float, $currency: String) { + toEuros(money: {amount: $amount, currency: $currency}) +}""" + ) + + variable_values = {"amount": 10, "currency": "DM"} + + result = await session.execute(query, variable_values=variable_values) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + with pytest.raises(TransportQueryError): + await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + +@pytest.mark.asyncio +async def test_custom_scalar_serialize_variables_schema_from_introspection( + event_loop, aiohttp_server +): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + schema = session.client.schema + + # Updating the Money Scalar in the schema + # We cannot replace it because some other objects keep a reference + # to the existing Scalar + # cannot do: schema.type_map["Money"] = MoneyScalar + + money_scalar = schema.type_map["Money"] + + money_scalar.serialize = MoneyScalar.serialize + money_scalar.parse_value = MoneyScalar.parse_value + money_scalar.parse_literal = MoneyScalar.parse_literal + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +@pytest.mark.asyncio +async def test_update_schema_scalars(event_loop, aiohttp_server): + + transport = await make_money_transport(aiohttp_server) + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # Update the schema MoneyScalar default implementation from + # introspection with our provided conversion methods + update_schema_scalars(session.client.schema, [MoneyScalar]) + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + +def test_update_schema_scalars_invalid_scalar(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [int]) + + exception = exc_info.value + + assert str(exception) == "Scalars should be instances of GraphQLScalarType." + + +def test_update_schema_scalars_invalid_scalar_argument(): + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, MoneyScalar) + + exception = exc_info.value + + assert str(exception) == "Scalars argument should be a list of scalars." + + +def test_update_schema_scalars_scalar_not_found_in_schema(): + + NotFoundScalar = GraphQLScalarType(name="abcd",) + + with pytest.raises(GraphQLError) as exc_info: + update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) + + exception = exc_info.value + + assert str(exception) == "Scalar 'abcd' not found in schema." + + +@pytest.mark.asyncio +@pytest.mark.requests +async def test_custom_scalar_serialize_variables_sync_transport( + event_loop, aiohttp_server, run_sync_test +): + + server, transport = await make_sync_money_transport(aiohttp_server) + + def test_code(): + with Client(schema=schema, transport=transport,) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + result = session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(f"result = {result!r}") + assert result["toEuros"] == 5 + + await run_sync_test(event_loop, server, test_code) + + +def test_serialize_value_with_invalid_type(): + + with pytest.raises(GraphQLError) as exc_info: + serialize_value("Not a valid type", 50) + + exception = exc_info.value + + assert ( + str(exception) == "Impossible to serialize value with type: 'Not a valid type'." + ) + + +def test_serialize_value_with_non_null_type_null(): + + non_null_int = GraphQLNonNull(GraphQLInt) + + with pytest.raises(GraphQLError) as exc_info: + serialize_value(non_null_int, None) + + exception = exc_info.value + + assert str(exception) == "Type Int! Cannot be None." + + +def test_serialize_value_with_nullable_type(): + + nullable_int = GraphQLInt + + assert serialize_value(nullable_int, None) is None From d1ba78dafe33b95eee4a6d6abcde26c764c69c44 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 19 Nov 2021 23:56:27 +0100 Subject: [PATCH 031/239] DSL serialize complex arguments to literals (#255) --- gql/dsl.py | 102 +++++++- .../custom_scalars/test_custom_scalar_json.py | 241 ++++++++++++++++++ tests/starwars/test_dsl.py | 33 ++- 3 files changed, 368 insertions(+), 8 deletions(-) create mode 100644 tests/custom_scalars/test_custom_scalar_json.py diff --git a/gql/dsl.py b/gql/dsl.py index f3bd1fe2..1646d402 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,15 +1,22 @@ import logging +import re from abc import ABC, abstractmethod +from math import isfinite from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, + BooleanValueNode, DocumentNode, + EnumValueNode, FieldNode, + FloatValueNode, FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLError, GraphQLField, + GraphQLID, GraphQLInputObjectType, GraphQLInputType, GraphQLInterfaceType, @@ -20,6 +27,7 @@ GraphQLSchema, GraphQLWrappingType, InlineFragmentNode, + IntValueNode, ListTypeNode, ListValueNode, NamedTypeNode, @@ -31,25 +39,76 @@ OperationDefinitionNode, OperationType, SelectionSetNode, + StringValueNode, TypeNode, Undefined, ValueNode, VariableDefinitionNode, VariableNode, assert_named_type, + is_enum_type, is_input_object_type, + is_leaf_type, is_list_type, is_non_null_type, is_wrapping_type, print_ast, ) -from graphql.pyutils import FrozenList -from graphql.utilities import ast_from_value as default_ast_from_value +from graphql.pyutils import FrozenList, inspect from .utils import to_camel_case log = logging.getLogger(__name__) +_re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$") + + +def ast_from_serialized_value_untyped(serialized: Any) -> Optional[ValueNode]: + """Given a serialized value, try our best to produce an AST. + + Anything ressembling an array (instance of Mapping) will be converted + to an ObjectFieldNode. + + Anything ressembling a list (instance of Iterable - except str) + will be converted to a ListNode. + + In some cases, a custom scalar can be serialized differently in the query + than in the variables. In that case, this function will not work.""" + + if serialized is None or serialized is Undefined: + return NullValueNode() + + if isinstance(serialized, Mapping): + field_items = ( + (key, ast_from_serialized_value_untyped(value)) + for key, value in serialized.items() + ) + field_nodes = ( + ObjectFieldNode(name=NameNode(value=field_name), value=field_value) + for field_name, field_value in field_items + if field_value + ) + return ObjectValueNode(fields=FrozenList(field_nodes)) + + if isinstance(serialized, Iterable) and not isinstance(serialized, str): + maybe_nodes = (ast_from_serialized_value_untyped(item) for item in serialized) + nodes = filter(None, maybe_nodes) + return ListValueNode(values=FrozenList(nodes)) + + if isinstance(serialized, bool): + return BooleanValueNode(value=serialized) + + if isinstance(serialized, int): + return IntValueNode(value=f"{serialized:d}") + + if isinstance(serialized, float) and isfinite(serialized): + return FloatValueNode(value=f"{serialized:g}") + + if isinstance(serialized, str): + return StringValueNode(value=serialized) + + raise TypeError(f"Cannot convert value to AST: {inspect(serialized)}.") + def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: """ @@ -60,15 +119,21 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: VariableNode when value is a DSLVariable Produce a GraphQL Value AST given a Python object. + + Raises a GraphQLError instead of returning None if we receive an Undefined + of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): return value.set_type(type_).ast_variable if is_non_null_type(type_): type_ = cast(GraphQLNonNull, type_) - ast_value = ast_from_value(value, type_.of_type) + inner_type = type_.of_type + ast_value = ast_from_value(value, inner_type) if isinstance(ast_value, NullValueNode): - return None + raise GraphQLError( + "Received Null value for a Non-Null type " f"{inspect(inner_type)}." + ) return ast_value # only explicit None, not Undefined or NaN @@ -77,7 +142,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: # undefined if value is Undefined: - return None + raise GraphQLError(f"Received Undefined value for type {inspect(type_)}.") # Convert Python list to GraphQL list. If the GraphQLType is a list, but the value # is not a list, convert the value using the list's item type. @@ -108,7 +173,32 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: ) return ObjectValueNode(fields=FrozenList(field_nodes)) - return default_ast_from_value(value, type_) + if is_leaf_type(type_): + # Since value is an internally represented value, it must be serialized to an + # externally represented value before converting into an AST. + serialized = type_.serialize(value) # type: ignore + + # if the serialized value is a string, then we should use the + # type to determine if it is an enum, an ID or a normal string + if isinstance(serialized, str): + # Enum types use Enum literals. + if is_enum_type(type_): + return EnumValueNode(value=serialized) + + # ID types can use Int literals. + if type_ is GraphQLID and _re_integer_string.match(serialized): + return IntValueNode(value=serialized) + + return StringValueNode(value=serialized) + + # Some custom scalars will serialize to dicts or lists + # Providing here a default conversion to AST using our best judgment + # until graphql-js issue #1817 is solved + # https://github.com/graphql/graphql-js/issues/1817 + return ast_from_serialized_value_untyped(serialized) + + # Not reachable. All possible input types have been considered. + raise TypeError(f"Unexpected input type: {inspect(type_)}.") def dsl_gql( diff --git a/tests/custom_scalars/test_custom_scalar_json.py b/tests/custom_scalars/test_custom_scalar_json.py new file mode 100644 index 00000000..80f99850 --- /dev/null +++ b/tests/custom_scalars/test_custom_scalar_json.py @@ -0,0 +1,241 @@ +from typing import Any, Dict, Optional + +import pytest +from graphql import ( + GraphQLArgument, + GraphQLError, + GraphQLField, + GraphQLFloat, + GraphQLInt, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.language import ValueNode +from graphql.utilities import value_from_ast_untyped + +from gql import Client, gql +from gql.dsl import DSLSchema + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +def serialize_json(value: Any) -> Dict[str, Any]: + return value + + +def parse_json_value(value: Any) -> Any: + return value + + +def parse_json_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Any: + return value_from_ast_untyped(value_node, variables) + + +JsonScalar = GraphQLScalarType( + name="JSON", + serialize=serialize_json, + parse_value=parse_json_value, + parse_literal=parse_json_literal, +) + +root_value = { + "players": [ + { + "name": "John", + "level": 3, + "is_connected": True, + "score": 123.45, + "friends": ["Alex", "Alicia"], + }, + { + "name": "Alex", + "level": 4, + "is_connected": False, + "score": 1337.69, + "friends": None, + }, + ] +} + + +def resolve_players(root, _info): + return root["players"] + + +queryType = GraphQLObjectType( + name="Query", fields={"players": GraphQLField(JsonScalar, resolve=resolve_players)}, +) + + +def resolve_add_player(root, _info, player): + print(f"player = {player!r}") + root["players"].append(player) + return {"players": root["players"]} + + +mutationType = GraphQLObjectType( + name="Mutation", + fields={ + "addPlayer": GraphQLField( + JsonScalar, + args={"player": GraphQLArgument(GraphQLNonNull(JsonScalar))}, + resolve=resolve_add_player, + ) + }, +) + +schema = GraphQLSchema(query=queryType, mutation=mutationType) + + +def test_json_value_output(): + + client = Client(schema=schema) + + query = gql("query {players}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["players"] == serialize_json(root_value["players"]) + + +def test_json_value_input_in_ast(): + + client = Client(schema=schema) + + query = gql( + """ + mutation adding_player { + addPlayer(player: { + name: "Tom", + level: 1, + is_connected: True, + score: 0, + friends: [ + "John" + ] + }) +}""" + ) + + result = client.execute(query, root_value=root_value) + + print(result) + + players = result["addPlayer"]["players"] + + assert players == serialize_json(root_value["players"]) + assert players[-1]["name"] == "Tom" + + +def test_json_value_input_in_ast_with_variables(): + + print(f"{schema.type_map!r}") + client = Client(schema=schema) + + # Note: we need to manually add the built-in types which + # are not present in the schema + schema.type_map["Int"] = GraphQLInt + schema.type_map["Float"] = GraphQLFloat + + query = gql( + """ + mutation adding_player( + $name: String!, + $level: Int!, + $is_connected: Boolean, + $score: Float!, + $friends: [String!]!) { + + addPlayer(player: { + name: $name, + level: $level, + is_connected: $is_connected, + score: $score, + friends: $friends, + }) +}""" + ) + + variable_values = { + "name": "Barbara", + "level": 1, + "is_connected": False, + "score": 69, + "friends": ["Alex", "John"], + } + + result = client.execute( + query, variable_values=variable_values, root_value=root_value + ) + + print(result) + + players = result["addPlayer"]["players"] + + assert players == serialize_json(root_value["players"]) + assert players[-1]["name"] == "Barbara" + + +def test_json_value_input_in_dsl_argument(): + + ds = DSLSchema(schema) + + new_player = { + "name": "Tim", + "level": 0, + "is_connected": False, + "score": 5, + "friends": ["Lea"], + } + + query = ds.Mutation.addPlayer(player=new_player) + + print(str(query)) + + assert ( + str(query) + == """addPlayer( + player: {name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"]} +)""" + ) + + +def test_none_json_value_input_in_dsl_argument(): + + ds = DSLSchema(schema) + + with pytest.raises(GraphQLError) as exc_info: + ds.Mutation.addPlayer(player=None) + + assert "Received Null value for a Non-Null type JSON." in str(exc_info.value) + + +def test_json_value_input_with_none_list_in_dsl_argument(): + + ds = DSLSchema(schema) + + new_player = { + "name": "Bob", + "level": 9001, + "is_connected": True, + "score": 666.66, + "friends": None, + } + + query = ds.Mutation.addPlayer(player=new_player) + + print(str(query)) + + assert ( + str(query) + == """addPlayer( + player: {name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null} +)""" + ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 93de6c03..d18bb37d 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,5 +1,7 @@ import pytest from graphql import ( + GraphQLError, + GraphQLID, GraphQLInt, GraphQLList, GraphQLNonNull, @@ -23,6 +25,7 @@ DSLSubscription, DSLVariable, DSLVariableDefinitions, + ast_from_serialized_value_untyped, ast_from_value, dsl_gql, ) @@ -54,12 +57,38 @@ def test_ast_from_value_with_none(): def test_ast_from_value_with_undefined(): - assert ast_from_value(Undefined, GraphQLInt) is None + with pytest.raises(GraphQLError) as exc_info: + ast_from_value(Undefined, GraphQLInt) + + assert "Received Undefined value for type Int." in str(exc_info.value) + + +def test_ast_from_value_with_graphqlid(): + + assert ast_from_value("12345", GraphQLID) == IntValueNode(value="12345") + + +def test_ast_from_value_with_invalid_type(): + with pytest.raises(TypeError) as exc_info: + ast_from_value(4, None) + + assert "Unexpected input type: None." in str(exc_info.value) def test_ast_from_value_with_non_null_type_and_none(): typ = GraphQLNonNull(GraphQLInt) - assert ast_from_value(None, typ) is None + + with pytest.raises(GraphQLError) as exc_info: + ast_from_value(None, typ) + + assert "Received Null value for a Non-Null type Int." in str(exc_info.value) + + +def test_ast_from_serialized_value_untyped_typeerror(): + with pytest.raises(TypeError) as exc_info: + ast_from_serialized_value_untyped(GraphQLInt) + + assert "Cannot convert value to AST: Int." in str(exc_info.value) def test_variable_to_ast_type_passing_wrapping_type(): From 46252d1642325082dde6017be1cb86409c6e5dfc Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 20:42:35 +0100 Subject: [PATCH 032/239] Parse custom scalar and enums in outputs (#256) --- README.md | 1 + docs/usage/custom_scalars.rst | 134 ------ docs/usage/custom_scalars_and_enums.rst | 333 +++++++++++++ docs/usage/index.rst | 2 +- gql/client.py | 137 ++++-- gql/utilities/__init__.py | 10 +- gql/utilities/parse_result.py | 446 ++++++++++++++++++ .../serialize_variable_values.py} | 18 +- gql/utilities/update_schema_enum.py | 69 +++ gql/utilities/update_schema_scalars.py | 60 ++- tests/conftest.py | 1 + ...om_scalar_datetime.py => test_datetime.py} | 20 +- tests/custom_scalars/test_enum_colors.py | 325 +++++++++++++ ...est_custom_scalar_json.py => test_json.py} | 2 +- ...t_custom_scalar_money.py => test_money.py} | 141 +++++- tests/starwars/test_parse_results.py | 191 ++++++++ tests/starwars/test_query.py | 2 +- tests/starwars/test_subscription.py | 4 +- tests/test_async_client_validation.py | 2 +- 19 files changed, 1682 insertions(+), 216 deletions(-) delete mode 100644 docs/usage/custom_scalars.rst create mode 100644 docs/usage/custom_scalars_and_enums.rst create mode 100644 gql/utilities/parse_result.py rename gql/{variable_values.py => utilities/serialize_variable_values.py} (86%) create mode 100644 gql/utilities/update_schema_enum.py rename tests/custom_scalars/{test_custom_scalar_datetime.py => test_datetime.py} (89%) create mode 100644 tests/custom_scalars/test_enum_colors.py rename tests/custom_scalars/{test_custom_scalar_json.py => test_json.py} (98%) rename tests/custom_scalars/{test_custom_scalar_money.py => test_money.py} (80%) create mode 100644 tests/starwars/test_parse_results.py diff --git a/README.md b/README.md index 8fefeb2f..a85761e1 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ The main features of GQL are: * Supports GraphQL queries, mutations and [subscriptions](https://gql.readthedocs.io/en/latest/usage/subscriptions.html) * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) +* Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries from the command line * [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically diff --git a/docs/usage/custom_scalars.rst b/docs/usage/custom_scalars.rst deleted file mode 100644 index baee441e..00000000 --- a/docs/usage/custom_scalars.rst +++ /dev/null @@ -1,134 +0,0 @@ -Custom Scalars -============== - -Scalar types represent primitive values at the leaves of a query. - -GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend -can add additional custom scalars to its schema to better express values in their data model. - -For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. - -The schema will then only contain: - -.. code-block:: python - - scalar Datetime - -When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), -their values need to be serialized to be composed -of only built-in scalars, then at the destination the serialized values will be parsed again to -be able to represent the scalar in its local internal representation. - -Because this serialization/unserialization is dependent on the language used at both sides, it is not -described in the schema and needs to be defined independently at both sides (client, backend). - -A custom scalar value can have two different representations during its transport: - - - as a serialized value (usually as json): - - * in the results sent by the backend - * in the variables sent by the client alongside the query - - - as "literal" inside the query itself sent by the client - -To define a custom scalar, you need 3 methods: - - - a :code:`serialize` method used: - - * by the backend to serialize a custom scalar output in the result - * by the client to serialize a custom scalar input in the variables - - - a :code:`parse_value` method used: - - * by the backend to unserialize custom scalars inputs in the variables sent by the client - * by the client to unserialize custom scalars outputs from the results - - - a :code:`parse_literal` method used: - - * by the backend to unserialize custom scalars inputs inside the query itself - -To define a custom scalar object, we define a :code:`GraphQLScalarType` from graphql-core with -its name and the implementation of the above methods. - -Example for Datetime: - -.. code-block:: python - - from datetime import datetime - from typing import Any, Dict, Optional - - from graphql import GraphQLScalarType, ValueNode - from graphql.utilities import value_from_ast_untyped - - - def serialize_datetime(value: Any) -> str: - return value.isoformat() - - - def parse_datetime_value(value: Any) -> datetime: - return datetime.fromisoformat(value) - - - def parse_datetime_literal( - value_node: ValueNode, variables: Optional[Dict[str, Any]] = None - ) -> datetime: - ast_value = value_from_ast_untyped(value_node, variables) - return parse_datetime_value(ast_value) - - - DatetimeScalar = GraphQLScalarType( - name="Datetime", - serialize=serialize_datetime, - parse_value=parse_datetime_value, - parse_literal=parse_datetime_literal, - ) - -Custom Scalars in inputs ------------------------- - -To provide custom scalars in input with gql, you can: - -- serialize the scalar yourself as "literal" in the query: - -.. code-block:: python - - query = gql( - """{ - shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) - }""" - ) - -- serialize the scalar yourself in a variable: - -.. code-block:: python - - query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - - variable_values = { - "time": "2021-11-12T11:58:13.461161", - } - - result = client.execute(query, variable_values=variable_values) - -- add a custom scalar to the schema with :func:`update_schema_scalars ` - and execute the query with :code:`serialize_variables=True` - and gql will serialize the variable values from a Python object representation. - -For this, you need to provide a schema or set :code:`fetch_schema_from_transport=True` -in the client to request the schema from the backend. - -.. code-block:: python - - from gql.utilities import update_schema_scalars - - async with Client(transport=transport, fetch_schema_from_transport=True) as session: - - update_schema_scalars(session.client.schema, [DatetimeScalar]) - - query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - - variable_values = {"time": datetime.now()} - - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) diff --git a/docs/usage/custom_scalars_and_enums.rst b/docs/usage/custom_scalars_and_enums.rst new file mode 100644 index 00000000..fc9008d8 --- /dev/null +++ b/docs/usage/custom_scalars_and_enums.rst @@ -0,0 +1,333 @@ +Custom scalars and enums +======================== + +.. _custom_scalars: + +Custom scalars +-------------- + +Scalar types represent primitive values at the leaves of a query. + +GraphQL provides a number of built-in scalars (Int, Float, String, Boolean and ID), but a GraphQL backend +can add additional custom scalars to its schema to better express values in their data model. + +For example, a schema can define the Datetime scalar to represent an ISO-8601 encoded date. + +The schema will then only contain:: + + scalar Datetime + +When custom scalars are sent to the backend (as inputs) or from the backend (as outputs), +their values need to be serialized to be composed +of only built-in scalars, then at the destination the serialized values will be parsed again to +be able to represent the scalar in its local internal representation. + +Because this serialization/unserialization is dependent on the language used at both sides, it is not +described in the schema and needs to be defined independently at both sides (client, backend). + +A custom scalar value can have two different representations during its transport: + + - as a serialized value (usually as json): + + * in the results sent by the backend + * in the variables sent by the client alongside the query + + - as "literal" inside the query itself sent by the client + +To define a custom scalar, you need 3 methods: + + - a :code:`serialize` method used: + + * by the backend to serialize a custom scalar output in the result + * by the client to serialize a custom scalar input in the variables + + - a :code:`parse_value` method used: + + * by the backend to unserialize custom scalars inputs in the variables sent by the client + * by the client to unserialize custom scalars outputs from the results + + - a :code:`parse_literal` method used: + + * by the backend to unserialize custom scalars inputs inside the query itself + +To define a custom scalar object, graphql-core provides the :code:`GraphQLScalarType` class +which contains the implementation of the above methods. + +Example for Datetime: + +.. code-block:: python + + from datetime import datetime + from typing import Any, Dict, Optional + + from graphql import GraphQLScalarType, ValueNode + from graphql.utilities import value_from_ast_untyped + + + def serialize_datetime(value: Any) -> str: + return value.isoformat() + + + def parse_datetime_value(value: Any) -> datetime: + return datetime.fromisoformat(value) + + + def parse_datetime_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None + ) -> datetime: + ast_value = value_from_ast_untyped(value_node, variables) + return parse_datetime_value(ast_value) + + + DatetimeScalar = GraphQLScalarType( + name="Datetime", + serialize=serialize_datetime, + parse_value=parse_datetime_value, + parse_literal=parse_datetime_literal, + ) + +If you get your schema from a "schema.graphql" file or from introspection, +then the generated schema in the gql Client will contain default :code:`GraphQLScalarType` instances +where the serialize and parse_value methods simply return the serialized value without modification. + +In that case, if you want gql to parse custom scalars to a more useful Python representation, +or to serialize custom scalars variables from a Python representation, +then you can use the :func:`update_schema_scalars ` +or :func:`update_schema_scalar ` methods +to modify the definition of a scalar in your schema so that gql could do the parsing/serialization. + +.. code-block:: python + + from gql.utilities import update_schema_scalar + + with open('path/to/schema.graphql') as f: + schema_str = f.read() + + client = Client(schema=schema_str, ...) + + update_schema_scalar(client.schema, "Datetime", DatetimeScalar) + + # or update_schema_scalars(client.schema, [DatetimeScalar]) + +.. _enums: + +Enums +----- + +GraphQL Enum types are a special kind of scalar that is restricted to a particular set of allowed values. + +For example, the schema may have a Color enum and contain:: + + enum Color { + RED + GREEN + BLUE + } + +Graphql-core provides the :code:`GraphQLEnumType` class to define an enum in the schema +(See `graphql-core schema building docs`_). + +This class defines how the enum is serialized and parsed. + +If you get your schema from a "schema.graphql" file or from introspection, +then the generated schema in the gql Client will contain default :code:`GraphQLEnumType` instances +which should serialize/parse enums to/from its String representation (the :code:`RED` enum +will be serialized to :code:`'RED'`). + +You may want to parse enums to convert them to Python Enum types. +In that case, you can use the :func:`update_schema_enum ` +to modify the default :code:`GraphQLEnumType` to use your defined Enum. + +Example: + +.. code-block:: python + + from enum import Enum + from gql.utilities import update_schema_enum + + class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + + with open('path/to/schema.graphql') as f: + schema_str = f.read() + + client = Client(schema=schema_str, ...) + + update_schema_enum(client.schema, 'Color', Color) + +Serializing Inputs +------------------ + +To provide custom scalars and/or enums in inputs with gql, you can: + +- serialize the inputs manually +- let gql serialize the inputs using the custom scalars and enums defined in the schema + +Manually +^^^^^^^^ + +You can serialize inputs yourself: + + - in the query itself + - in variables + +This has the advantage that you don't need a schema... + +In the query +"""""""""""" + +- custom scalar: + +.. code-block:: python + + query = gql( + """{ + shiftDays(time: "2021-11-12T11:58:13.461161", days: 5) + }""" + ) + +- enum: + +.. code-block:: python + + query = gql("{opposite(color: RED)}") + +In a variable +""""""""""""" + +- custom scalar: + +.. code-block:: python + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = { + "time": "2021-11-12T11:58:13.461161", + } + + result = client.execute(query, variable_values=variable_values) + +- enum: + +.. code-block:: python + + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + variable_values = { + "color": 'RED', + } + + result = client.execute(query, variable_values=variable_values) + +Automatically +^^^^^^^^^^^^^ + +If you have custom scalar and/or enums defined in your schema +(See: :ref:`custom_scalars` and :ref:`enums`), +then you can request gql to serialize your variables automatically. + +- use :code:`Client(..., serialize_variables=True)` to request serializing variables for all queries +- use :code:`execute(..., serialize_variables=True)` or :code:`subscribe(..., serialize_variables=True)` if + you want gql to serialize the variables only for a single query. + +Examples: + +- custom scalars: + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + from .myscalars import DatetimeScalar + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # We update the schema we got from introspection with our custom scalar type + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + # In the query, the custom scalar in the input is set to a variable + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + # the argument for time is a datetime instance + variable_values = {"time": datetime.now()} + + # we execute the query with serialize_variables set to True + result = await session.execute( + query, variable_values=variable_values, serialize_variables=True + ) + +- enums: + +.. code-block:: python + + from gql.utilities import update_schema_enum + + from .myenums import Color + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + # We update the schema we got from introspection with our custom enum + update_schema_enum(session.client.schema, 'Color', Color) + + # In the query, the enum in the input is set to a variable + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + # the argument for time is an instance of our Enum type + variable_values = { + "color": Color.RED, + } + + # we execute the query with serialize_variables set to True + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + +Parsing output +-------------- + +By default, gql returns the serialized result from the backend without parsing +(except json unserialization to Python default types). + +if you want to convert the result of custom scalars to custom objects, +you can request gql to parse the results. + +- use :code:`Client(..., parse_results=True)` to request parsing for all queries +- use :code:`execute(..., parse_result=True)` or :code:`subscribe(..., parse_result=True)` if + you want gql to parse only the result of a single query. + +Same example as above, with result parsing enabled: + +.. code-block:: python + + from gql.utilities import update_schema_scalars + + async with Client(transport=transport, fetch_schema_from_transport=True) as session: + + update_schema_scalars(session.client.schema, [DatetimeScalar]) + + query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") + + variable_values = {"time": datetime.now()} + + result = await session.execute( + query, + variable_values=variable_values, + serialize_variables=True, + parse_result=True, + ) + + # now result["time"] type is a datetime instead of string + +.. _graphql-core schema building docs: https://graphql-core-3.readthedocs.io/en/latest/usage/schema.html diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 4a38093a..eebf9fd2 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -10,4 +10,4 @@ Usage variables headers file_upload - custom_scalars + custom_scalars_and_enums diff --git a/gql/client.py b/gql/client.py index 368193cc..079bb552 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,7 +17,8 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport -from .variable_values import serialize_variable_values +from .utilities import parse_result as parse_result_fn +from .utilities import serialize_variable_values class Client: @@ -48,6 +49,8 @@ def __init__( transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[Union[int, float]] = 10, + serialize_variables: bool = False, + parse_results: bool = False, ): """Initialize the client with the given parameters. @@ -59,6 +62,10 @@ def __init__( :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. Passing None results in waiting forever for a response. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. Default: False. + :param parse_results: Whether gql will try to parse the serialized output + sent by the backend. Can be used to unserialize custom scalars or enums. """ assert not ( type_def and introspection @@ -108,6 +115,9 @@ def __init__( # Enforced timeout of the execute function (only for async transports) self.execute_timeout = execute_timeout + self.serialize_variables = serialize_variables + self.parse_results = parse_results + def validate(self, document: DocumentNode): """:meta private:""" assert ( @@ -296,7 +306,8 @@ def _execute( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, **kwargs, ) -> ExecutionResult: """Execute the provided document AST synchronously using @@ -307,6 +318,8 @@ def _execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -315,15 +328,18 @@ def _execute( self.client.validate(document) # Parse variable values for custom scalars if requested - if serialize_variables and variable_values is not None: - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + if variable_values is not None: + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) - return self.transport.execute( + result = self.transport.execute( document, *args, variable_values=variable_values, @@ -331,13 +347,26 @@ def _execute( **kwargs, ) + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + result.data = parse_result_fn( + self.client.schema, + document, + result.data, + operation_name=operation_name, + ) + + return result + def execute( self, document: DocumentNode, *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, **kwargs, ) -> Dict: """Execute the provided document AST synchronously using @@ -351,6 +380,8 @@ def execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -361,6 +392,7 @@ def execute( variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, + parse_result=parse_result, **kwargs, ) @@ -408,7 +440,8 @@ async def _subscribe( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: """Coroutine to subscribe asynchronously to the provided document AST @@ -423,6 +456,8 @@ async def _subscribe( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport subscribe method.""" @@ -431,13 +466,16 @@ async def _subscribe( self.client.validate(document) # Parse variable values for custom scalars if requested - if serialize_variables and variable_values is not None: - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + if variable_values is not None: + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) # Subscribe to the transport inner_generator: AsyncGenerator[ @@ -456,7 +494,20 @@ async def _subscribe( try: async for result in inner_generator: + + if self.client.schema: + if parse_result or ( + parse_result is None and self.client.parse_results + ): + result.data = parse_result_fn( + self.client.schema, + document, + result.data, + operation_name=operation_name, + ) + yield result + finally: await inner_generator.aclose() @@ -466,7 +517,8 @@ async def subscribe( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, **kwargs, ) -> AsyncGenerator[Dict, None]: """Coroutine to subscribe asynchronously to the provided document AST @@ -480,6 +532,8 @@ async def subscribe( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport subscribe method.""" @@ -489,6 +543,7 @@ async def subscribe( variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, + parse_result=parse_result, **kwargs, ) @@ -513,7 +568,8 @@ async def _execute( *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, **kwargs, ) -> ExecutionResult: """Coroutine to execute the provided document AST asynchronously using @@ -527,6 +583,8 @@ async def _execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -535,16 +593,19 @@ async def _execute( self.client.validate(document) # Parse variable values for custom scalars if requested - if serialize_variables and variable_values is not None: - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + if variable_values is not None: + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + variable_values = serialize_variable_values( + self.client.schema, + document, + variable_values, + operation_name=operation_name, + ) # Execute the query with the transport with a timeout - return await asyncio.wait_for( + result = await asyncio.wait_for( self.transport.execute( document, variable_values=variable_values, @@ -555,13 +616,26 @@ async def _execute( self.client.execute_timeout, ) + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + result.data = parse_result_fn( + self.client.schema, + document, + result.data, + operation_name=operation_name, + ) + + return result + async def execute( self, document: DocumentNode, *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - serialize_variables: bool = False, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, **kwargs, ) -> Dict: """Coroutine to execute the provided document AST asynchronously using @@ -575,6 +649,8 @@ async def execute( :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results attribute of the client. The extra arguments are passed to the transport execute method.""" @@ -585,6 +661,7 @@ async def execute( variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, + parse_result=parse_result, **kwargs, ) diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index 68b80156..d17f9b2d 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,5 +1,13 @@ -from .update_schema_scalars import update_schema_scalars +from .parse_result import parse_result +from .serialize_variable_values import serialize_value, serialize_variable_values +from .update_schema_enum import update_schema_enum +from .update_schema_scalars import update_schema_scalar, update_schema_scalars __all__ = [ "update_schema_scalars", + "update_schema_scalar", + "update_schema_enum", + "parse_result", + "serialize_variable_values", + "serialize_value", ] diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py new file mode 100644 index 00000000..ecb73474 --- /dev/null +++ b/gql/utilities/parse_result.py @@ -0,0 +1,446 @@ +import logging +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast + +from graphql import ( + IDLE, + REMOVE, + DocumentNode, + FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, + GraphQLError, + GraphQLInterfaceType, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLType, + InlineFragmentNode, + NameNode, + Node, + OperationDefinitionNode, + SelectionSetNode, + TypeInfo, + TypeInfoVisitor, + Visitor, + is_leaf_type, + print_ast, + visit, +) +from graphql.language.visitor import VisitorActionEnum +from graphql.pyutils import inspect + +log = logging.getLogger(__name__) + +# Equivalent to QUERY_DOCUMENT_KEYS but only for fields interesting to +# visit to parse the results +RESULT_DOCUMENT_KEYS: Dict[str, Tuple[str, ...]] = { + "document": ("definitions",), + "operation_definition": ("selection_set",), + "selection_set": ("selections",), + "field": ("selection_set",), + "inline_fragment": ("selection_set",), + "fragment_definition": ("selection_set",), +} + + +def _ignore_non_null(type_: GraphQLType): + """Removes the GraphQLNonNull wrappings around types.""" + if isinstance(type_, GraphQLNonNull): + return type_.of_type + else: + return type_ + + +def _get_fragment(document, fragment_name): + """Returns a fragment from the document.""" + for definition in document.definitions: + if isinstance(definition, FragmentDefinitionNode): + if definition.name.value == fragment_name: + return definition + + raise GraphQLError(f'Fragment "{fragment_name}" not found in document!') + + +class ParseResultVisitor(Visitor): + def __init__( + self, + schema: GraphQLSchema, + document: DocumentNode, + node: Node, + result: Dict[str, Any], + type_info: TypeInfo, + visit_fragment: bool = False, + inside_list_level: int = 0, + operation_name: Optional[str] = None, + ): + """Recursive Implementation of a Visitor class to parse results + correspondind to a schema and a document. + + Using a TypeInfo class to get the node types during traversal. + + If we reach a list in the results, then we parse each + item of the list recursively, traversing the same nodes + of the query again. + + During traversal, we keep the current position in the result + in the result_stack field. + + Alongside the field type, we calculate the "result type" + which is computed from the field type and the current + recursive level we are for this field + (:code:`inside_list_level` argument). + """ + self.schema: GraphQLSchema = schema + self.document: DocumentNode = document + self.node: Node = node + self.result: Dict[str, Any] = result + self.type_info: TypeInfo = type_info + self.visit_fragment: bool = visit_fragment + self.inside_list_level = inside_list_level + self.operation_name = operation_name + + self.result_stack: List[Any] = [] + + @property + def current_result(self): + try: + return self.result_stack[-1] + except IndexError: + return self.result + + @staticmethod + def leave_document(node: DocumentNode, *_args: Any) -> Dict[str, Any]: + results = cast(List[Dict[str, Any]], node.definitions) + return {k: v for result in results for k, v in result.items()} + + def enter_operation_definition( + self, node: OperationDefinitionNode, *_args: Any + ) -> Union[None, VisitorActionEnum]: + + if self.operation_name is not None: + if not hasattr(node.name, "value"): + return REMOVE # pragma: no cover + + node.name = cast(NameNode, node.name) + + if node.name.value != self.operation_name: + log.debug(f"SKIPPING operation {node.name.value}") + return REMOVE + + return IDLE + + @staticmethod + def leave_operation_definition( + node: OperationDefinitionNode, *_args: Any + ) -> Dict[str, Any]: + selections = cast(List[Dict[str, Any]], node.selection_set) + return {k: v for s in selections for k, v in s.items()} + + @staticmethod + def leave_selection_set(node: SelectionSetNode, *_args: Any) -> Dict[str, Any]: + partial_results = cast(Dict[str, Any], node.selections) + return partial_results + + @staticmethod + def in_first_field(path): + return path.count("selections") <= 1 + + def get_current_result_type(self, path): + field_type = self.type_info.get_type() + + list_level = self.inside_list_level + + result_type = _ignore_non_null(field_type) + + if self.in_first_field(path): + + while list_level > 0: + assert isinstance(result_type, GraphQLList) + result_type = _ignore_non_null(result_type.of_type) + + list_level -= 1 + + return result_type + + def enter_field( + self, + node: FieldNode, + key: str, + parent: Node, + path: List[Node], + ancestors: List[Node], + ) -> Union[None, VisitorActionEnum, Dict[str, Any]]: + + name = node.alias.value if node.alias else node.name.value + + if log.isEnabledFor(logging.DEBUG): + log.debug(f"Enter field {name}") + log.debug(f" path={path!r}") + log.debug(f" current_result={self.current_result!r}") + + if self.current_result is None: + # Result was null for this field -> remove + return REMOVE + + elif isinstance(self.current_result, Mapping): + + try: + result_value = self.current_result[name] + except KeyError: + # Key not found in result. + # Should never happen in theory with a correct GraphQL backend + # Silently ignoring this field + log.debug(f"Key {name} not found in result --> REMOVE") + return REMOVE + + log.debug(f" result_value={result_value}") + + # We get the field_type from type_info + field_type = self.type_info.get_type() + + # We calculate a virtual "result type" depending on our recursion level. + result_type = self.get_current_result_type(path) + + # If the result for this field is a list, then we need + # to recursively visit the same node multiple times for each + # item in the list. + if ( + not isinstance(result_value, Mapping) + and isinstance(result_value, Iterable) + and not isinstance(result_value, str) + and not is_leaf_type(result_type) + ): + + # Finding out the inner type of the list + inner_type = _ignore_non_null(result_type.of_type) + + if log.isEnabledFor(logging.DEBUG): + log.debug(" List detected:") + log.debug(f" field_type={inspect(field_type)}") + log.debug(f" result_type={inspect(result_type)}") + log.debug(f" inner_type={inspect(inner_type)}\n") + + visits: List[Dict[str, Any]] = [] + + # Get parent type + initial_type = self.type_info.get_parent_type() + assert isinstance( + initial_type, (GraphQLObjectType, GraphQLInterfaceType) + ) + + # Get parent SelectionSet node + new_node = ancestors[-1] + assert isinstance(new_node, SelectionSetNode) + + for item in result_value: + + new_result = {name: item} + + if log.isEnabledFor(logging.DEBUG): + log.debug(f" recursive new_result={new_result}") + log.debug(f" recursive ast={print_ast(node)}") + log.debug(f" recursive path={path!r}") + log.debug(f" recursive initial_type={initial_type!r}\n") + + if self.in_first_field(path): + inside_list_level = self.inside_list_level + 1 + else: + inside_list_level = 1 + + inner_visit = parse_result_recursive( + self.schema, + self.document, + new_node, + new_result, + initial_type=initial_type, + inside_list_level=inside_list_level, + ) + log.debug(f" recursive result={inner_visit}\n") + + inner_visit = cast(List[Dict[str, Any]], inner_visit) + visits.append(inner_visit[0][name]) + + result_value = {name: visits} + log.debug(f" recursive visits final result = {result_value}\n") + return result_value + + # If the result for this field is not a list, then add it + # to the result stack so that it becomes the current_value + # for the next inner fields + self.result_stack.append(result_value) + + return IDLE + + raise GraphQLError( + f"Invalid result for container of field {name}: {self.current_result!r}" + ) + + def leave_field( + self, + node: FieldNode, + key: str, + parent: Node, + path: List[Node], + ancestors: List[Node], + ) -> Dict[str, Any]: + + name = cast(str, node.alias.value if node.alias else node.name.value) + + log.debug(f"Leave field {name}") + + if self.current_result is None: + + log.debug(f"Leave field {name}: returning None") + return {name: None} + + elif node.selection_set is None: + + field_type = self.type_info.get_type() + result_type = self.get_current_result_type(path) + + if log.isEnabledFor(logging.DEBUG): + log.debug(f" field type of {name} is {inspect(field_type)}") + log.debug(f" result type of {name} is {inspect(result_type)}") + + assert is_leaf_type(result_type) + + # Finally parsing a single scalar using the parse_value method + parsed_value = result_type.parse_value(self.current_result) + + return_value = {name: parsed_value} + else: + + partial_results = cast(List[Dict[str, Any]], node.selection_set) + + return_value = { + name: {k: v for pr in partial_results for k, v in pr.items()} + } + + # Go up a level in the result stack + self.result_stack.pop() + + log.debug(f"Leave field {name}: returning {return_value}") + + return return_value + + # Fragments + + def enter_fragment_definition( + self, node: FragmentDefinitionNode, *_args: Any + ) -> Union[None, VisitorActionEnum]: + + if log.isEnabledFor(logging.DEBUG): + log.debug(f"Enter fragment definition {node.name.value}.") + log.debug(f"visit_fragment={self.visit_fragment!s}") + + if self.visit_fragment: + return IDLE + else: + return REMOVE + + @staticmethod + def leave_fragment_definition( + node: FragmentDefinitionNode, *_args: Any + ) -> Dict[str, Any]: + + selections = cast(List[Dict[str, Any]], node.selection_set) + return {k: v for s in selections for k, v in s.items()} + + def leave_fragment_spread( + self, node: FragmentSpreadNode, *_args: Any + ) -> Dict[str, Any]: + + fragment_name = node.name.value + + log.debug(f"Start recursive fragment visit {fragment_name}") + + fragment_node = _get_fragment(self.document, fragment_name) + + fragment_result = parse_result_recursive( + self.schema, + self.document, + fragment_node, + self.current_result, + visit_fragment=True, + ) + + log.debug( + f"Result of recursive fragment visit {fragment_name}: {fragment_result}" + ) + + return cast(Dict[str, Any], fragment_result) + + @staticmethod + def leave_inline_fragment(node: InlineFragmentNode, *_args: Any) -> Dict[str, Any]: + + selections = cast(List[Dict[str, Any]], node.selection_set) + return {k: v for s in selections for k, v in s.items()} + + +def parse_result_recursive( + schema: GraphQLSchema, + document: DocumentNode, + node: Node, + result: Optional[Dict[str, Any]], + initial_type: Optional[GraphQLType] = None, + inside_list_level: int = 0, + visit_fragment: bool = False, + operation_name: Optional[str] = None, +) -> Any: + + if result is None: + return None + + type_info = TypeInfo(schema, initial_type=initial_type) + + visited = visit( + node, + TypeInfoVisitor( + type_info, + ParseResultVisitor( + schema, + document, + node, + result, + type_info=type_info, + inside_list_level=inside_list_level, + visit_fragment=visit_fragment, + operation_name=operation_name, + ), + ), + visitor_keys=RESULT_DOCUMENT_KEYS, + ) + + return visited + + +def parse_result( + schema: GraphQLSchema, + document: DocumentNode, + result: Optional[Dict[str, Any]], + operation_name: Optional[str] = None, +) -> Optional[Dict[str, Any]]: + """Unserialize a result received from a GraphQL backend. + + :param schema: the GraphQL schema + :param document: the document representing the query sent to the backend + :param result: the serialized result received from the backend + :param operation_name: the optional operation name + + :returns: a parsed result with scalars and enums parsed depending on + their definition in the schema. + + Given a schema, a query and a serialized result, + provide a new result with parsed values. + + If the result contains only built-in GraphQL scalars (String, Int, Float, ...) + then the parsed result should be unchanged. + + If the result contains custom scalars or enums, then those values + will be parsed with the parse_value method of the custom scalar or enum + definition in the schema.""" + + return parse_result_recursive( + schema, document, document, result, operation_name=operation_name + ) diff --git a/gql/variable_values.py b/gql/utilities/serialize_variable_values.py similarity index 86% rename from gql/variable_values.py rename to gql/utilities/serialize_variable_values.py index 7db7091a..833df8bd 100644 --- a/gql/variable_values.py +++ b/gql/utilities/serialize_variable_values.py @@ -17,7 +17,7 @@ from graphql.pyutils import inspect -def get_document_operation( +def _get_document_operation( document: DocumentNode, operation_name: Optional[str] = None ) -> OperationDefinitionNode: """Returns the operation which should be executed in the document. @@ -53,7 +53,13 @@ def get_document_operation( def serialize_value(type_: GraphQLType, value: Any) -> Any: """Given a GraphQL type and a Python value, return the serialized value. + This method will serialize the value recursively, entering into + lists and dicts. + Can be used to serialize Enums and/or Custom Scalars in variable values. + + :param type_: the GraphQL type + :param value: the provided value """ if value is None: @@ -93,13 +99,19 @@ def serialize_variable_values( """Given a GraphQL document and a schema, serialize the Dictionary of variable values. - Useful to serialize Enums and/or Custom Scalars in variable values + Useful to serialize Enums and/or Custom Scalars in variable values. + + :param schema: the GraphQL schema + :param document: the document representing the query sent to the backend + :param variable_values: the dictionnary of variable values which needs + to be serialized. + :param operation_name: the optional operation_name for the query. """ parsed_variable_values: Dict[str, Any] = {} # Find the operation in the document - operation = get_document_operation(document, operation_name=operation_name) + operation = _get_document_operation(document, operation_name=operation_name) # Serialize every variable value defined for the operation for var_def_node in operation.variable_definitions: diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py new file mode 100644 index 00000000..80c73862 --- /dev/null +++ b/gql/utilities/update_schema_enum.py @@ -0,0 +1,69 @@ +from enum import Enum +from typing import Any, Dict, Mapping, Type, Union, cast + +from graphql import GraphQLEnumType, GraphQLSchema + + +def update_schema_enum( + schema: GraphQLSchema, + name: str, + values: Union[Dict[str, Any], Type[Enum]], + use_enum_values: bool = False, +): + """Update in the schema the GraphQLEnumType corresponding to the given name. + + Example:: + + from enum import Enum + + class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + + update_schema_enum(schema, 'Color', Color) + + :param schema: a GraphQL Schema already containing the GraphQLEnumType type. + :param name: the name of the enum in the GraphQL schema + :param values: Either a Python Enum or a dict of values. The keys of the provided + values should correspond to the keys of the existing enum in the schema. + :param use_enum_values: By default, we configure the GraphQLEnumType to serialize + to enum instances (ie: .parse_value() returns Color.RED). + If use_enum_values is set to True, then .parse_value() returns 0. + use_enum_values=True is the defaut behaviour when passing an Enum + to a GraphQLEnumType. + """ + + # Convert Enum values to Dict + if isinstance(values, type): + if issubclass(values, Enum): + values = cast(Type[Enum], values) + if use_enum_values: + values = {enum.name: enum.value for enum in values} + else: + values = {enum.name: enum for enum in values} + + if not isinstance(values, Mapping): + raise TypeError(f"Invalid type for enum values: {type(values)}") + + # Find enum type in schema + schema_enum = schema.get_type(name) + + if schema_enum is None: + raise KeyError(f"Enum {name} not found in schema!") + + if not isinstance(schema_enum, GraphQLEnumType): + raise TypeError( + f'The type "{name}" is not a GraphQLEnumType, it is a {type(schema_enum)}' + ) + + # Replace all enum values + for enum_name, enum_value in schema_enum.values.items(): + try: + enum_value.value = values[enum_name] + except KeyError: + raise KeyError(f'Enum key "{enum_name}" not found in provided values!') + + # Delete the _value_lookup cached property + if "_value_lookup" in schema_enum.__dict__: + del schema_enum.__dict__["_value_lookup"] diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index d5434c6b..db3adb17 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -1,32 +1,60 @@ from typing import Iterable, List -from graphql import GraphQLError, GraphQLScalarType, GraphQLSchema +from graphql import GraphQLScalarType, GraphQLSchema + + +def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): + """Update the scalar in a schema with the scalar provided. + + :param schema: the GraphQL schema + :param name: the name of the custom scalar type in the schema + :param scalar: a provided scalar type + + This can be used to update the default Custom Scalar implementation + when the schema has been provided from a text file or from introspection. + """ + + if not isinstance(scalar, GraphQLScalarType): + raise TypeError("Scalars should be instances of GraphQLScalarType.") + + schema_scalar = schema.get_type(name) + + if schema_scalar is None: + raise KeyError(f"Scalar '{name}' not found in schema.") + + if not isinstance(schema_scalar, GraphQLScalarType): + raise TypeError( + f'The type "{name}" is not a GraphQLScalarType,' + f" it is a {type(schema_scalar)}" + ) + + # Update the conversion methods + # Using setattr because mypy has a false positive + # https://github.com/python/mypy/issues/2427 + setattr(schema_scalar, "serialize", scalar.serialize) + setattr(schema_scalar, "parse_value", scalar.parse_value) + setattr(schema_scalar, "parse_literal", scalar.parse_literal) def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): """Update the scalars in a schema with the scalars provided. + :param schema: the GraphQL schema + :param scalars: a list of provided scalar types + This can be used to update the default Custom Scalar implementation when the schema has been provided from a text file or from introspection. + + If the name of the provided scalar is different than the name of + the custom scalar, then you should use the + :func:`update_schema_scalar ` method instead. """ if not isinstance(scalars, Iterable): - raise GraphQLError("Scalars argument should be a list of scalars.") + raise TypeError("Scalars argument should be a list of scalars.") for scalar in scalars: if not isinstance(scalar, GraphQLScalarType): - raise GraphQLError("Scalars should be instances of GraphQLScalarType.") - - try: - schema_scalar = schema.type_map[scalar.name] - except KeyError: - raise GraphQLError(f"Scalar '{scalar.name}' not found in schema.") - - assert isinstance(schema_scalar, GraphQLScalarType) + raise TypeError("Scalars should be instances of GraphQLScalarType.") - # Update the conversion methods - # Using setattr because mypy has a false positive - # https://github.com/python/mypy/issues/2427 - setattr(schema_scalar, "serialize", scalar.serialize) - setattr(schema_scalar, "parse_value", scalar.parse_value) - setattr(schema_scalar, "parse_literal", scalar.parse_literal) + update_schema_scalar(schema, scalar.name, scalar) diff --git a/tests/conftest.py b/tests/conftest.py index 004fa9df..519738cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -105,6 +105,7 @@ async def go(app, *, port=None, **kwargs): # type: ignore "gql.transport.websockets", "gql.transport.phoenix_channel_websockets", "gql.dsl", + "gql.utilities.parse_result", ]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) diff --git a/tests/custom_scalars/test_custom_scalar_datetime.py b/tests/custom_scalars/test_datetime.py similarity index 89% rename from tests/custom_scalars/test_custom_scalar_datetime.py rename to tests/custom_scalars/test_datetime.py index 25c6bb31..169ce076 100644 --- a/tests/custom_scalars/test_custom_scalar_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -112,7 +112,7 @@ def resolve_seconds(root, _info, interval): ) def test_shift_days(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True, serialize_variables=True) now = datetime.fromisoformat("2021-11-12T11:58:13.461161") @@ -122,13 +122,11 @@ def test_shift_days(): "time": now, } - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, variable_values=variable_values) print(result) - assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") @pytest.mark.skipif( @@ -144,11 +142,11 @@ def test_shift_days_serialized_manually_in_query(): }""" ) - result = client.execute(query) + result = client.execute(query, parse_result=True) print(result) - assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") @pytest.mark.skipif( @@ -156,7 +154,7 @@ def test_shift_days_serialized_manually_in_query(): ) def test_shift_days_serialized_manually_in_variables(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") @@ -168,7 +166,7 @@ def test_shift_days_serialized_manually_in_variables(): print(result) - assert result["shiftDays"] == "2021-11-17T11:58:13.461161" + assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") @pytest.mark.skipif( @@ -176,7 +174,7 @@ def test_shift_days_serialized_manually_in_variables(): ) def test_latest(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) now = datetime.fromisoformat("2021-11-12T11:58:13.461161") in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") @@ -193,7 +191,7 @@ def test_latest(): print(result) - assert result["latest"] == in_five_days.isoformat() + assert result["latest"] == in_five_days @pytest.mark.skipif( diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py new file mode 100644 index 00000000..2c7b887c --- /dev/null +++ b/tests/custom_scalars/test_enum_colors.py @@ -0,0 +1,325 @@ +from enum import Enum + +import pytest +from graphql import ( + GraphQLArgument, + GraphQLEnumType, + GraphQLField, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, +) + +from gql import Client, gql +from gql.utilities import update_schema_enum + + +class Color(Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + YELLOW = 3 + CYAN = 4 + MAGENTA = 5 + + +RED = Color.RED +GREEN = Color.GREEN +BLUE = Color.BLUE +YELLOW = Color.YELLOW +CYAN = Color.CYAN +MAGENTA = Color.MAGENTA + +ALL_COLORS = [c for c in Color] + +ColorType = GraphQLEnumType("Color", {c.name: c for c in Color}) + + +def resolve_opposite(_root, _info, color): + opposite_colors = { + RED: CYAN, + GREEN: MAGENTA, + BLUE: YELLOW, + YELLOW: BLUE, + CYAN: RED, + MAGENTA: GREEN, + } + + return opposite_colors[color] + + +def resolve_all(_root, _info): + return ALL_COLORS + + +list_of_list_of_list = [[[RED, GREEN], [GREEN, BLUE]], [[YELLOW, CYAN], [MAGENTA, RED]]] + + +def resolve_list_of_list_of_list(_root, _info): + return list_of_list_of_list + + +def resolve_list_of_list(_root, _info): + return list_of_list_of_list[0] + + +def resolve_list(_root, _info): + return list_of_list_of_list[0][0] + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "all": GraphQLField(GraphQLList(ColorType), resolve=resolve_all,), + "opposite": GraphQLField( + ColorType, + args={"color": GraphQLArgument(ColorType)}, + resolve=resolve_opposite, + ), + "list_of_list_of_list": GraphQLField( + GraphQLNonNull( + GraphQLList( + GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLList(ColorType)))) + ) + ), + resolve=resolve_list_of_list_of_list, + ), + "list_of_list": GraphQLField( + GraphQLNonNull(GraphQLList(GraphQLNonNull(GraphQLList(ColorType)))), + resolve=resolve_list_of_list, + ), + "list": GraphQLField( + GraphQLNonNull(GraphQLList(ColorType)), resolve=resolve_list, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +def test_parse_value_enum(): + + result = ColorType.parse_value("RED") + + print(result) + + assert isinstance(result, Color) + assert result is RED + + +def test_serialize_enum(): + + result = ColorType.serialize(RED) + + print(result) + + assert result == "RED" + + +def test_get_all_colors(): + + query = gql("{all}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + all_colors = result["all"] + + assert all_colors == ALL_COLORS + + +def test_opposite_color_literal(): + + client = Client(schema=schema, parse_results=True) + + query = gql("{opposite(color: RED)}") + + result = client.execute(query) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN + + +def test_opposite_color_variable_serialized_manually(): + + client = Client(schema=schema, parse_results=True) + + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + variable_values = { + "color": "RED", + } + + result = client.execute(query, variable_values=variable_values) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN + + +def test_opposite_color_variable_serialized_by_gql(): + + client = Client(schema=schema, parse_results=True) + + query = gql( + """ + query GetOppositeColor($color: Color) { + opposite(color:$color) + }""" + ) + + variable_values = { + "color": RED, + } + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN + + +def test_list(): + + query = gql("{list}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + big_list = result["list"] + + assert big_list == list_of_list_of_list[0][0] + + +def test_list_of_list(): + + query = gql("{list_of_list}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + big_list = result["list_of_list"] + + assert big_list == list_of_list_of_list[0] + + +def test_list_of_list_of_list(): + + query = gql("{list_of_list_of_list}") + + client = Client(schema=schema, parse_results=True) + + result = client.execute(query) + + print(result) + + big_list = result["list_of_list_of_list"] + + assert big_list == list_of_list_of_list + + +def test_update_schema_enum(): + + assert schema.get_type("Color").parse_value("RED") == Color.RED + + # Using values + + update_schema_enum(schema, "Color", Color, use_enum_values=True) + + assert schema.get_type("Color").parse_value("RED") == 0 + assert schema.get_type("Color").serialize(1) == "GREEN" + + update_schema_enum(schema, "Color", Color) + + assert schema.get_type("Color").parse_value("RED") == Color.RED + assert schema.get_type("Color").serialize(Color.RED) == "RED" + + +def test_update_schema_enum_errors(): + + with pytest.raises(KeyError) as exc_info: + update_schema_enum(schema, "Corlo", Color) + + assert "Enum Corlo not found in schema!" in str(exc_info) + + with pytest.raises(TypeError) as exc_info: + update_schema_enum(schema, "Color", 6) + + assert "Invalid type for enum values: " in str(exc_info) + + with pytest.raises(TypeError) as exc_info: + update_schema_enum(schema, "RootQueryType", Color) + + assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str(exc_info) + + with pytest.raises(KeyError) as exc_info: + update_schema_enum(schema, "Color", {"RED": Color.RED}) + + assert 'Enum key "GREEN" not found in provided values!' in str(exc_info) + + +def test_parse_results_with_operation_type(): + + client = Client(schema=schema, parse_results=True) + + query = gql( + """ + query GetAll { + all + } + query GetOppositeColor($color: Color) { + opposite(color:$color) + } + query GetOppositeColor2($color: Color) { + other_opposite:opposite(color:$color) + } + query GetOppositeColor3 { + opposite(color: YELLOW) + } + query GetListOfListOfList { + list_of_list_of_list + } + """ + ) + + variable_values = { + "color": "RED", + } + + result = client.execute( + query, variable_values=variable_values, operation_name="GetOppositeColor" + ) + + print(result) + + opposite_color = result["opposite"] + + assert isinstance(opposite_color, Color) + assert opposite_color == CYAN diff --git a/tests/custom_scalars/test_custom_scalar_json.py b/tests/custom_scalars/test_json.py similarity index 98% rename from tests/custom_scalars/test_custom_scalar_json.py rename to tests/custom_scalars/test_json.py index 80f99850..9659d0a5 100644 --- a/tests/custom_scalars/test_custom_scalar_json.py +++ b/tests/custom_scalars/test_json.py @@ -94,7 +94,7 @@ def resolve_add_player(root, _info, player): def test_json_value_output(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) query = gql("query {players}") diff --git a/tests/custom_scalars/test_custom_scalar_money.py b/tests/custom_scalars/test_money.py similarity index 80% rename from tests/custom_scalars/test_custom_scalar_money.py rename to tests/custom_scalars/test_money.py index 238308a9..1b65ec98 100644 --- a/tests/custom_scalars/test_custom_scalar_money.py +++ b/tests/custom_scalars/test_money.py @@ -11,6 +11,7 @@ GraphQLField, GraphQLFloat, GraphQLInt, + GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, @@ -20,8 +21,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportQueryError -from gql.utilities import update_schema_scalars -from gql.variable_values import serialize_value +from gql.utilities import serialize_value, update_schema_scalar, update_schema_scalars from ..conftest import MS @@ -82,9 +82,34 @@ def parse_money_literal( parse_literal=parse_money_literal, ) +root_value = { + "balance": Money(42, "DM"), + "friends_balance": [Money(12, "EUR"), Money(24, "EUR"), Money(150, "DM")], + "countries_balance": { + "Belgium": Money(15000, "EUR"), + "Luxembourg": Money(99999, "EUR"), + }, +} + def resolve_balance(root, _info): - return root + return root["balance"] + + +def resolve_friends_balance(root, _info): + return root["friends_balance"] + + +def resolve_countries_balance(root, _info): + return root["countries_balance"] + + +def resolve_belgium_balance(countries_balance, _info): + return countries_balance["Belgium"] + + +def resolve_luxembourg_balance(countries_balance, _info): + return countries_balance["Luxembourg"] def resolve_to_euros(_root, _info, money): @@ -97,6 +122,18 @@ def resolve_to_euros(_root, _info, money): raise ValueError("Cannot convert to euros: " + inspect(money)) +countriesBalance = GraphQLObjectType( + name="CountriesBalance", + fields={ + "Belgium": GraphQLField( + GraphQLNonNull(MoneyScalar), resolve=resolve_belgium_balance + ), + "Luxembourg": GraphQLField( + GraphQLNonNull(MoneyScalar), resolve=resolve_luxembourg_balance + ), + }, +) + queryType = GraphQLObjectType( name="RootQueryType", fields={ @@ -106,6 +143,12 @@ def resolve_to_euros(_root, _info, money): args={"money": GraphQLArgument(MoneyScalar)}, resolve=resolve_to_euros, ), + "friends_balance": GraphQLField( + GraphQLList(MoneyScalar), resolve=resolve_friends_balance + ), + "countries_balance": GraphQLField( + GraphQLNonNull(countriesBalance), resolve=resolve_countries_balance, + ), }, ) @@ -133,14 +176,12 @@ async def subscribe_spend_all(_root, _info, money): }, ) -root_value = Money(42, "DM") - schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) def test_custom_scalar_in_output(): - client = Client(schema=schema) + client = Client(schema=schema, parse_results=True) query = gql("{balance}") @@ -148,7 +189,53 @@ def test_custom_scalar_in_output(): print(result) - assert result["balance"] == serialize_money(root_value) + assert result["balance"] == root_value["balance"] + + +def test_custom_scalar_in_output_embedded_fragments(): + + client = Client(schema=schema, parse_results=True) + + query = gql( + """ + fragment LuxMoneyInternal on CountriesBalance { + ... on CountriesBalance { + Luxembourg + } + } + query { + countries_balance { + Belgium + ...LuxMoney + } + } + fragment LuxMoney on CountriesBalance { + ...LuxMoneyInternal + } + """ + ) + + result = client.execute(query, root_value=root_value) + + print(result) + + belgium_money = result["countries_balance"]["Belgium"] + assert belgium_money == Money(15000, "EUR") + luxembourg_money = result["countries_balance"]["Luxembourg"] + assert luxembourg_money == Money(99999, "EUR") + + +def test_custom_scalar_list_in_output(): + + client = Client(schema=schema, parse_results=True) + + query = gql("{friends_balance}") + + result = client.execute(query, root_value=root_value) + + print(result) + + assert result["friends_balance"] == root_value["friends_balance"] def test_custom_scalar_in_input_query(): @@ -301,16 +388,18 @@ def test_custom_scalar_subscribe_in_input_variable_values_serialized(): variable_values = {"money": money_value} - expected_result = {"spend": {"amount": 10, "currency": "DM"}} + expected_result = {"spend": Money(10, "DM")} for result in client.subscribe( query, variable_values=variable_values, root_value=root_value, serialize_variables=True, + parse_result=True, ): print(f"result = {result!r}") - expected_result["spend"]["amount"] = expected_result["spend"]["amount"] - 1 + assert isinstance(result["spend"], Money) + expected_result["spend"] = Money(expected_result["spend"].amount - 1, "DM") assert expected_result == result @@ -385,7 +474,7 @@ async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server print(result) - assert result["balance"] == serialize_money(root_value) + assert result["balance"] == serialize_money(root_value["balance"]) @pytest.mark.asyncio @@ -533,7 +622,8 @@ async def test_update_schema_scalars(event_loop, aiohttp_server): # Update the schema MoneyScalar default implementation from # introspection with our provided conversion methods - update_schema_scalars(session.client.schema, [MoneyScalar]) + # update_schema_scalars(session.client.schema, [MoneyScalar]) + update_schema_scalar(session.client.schema, "Money", MoneyScalar) query = gql("query myquery($money: Money) {toEuros(money: $money)}") @@ -549,17 +639,24 @@ async def test_update_schema_scalars(event_loop, aiohttp_server): def test_update_schema_scalars_invalid_scalar(): - with pytest.raises(GraphQLError) as exc_info: + with pytest.raises(TypeError) as exc_info: update_schema_scalars(schema, [int]) exception = exc_info.value assert str(exception) == "Scalars should be instances of GraphQLScalarType." + with pytest.raises(TypeError) as exc_info: + update_schema_scalar(schema, "test", int) + + exception = exc_info.value + + assert str(exception) == "Scalars should be instances of GraphQLScalarType." + def test_update_schema_scalars_invalid_scalar_argument(): - with pytest.raises(GraphQLError) as exc_info: + with pytest.raises(TypeError) as exc_info: update_schema_scalars(schema, MoneyScalar) exception = exc_info.value @@ -571,12 +668,24 @@ def test_update_schema_scalars_scalar_not_found_in_schema(): NotFoundScalar = GraphQLScalarType(name="abcd",) - with pytest.raises(GraphQLError) as exc_info: + with pytest.raises(KeyError) as exc_info: update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) exception = exc_info.value - assert str(exception) == "Scalar 'abcd' not found in schema." + assert "Scalar 'abcd' not found in schema." in str(exception) + + +def test_update_schema_scalars_scalar_type_is_not_a_scalar_in_schema(): + + with pytest.raises(TypeError) as exc_info: + update_schema_scalar(schema, "CountriesBalance", MoneyScalar) + + exception = exc_info.value + + assert 'The type "CountriesBalance" is not a GraphQLScalarType, it is a' in str( + exception + ) @pytest.mark.asyncio @@ -588,7 +697,7 @@ async def test_custom_scalar_serialize_variables_sync_transport( server, transport = await make_sync_money_transport(aiohttp_server) def test_code(): - with Client(schema=schema, transport=transport,) as session: + with Client(schema=schema, transport=transport, parse_results=True) as session: query = gql("query myquery($money: Money) {toEuros(money: $money)}") diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py new file mode 100644 index 00000000..23073839 --- /dev/null +++ b/tests/starwars/test_parse_results.py @@ -0,0 +1,191 @@ +import pytest +from graphql import GraphQLError + +from gql import gql +from gql.utilities import parse_result +from tests.starwars.schema import StarWarsSchema + + +def test_hero_name_and_friends_query(): + query = gql( + """ + query HeroNameAndFriendsQuery { + hero { + id + friends { + name + } + name + } + } + """ + ) + result = { + "hero": { + "id": "2001", + "friends": [ + {"name": "Luke Skywalker"}, + {"name": "Han Solo"}, + {"name": "Leia Organa"}, + ], + "name": "R2-D2", + } + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_key_not_found_in_result(): + + query = gql( + """ + { + hero { + id + } + } + """ + ) + + # Backend returned an invalid result without the hero key + # Should be impossible. In that case, we ignore the missing key + result = {} + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_invalid_result_raise_error(): + + query = gql( + """ + { + hero { + id + } + } + """ + ) + + result = {"hero": 5} + + with pytest.raises(GraphQLError) as exc_info: + + parse_result(StarWarsSchema, query, result) + + assert "Invalid result for container of field id: 5" in str(exc_info) + + +def test_fragment(): + + query = gql( + """ + query UseFragment { + luke: human(id: "1000") { + ...HumanFragment + } + leia: human(id: "1003") { + ...HumanFragment + } + } + fragment HumanFragment on Human { + name + homePlanet + } + """ + ) + + result = { + "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, + "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_fragment_not_found(): + + query = gql( + """ + query UseFragment { + luke: human(id: "1000") { + ...HumanFragment + } + } + """ + ) + + result = { + "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, + } + + with pytest.raises(GraphQLError) as exc_info: + + parse_result(StarWarsSchema, query, result) + + assert 'Fragment "HumanFragment" not found in document!' in str(exc_info) + + +def test_return_none_if_result_is_none(): + + query = gql( + """ + query { + hero { + id + } + } + """ + ) + + result = None + + assert parse_result(StarWarsSchema, query, result) is None + + +def test_null_result_is_allowed(): + + query = gql( + """ + query { + hero { + id + } + } + """ + ) + + result = {"hero": None} + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + +def test_inline_fragment(): + + query = gql( + """ + query UseFragment { + luke: human(id: "1000") { + ... on Human { + name + homePlanet + } + } + } + """ + ) + + result = { + "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 62890222..520018c1 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -107,7 +107,7 @@ def test_nested_query(client): ], } } - result = client.execute(query) + result = client.execute(query, parse_result=False) assert result == expected diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 3753ab2f..2516701f 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -53,7 +53,9 @@ async def test_subscription_support_using_client(): async with Client(schema=StarWarsSchema) as session: results = [ result["reviewAdded"] - async for result in session.subscribe(subs, variable_values=params) + async for result in session.subscribe( + subs, variable_values=params, parse_result=False + ) ] assert results == expected diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index 1402aa59..107bd6c2 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -112,7 +112,7 @@ async def test_async_client_validation( expected = [] async for result in session.subscribe( - subscription, variable_values=variable_values + subscription, variable_values=variable_values, parse_result=False ): review = result["reviewAdded"] From 6cfab1e45f1cc5db774f910ac1cf279e2be74def Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 22 Nov 2021 21:02:38 +0100 Subject: [PATCH 033/239] Bump version number to 3.0.0b1 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index b3d3a3b4..3996ce87 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0b0" +__version__ = "3.0.0b1" From 5df465b9fd75b1da2411a3576e90ab80aa6095c7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 24 Nov 2021 10:10:03 +0100 Subject: [PATCH 034/239] Feature get_execution_result argument of execute and subscribe (#257) --- docs/usage/extensions.rst | 36 ++++++++++++++++++++++++ docs/usage/index.rst | 1 + gql/client.py | 41 +++++++++++++++++++++++----- gql/transport/exceptions.py | 2 ++ tests/test_aiohttp.py | 2 +- tests/test_requests.py | 2 +- tests/test_websocket_query.py | 2 +- tests/test_websocket_subscription.py | 26 ++++++++++++++++++ 8 files changed, 102 insertions(+), 10 deletions(-) create mode 100644 docs/usage/extensions.rst diff --git a/docs/usage/extensions.rst b/docs/usage/extensions.rst new file mode 100644 index 00000000..ec413656 --- /dev/null +++ b/docs/usage/extensions.rst @@ -0,0 +1,36 @@ +.. _extensions: + +Extensions +---------- + +When you execute (or subscribe) GraphQL requests, the server will send +responses which may have 3 fields: + +- data: the serialized response from the backend +- errors: a list of potential errors +- extensions: an optional field for additional data + +If there are errors in the response, then the +:code:`execute` or :code:`subscribe` methods will +raise a :code:`TransportQueryError`. + +If no errors are present, then only the data from the response is returned by default. + +.. code-block:: python + + result = client.execute(query) + # result is here the content of the data field + +If you need to receive the extensions data too, then you can run the +:code:`execute` or :code:`subscribe` methods with :code:`get_execution_result=True`. + +In that case, the full execution result is returned and you can have access +to the extensions field + +.. code-block:: python + + result = client.execute(query, get_execution_result=True) + # result is here an ExecutionResult instance + + # result.data is the content of the data field + # result.extensions is the content of the extensions field diff --git a/docs/usage/index.rst b/docs/usage/index.rst index eebf9fd2..f73ac75a 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -11,3 +11,4 @@ Usage headers file_upload custom_scalars_and_enums + extensions diff --git a/gql/client.py b/gql/client.py index 079bb552..c39da95b 100644 --- a/gql/client.py +++ b/gql/client.py @@ -367,8 +367,9 @@ def execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, + get_execution_result: bool = False, **kwargs, - ) -> Dict: + ) -> Union[Dict[str, Any], ExecutionResult]: """Execute the provided document AST synchronously using the sync transport. @@ -382,6 +383,8 @@ def execute( serialized. Used for custom scalars and/or enums. Default: False. :param parse_result: Whether gql will unserialize the result. By default use the parse_results attribute of the client. + :param get_execution_result: return the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. The extra arguments are passed to the transport execute method.""" @@ -399,13 +402,19 @@ def execute( # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( - str(result.errors[0]), errors=result.errors, data=result.data + str(result.errors[0]), + errors=result.errors, + data=result.data, + extensions=result.extensions, ) assert ( result.data is not None ), "Transport returned an ExecutionResult without data or errors" + if get_execution_result: + return result + return result.data def fetch_schema(self) -> None: @@ -519,8 +528,9 @@ async def subscribe( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, + get_execution_result: bool = False, **kwargs, - ) -> AsyncGenerator[Dict, None]: + ) -> AsyncGenerator[Union[Dict[str, Any], ExecutionResult], None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport. @@ -534,6 +544,8 @@ async def subscribe( serialized. Used for custom scalars and/or enums. Default: False. :param parse_result: Whether gql will unserialize the result. By default use the parse_results attribute of the client. + :param get_execution_result: yield the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. The extra arguments are passed to the transport subscribe method.""" @@ -554,11 +566,17 @@ async def subscribe( # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( - str(result.errors[0]), errors=result.errors, data=result.data + str(result.errors[0]), + errors=result.errors, + data=result.data, + extensions=result.extensions, ) elif result.data is not None: - yield result.data + if get_execution_result: + yield result + else: + yield result.data finally: await inner_generator.aclose() @@ -636,8 +654,9 @@ async def execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, + get_execution_result: bool = False, **kwargs, - ) -> Dict: + ) -> Union[Dict[str, Any], ExecutionResult]: """Coroutine to execute the provided document AST asynchronously using the async transport. @@ -651,6 +670,8 @@ async def execute( serialized. Used for custom scalars and/or enums. Default: False. :param parse_result: Whether gql will unserialize the result. By default use the parse_results attribute of the client. + :param get_execution_result: return the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. The extra arguments are passed to the transport execute method.""" @@ -668,13 +689,19 @@ async def execute( # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( - str(result.errors[0]), errors=result.errors, data=result.data + str(result.errors[0]), + errors=result.errors, + data=result.data, + extensions=result.extensions, ) assert ( result.data is not None ), "Transport returned an ExecutionResult without data or errors" + if get_execution_result: + return result + return result.data async def fetch_schema(self) -> None: diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 899d5d66..250e7523 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -35,11 +35,13 @@ def __init__( query_id: Optional[int] = None, errors: Optional[List[Any]] = None, data: Optional[Any] = None, + extensions: Optional[Any] = None, ): super().__init__(msg) self.query_id = query_id self.errors = errors self.data = data + self.extensions = extensions class TransportClosed(TransportError): diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index df954f12..50cec3f9 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1070,6 +1070,6 @@ async def handler(request): query = gql(query1_str) - execution_result = await session._execute(query) + execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" diff --git a/tests/test_requests.py b/tests/test_requests.py index d0cc7eb7..c3123d72 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -328,7 +328,7 @@ def test_code(): query = gql(query1_str) - execution_result = session._execute(query) + execution_result = session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index e825c637..4e51f161 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -596,6 +596,6 @@ async def test_websocket_simple_query_with_extensions( query = gql(query_str) - execution_result = await session._execute(query) + execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 5300333d..ff484157 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -4,6 +4,7 @@ from typing import List import pytest +from graphql import ExecutionResult from parse import search from gql import Client, gql @@ -142,6 +143,31 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio assert count == -1 +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_get_execution_result( + event_loop, client_and_server, subscription_str +): + + session, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription, get_execution_result=True): + + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) From 6b5342201ec417100e355f60350ee61b600a033a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 24 Nov 2021 10:16:07 +0100 Subject: [PATCH 035/239] Feature gql cli print schema (#258) --- docs/gql-cli/intro.rst | 9 ++++++++ docs/usage/validation.rst | 5 +++++ gql/cli.py | 21 ++++++++++++++++-- tests/custom_scalars/test_money.py | 35 ++++++++++++++++++++++++++---- 4 files changed, 64 insertions(+), 6 deletions(-) diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst index 3a25c6df..b4565b01 100644 --- a/docs/gql-cli/intro.rst +++ b/docs/gql-cli/intro.rst @@ -1,3 +1,5 @@ +.. _gql_cli: + gql-cli ======= @@ -69,3 +71,10 @@ Then execute query from the file: $ cat query.gql | gql-cli wss://countries.trevorblades.com/graphql {"continent": {"name": "Africa"}} + +Print the GraphQL schema in a file +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: shell + + $ gql-cli https://countries.trevorblades.com/graphql --print-schema > schema.graphql diff --git a/docs/usage/validation.rst b/docs/usage/validation.rst index df6990bd..18b1cda1 100644 --- a/docs/usage/validation.rst +++ b/docs/usage/validation.rst @@ -21,6 +21,11 @@ The schema can be provided as a String (which is usually stored in a .graphql fi client = Client(schema=schema_str) +.. note:: + You can download a schema from a server by using :ref:`gql-cli ` + + :code:`$ gql-cli https://SERVER_URL/graphql --print-schema > schema.graphql` + OR can be created using python classes: .. code-block:: python diff --git a/gql/cli.py b/gql/cli.py index c75ad120..917a4268 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from typing import Any, Dict -from graphql import GraphQLError +from graphql import GraphQLError, print_schema from yarl import URL from gql import Client, __version__, gql @@ -38,6 +38,9 @@ # Execute query saved in a file cat query.gql | gql-cli wss://countries.trevorblades.com/graphql +# Print the schema of the backend +gql-cli https://countries.trevorblades.com/graphql --print-schema + """ @@ -92,6 +95,12 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: help="set the operation_name value", dest="operation_name", ) + parser.add_argument( + "--print-schema", + help="get the schema from instrospection and print it", + action="store_true", + dest="print_schema", + ) return parser @@ -241,7 +250,15 @@ async def main(args: Namespace) -> int: exit_code = 0 # Connect to the backend and provide a session - async with Client(transport=transport) as session: + async with Client( + transport=transport, fetch_schema_from_transport=args.print_schema + ) as session: + + if args.print_schema: + schema_str = print_schema(session.client.schema) + print(schema_str) + + return exit_code while True: diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 1b65ec98..2e30b6b7 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -410,10 +410,8 @@ async def handler(request): data = await request.json() source = data["query"] - print(f"data keys = {data.keys()}") try: variables = data["variables"] - print(f"variables = {variables!r}") except KeyError: variables = None @@ -421,8 +419,6 @@ async def handler(request): schema, source, variable_values=variables, root_value=root_value ) - print(f"backend result = {result!r}") - return web.json_response( { "data": result.data, @@ -742,3 +738,34 @@ def test_serialize_value_with_nullable_type(): nullable_int = GraphQLInt assert serialize_value(nullable_int, None) is None + + +@pytest.mark.asyncio +async def test_gql_cli_print_schema(event_loop, aiohttp_server, capsys): + + from gql.cli import get_parser, main + + server = await make_money_backend(aiohttp_server) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, "--print-schema"]) + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + print(captured_out) + assert ( + """ +type Subscription { + spend(money: Money): Money +} +""".strip() + in captured_out + ) From 0856f11b08ee9f5818778a168091ee8ba5d3905a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 24 Nov 2021 21:14:40 +0100 Subject: [PATCH 036/239] DSL meta fields implementation (#259) --- docs/advanced/dsl_module.rst | 10 ++ gql/dsl.py | 105 +++++++++++++--- gql/utilities/__init__.py | 8 +- gql/utilities/get_introspection_query_ast.py | 123 +++++++++++++++++++ tests/starwars/test_dsl.py | 115 ++++++++++++++++- 5 files changed, 335 insertions(+), 26 deletions(-) create mode 100644 gql/utilities/get_introspection_query_ast.py diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index afaa3bc6..f4046f27 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -338,6 +338,16 @@ this can be written in a concise manner:: DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet) ) +Meta-fields +^^^^^^^^^^^ + +To define meta-fields (:code:`__typename`, :code:`__schema` and :code:`__type`), +you can use the :class:`DSLMetaField ` class:: + + query = ds.Query.hero.select( + ds.Character.name, + DSLMetaField("__typename") + ) Executable examples ------------------- diff --git a/gql/dsl.py b/gql/dsl.py index 1646d402..f864c316 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -25,6 +25,7 @@ GraphQLNonNull, GraphQLObjectType, GraphQLSchema, + GraphQLString, GraphQLWrappingType, InlineFragmentNode, IntValueNode, @@ -46,6 +47,7 @@ VariableDefinitionNode, VariableNode, assert_named_type, + introspection_types, is_enum_type, is_input_object_type, is_leaf_type, @@ -301,6 +303,7 @@ class DSLExecutable(ABC): variable_definitions: "DSLVariableDefinitions" name: Optional[str] + selection_set: SelectionSetNode @property @abstractmethod @@ -349,11 +352,31 @@ def __init__( f"Received type: {type(field)}" ) ) + valid_type = False if isinstance(self, DSLOperation): - assert field.type_name.upper() == self.operation_type.name, ( - f"Invalid root field for operation {self.operation_type.name}.\n" - f"Received: {field.type_name}" - ) + operation_name = self.operation_type.name + if isinstance(field, DSLMetaField): + if field.name in ["__schema", "__type"]: + valid_type = operation_name == "QUERY" + if field.name == "__typename": + valid_type = operation_name != "SUBSCRIPTION" + else: + valid_type = field.parent_type.name.upper() == operation_name + + else: # Fragments + if isinstance(field, DSLMetaField): + valid_type = field.name == "__typename" + + if not valid_type: + if isinstance(self, DSLOperation): + error_msg = ( + "Invalid root field for operation " + f"{self.operation_type.name}" + ) + else: + error_msg = f"Invalid field for fragment {self.name}" + + raise AssertionError(f"{error_msg}: {field!r}") self.selection_set = SelectionSetNode( selections=FrozenList(DSLSelectable.get_ast_fields(all_fields)) @@ -610,6 +633,11 @@ def select( fields, fields_with_alias ) + # Check that we don't receive an invalid meta-field + for field in added_fields: + if isinstance(field, DSLMetaField) and field.name != "__typename": + raise AssertionError(f"Invalid field for {self!r}: {field!r}") + # Get a list of AST Nodes for each added field added_selections: List[ Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] @@ -668,8 +696,8 @@ class DSLField(DSLSelectableWithAlias, DSLSelector): def __init__( self, name: str, - graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], - graphql_field: GraphQLField, + parent_type: Union[GraphQLObjectType, GraphQLInterfaceType], + field: GraphQLField, ): """Initialize the DSLField. @@ -678,15 +706,21 @@ def __init__( Use attributes of the :class:`DSLType` instead. :param name: the name of the field - :param graphql_type: the GraphQL type definition from the schema - :param graphql_field: the GraphQL field definition from the schema + :param parent_type: the GraphQL type definition from the schema of the + parent type of the field + :param field: the GraphQL field definition from the schema """ DSLSelector.__init__(self) - self._type = graphql_type - self.field = graphql_field + self.parent_type = parent_type + self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) log.debug(f"Creating {self!r}") + @property + def name(self): + """:meta private:""" + return self.ast_field.name.value + def __call__(self, **kwargs) -> "DSLField": return self.args(**kwargs) @@ -750,16 +784,49 @@ def select( return self - @property - def type_name(self): - """:meta private:""" - return self._type.name - def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__} {self._type.name}" - f"::{self.ast_field.name.value}>" - ) + return f"<{self.__class__.__name__} {self.parent_type.name}" f"::{self.name}>" + + +class DSLMetaField(DSLField): + """DSLMetaField represents a GraphQL meta-field for the DSL code. + + meta-fields are reserved field in the GraphQL type system prefixed with + "__" two underscores and used for introspection. + """ + + meta_type = GraphQLObjectType( + "meta-field", + fields={ + "__typename": GraphQLField(GraphQLString), + "__schema": GraphQLField( + cast(GraphQLObjectType, introspection_types["__Schema"]) + ), + "__type": GraphQLField( + cast(GraphQLObjectType, introspection_types["__Type"]), + args={"name": GraphQLArgument(type_=GraphQLNonNull(GraphQLString))}, + ), + }, + ) + + def __init__(self, name: str): + """Initialize the meta-field. + + :param name: the name between __typename, __schema or __type + """ + + try: + field = self.meta_type.fields[name] + except KeyError: + raise AssertionError(f'Invalid meta-field "{name}"') + + super().__init__(name, self.meta_type, field) + + def alias(self, alias: str) -> "DSLSelectableWithAlias": + """ + :meta private: + """ + pass class DSLInlineFragment(DSLSelectable, DSLSelector): diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index d17f9b2d..7089d360 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,13 +1,15 @@ +from .get_introspection_query_ast import get_introspection_query_ast from .parse_result import parse_result from .serialize_variable_values import serialize_value, serialize_variable_values from .update_schema_enum import update_schema_enum from .update_schema_scalars import update_schema_scalar, update_schema_scalars __all__ = [ - "update_schema_scalars", - "update_schema_scalar", - "update_schema_enum", "parse_result", + "get_introspection_query_ast", "serialize_variable_values", "serialize_value", + "update_schema_enum", + "update_schema_scalars", + "update_schema_scalar", ] diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py new file mode 100644 index 00000000..bbb07771 --- /dev/null +++ b/gql/utilities/get_introspection_query_ast.py @@ -0,0 +1,123 @@ +from itertools import repeat + +from graphql import DocumentNode, GraphQLSchema + +from gql.dsl import DSLFragment, DSLMetaField, DSLQuery, DSLSchema, dsl_gql + + +def get_introspection_query_ast( + descriptions: bool = True, + specified_by_url: bool = False, + directive_is_repeatable: bool = False, + schema_description: bool = False, + type_recursion_level: int = 7, +) -> DocumentNode: + """Get a query for introspection as a document using the DSL module. + + Equivalent to the get_introspection_query function from graphql-core + but using the DSL module and allowing to select the recursion level. + + Optionally, you can exclude descriptions, include specification URLs, + include repeatability of directives, and specify whether to include + the schema description as well. + """ + + ds = DSLSchema(GraphQLSchema()) + + fragment_FullType = DSLFragment("FullType").on(ds.__Type) + fragment_InputValue = DSLFragment("InputValue").on(ds.__InputValue) + fragment_TypeRef = DSLFragment("TypeRef").on(ds.__Type) + + schema = DSLMetaField("__schema") + + if descriptions and schema_description: + schema.select(ds.__Schema.description) + + schema.select( + ds.__Schema.queryType.select(ds.__Type.name), + ds.__Schema.mutationType.select(ds.__Type.name), + ds.__Schema.subscriptionType.select(ds.__Type.name), + ) + + schema.select(ds.__Schema.types.select(fragment_FullType)) + + directives = ds.__Schema.directives.select(ds.__Directive.name) + + if descriptions: + directives.select(ds.__Directive.description) + if directive_is_repeatable: + directives.select(ds.__Directive.isRepeatable) + directives.select( + ds.__Directive.locations, ds.__Directive.args.select(fragment_InputValue), + ) + + schema.select(directives) + + fragment_FullType.select( + ds.__Type.kind, ds.__Type.name, + ) + if descriptions: + fragment_FullType.select(ds.__Type.description) + if specified_by_url: + fragment_FullType.select(ds.__Type.specifiedByUrl) + + fields = ds.__Type.fields(includeDeprecated=True).select(ds.__Field.name) + + if descriptions: + fields.select(ds.__Field.description) + + fields.select( + ds.__Field.args.select(fragment_InputValue), + ds.__Field.type.select(fragment_TypeRef), + ds.__Field.isDeprecated, + ds.__Field.deprecationReason, + ) + + enum_values = ds.__Type.enumValues(includeDeprecated=True).select( + ds.__EnumValue.name + ) + + if descriptions: + enum_values.select(ds.__EnumValue.description) + + enum_values.select( + ds.__EnumValue.isDeprecated, ds.__EnumValue.deprecationReason, + ) + + fragment_FullType.select( + fields, + ds.__Type.inputFields.select(fragment_InputValue), + ds.__Type.interfaces.select(fragment_TypeRef), + enum_values, + ds.__Type.possibleTypes.select(fragment_TypeRef), + ) + + fragment_InputValue.select(ds.__InputValue.name) + + if descriptions: + fragment_InputValue.select(ds.__InputValue.description) + + fragment_InputValue.select( + ds.__InputValue.type.select(fragment_TypeRef), ds.__InputValue.defaultValue, + ) + + fragment_TypeRef.select( + ds.__Type.kind, ds.__Type.name, + ) + + if type_recursion_level >= 1: + current_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name) + fragment_TypeRef.select(current_field) + + for _ in repeat(None, type_recursion_level - 1): + new_oftype = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name) + current_field.select(new_oftype) + current_field = new_oftype + + query = DSLQuery(schema) + + query.name = "IntrospectionQuery" + + dsl_query = dsl_gql(query, fragment_FullType, fragment_InputValue, fragment_TypeRef) + + return dsl_query diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index d18bb37d..a86ceff9 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -14,11 +14,13 @@ Undefined, print_ast, ) +from graphql.utilities import get_introspection_query -from gql import Client +from gql import Client, gql from gql.dsl import ( DSLFragment, DSLInlineFragment, + DSLMetaField, DSLMutation, DSLQuery, DSLSchema, @@ -29,6 +31,7 @@ ast_from_value, dsl_gql, ) +from gql.utilities import get_introspection_query_ast from .schema import StarWarsSchema @@ -616,9 +619,9 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): with pytest.raises(AssertionError) as excinfo: DSLQuery(ds.Character.name) - assert ("Invalid root field for operation QUERY.\n" "Received: Character") in str( - excinfo.value - ) + assert ( + "Invalid root field for operation QUERY: " "" + ) in str(excinfo.value) def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): @@ -638,3 +641,107 @@ def test_invalid_type(ds): AttributeError, match="Type 'invalid_type' not found in the schema!" ): ds.invalid_type + + +def test_hero_name_query_with_typename(ds): + query = """ +hero { + name + __typename +} + """.strip() + query_dsl = ds.Query.hero.select(ds.Character.name, DSLMetaField("__typename")) + assert query == str(query_dsl) + + +def test_type_hero_query(ds): + query = """{ + __type(name: "Hero") { + kind + name + ofType { + kind + name + } + } +}""" + + type_hero = DSLMetaField("__type")(name="Hero") + type_hero.select( + ds.__Type.kind, + ds.__Type.name, + ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name), + ) + query_dsl = DSLQuery(type_hero) + + assert query == str(print_ast(dsl_gql(query_dsl))).strip() + + +def test_invalid_meta_field_selection(ds): + + DSLQuery(DSLMetaField("__typename")) + DSLQuery(DSLMetaField("__schema")) + DSLQuery(DSLMetaField("__type")) + + metafield = DSLMetaField("__typename") + assert metafield.name == "__typename" + + # alias does not work + metafield.alias("test") + + assert metafield.name == "__typename" + + with pytest.raises(AssertionError): + DSLMetaField("__invalid_meta_field") + + DSLMutation(DSLMetaField("__typename")) + + with pytest.raises(AssertionError): + DSLMutation(DSLMetaField("__schema")) + + with pytest.raises(AssertionError): + DSLMutation(DSLMetaField("__type")) + + with pytest.raises(AssertionError): + DSLSubscription(DSLMetaField("__typename")) + + with pytest.raises(AssertionError): + DSLSubscription(DSLMetaField("__schema")) + + with pytest.raises(AssertionError): + DSLSubscription(DSLMetaField("__type")) + + DSLFragment("blah", DSLMetaField("__typename")) + + with pytest.raises(AssertionError): + DSLFragment("blah", DSLMetaField("__schema")) + + with pytest.raises(AssertionError): + DSLFragment("blah", DSLMetaField("__type")) + + ds.Query.hero.select(DSLMetaField("__typename")) + + with pytest.raises(AssertionError): + ds.Query.hero.select(DSLMetaField("__schema")) + + with pytest.raises(AssertionError): + ds.Query.hero.select(DSLMetaField("__type")) + + +@pytest.mark.parametrize("option", [True, False]) +def test_get_introspection_query_ast(option): + + introspection_query = get_introspection_query( + descriptions=option, + specified_by_url=option, + directive_is_repeatable=option, + schema_description=option, + ) + dsl_introspection_query = get_introspection_query_ast( + descriptions=option, + specified_by_url=option, + directive_is_repeatable=option, + schema_description=option, + ) + + assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) From 99511006072d06144d37504fe042c54aea462061 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 24 Nov 2021 22:50:14 +0100 Subject: [PATCH 037/239] Refactor dsl select (#261) --- gql/dsl.py | 332 ++++++++++++++++++++++--------------- tests/starwars/test_dsl.py | 120 +++++++++++--- 2 files changed, 292 insertions(+), 160 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index f864c316..0cadef8b 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,3 +1,7 @@ +""" +.. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa + :alt: UML diagram +""" import logging import re from abc import ABC, abstractmethod @@ -47,6 +51,7 @@ VariableDefinitionNode, VariableNode, assert_named_type, + get_named_type, introspection_types, is_enum_type, is_input_object_type, @@ -292,7 +297,77 @@ def __getattr__(self, name: str) -> "DSLType": return DSLType(type_def) -class DSLExecutable(ABC): +class DSLSelector(ABC): + """DSLSelector is an abstract class which defines the + :meth:`select ` method to select + children fields in the query. + + Inherited by + :class:`DSLRootFieldSelector `, + :class:`DSLFieldSelector ` + :class:`DSLFragmentSelector ` + """ + + selection_set: SelectionSetNode + + def __init__( + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + ): + """:meta private:""" + self.selection_set = SelectionSetNode(selections=FrozenList([])) + + if fields or fields_with_alias: + self.select(*fields, **fields_with_alias) + + @abstractmethod + def is_valid_field(self, field: "DSLSelectable") -> bool: + raise NotImplementedError( + "Any DSLSelector subclass must have a is_valid_field method" + ) # pragma: no cover + + def select( + self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + ): + r"""Select the fields which should be added. + + :param \*fields: fields or fragments + :type \*fields: DSLSelectable + :param \**fields_with_alias: fields or fragments with alias as key + :type \**fields_with_alias: DSLSelectable + + :raises TypeError: if an argument is not an instance of :class:`DSLSelectable` + :raises GraphQLError: if an argument is not a valid field + """ + # Concatenate fields without and with alias + added_fields: Tuple["DSLSelectable", ...] = DSLField.get_aliased_fields( + fields, fields_with_alias + ) + + # Check that each field is valid + for field in added_fields: + if not isinstance(field, DSLSelectable): + raise TypeError( + "Fields should be instances of DSLSelectable. " + f"Received: {type(field)}" + ) + + if not self.is_valid_field(field): + raise GraphQLError(f"Invalid field for {self!r}: {field!r}") + + # Get a list of AST Nodes for each added field + added_selections: List[ + Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] + ] = [field.ast_field for field in added_fields] + + # Update the current selection list with new selections + self.selection_set.selections = FrozenList( + self.selection_set.selections + added_selections + ) + + log.debug(f"Added fields: {added_fields} in {self!r}") + + +class DSLExecutable(DSLSelector): """Interface for the root elements which can be executed in the :func:`dsl_gql ` function @@ -316,20 +391,21 @@ def executable_ast(self): def __init__( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", ): - r"""Given arguments of type :class:`DSLField` containing GraphQL requests, + r"""Given arguments of type :class:`DSLSelectable` containing GraphQL requests, generate an operation which can be converted to a Document using the :func:`dsl_gql `. - The fields arguments should be fields of root GraphQL types + The fields arguments should be either be fragments or + fields of root GraphQL types (Query, Mutation or Subscription) and correspond to the operation_type of this operation. - :param \*fields: root instances of the dynamically generated requests - :type \*fields: DSLField - :param \**fields_with_alias: root instances fields with alias as key - :type \**fields_with_alias: DSLField + :param \*fields: root fields or fragments + :type \*fields: DSLSelectable + :param \**fields_with_alias: root fields or fragments with alias as key + :type \**fields_with_alias: DSLSelectable - :raises TypeError: if an argument is not an instance of :class:`DSLField` + :raises TypeError: if an argument is not an instance of :class:`DSLSelectable` :raises AssertionError: if an argument is not a field which correspond to the operation type """ @@ -337,53 +413,46 @@ def __init__( self.name = None self.variable_definitions = DSLVariableDefinitions() - # Concatenate fields without and with alias - all_fields: Tuple["DSLSelectable", ...] = DSLField.get_aliased_fields( - fields, fields_with_alias - ) + DSLSelector.__init__(self, *fields, **fields_with_alias) - # Check that we receive only arguments of type DSLField - # And that the root type correspond to the operation - for field in all_fields: - if not isinstance(field, DSLField): - raise TypeError( - ( - "fields must be instances of DSLField. " - f"Received type: {type(field)}" - ) - ) - valid_type = False - if isinstance(self, DSLOperation): - operation_name = self.operation_type.name - if isinstance(field, DSLMetaField): - if field.name in ["__schema", "__type"]: - valid_type = operation_name == "QUERY" - if field.name == "__typename": - valid_type = operation_name != "SUBSCRIPTION" - else: - valid_type = field.parent_type.name.upper() == operation_name - - else: # Fragments - if isinstance(field, DSLMetaField): - valid_type = field.name == "__typename" - - if not valid_type: - if isinstance(self, DSLOperation): - error_msg = ( - "Invalid root field for operation " - f"{self.operation_type.name}" - ) - else: - error_msg = f"Invalid field for fragment {self.name}" - - raise AssertionError(f"{error_msg}: {field!r}") - - self.selection_set = SelectionSetNode( - selections=FrozenList(DSLSelectable.get_ast_fields(all_fields)) - ) +class DSLRootFieldSelector(DSLSelector): + """Class used to define the + :meth:`is_valid_field ` method + for root fields for the :meth:`select ` method. -class DSLOperation(DSLExecutable): + Inherited by + :class:`DSLOperation ` + """ + + def is_valid_field(self, field: "DSLSelectable") -> bool: + """Check that a field is valid for a root field. + + For operations, the fields arguments should be fields of root GraphQL types + (Query, Mutation or Subscription) and correspond to the + operation_type of this operation. + + the :code:`__typename` field can only be added to Query or Mutation. + the :code:`__schema` and :code:`__type` field can only be added to Query. + """ + + assert isinstance(self, DSLOperation) + + operation_name = self.operation_type.name + + if isinstance(field, DSLMetaField): + if field.name in ["__schema", "__type"]: + return operation_name == "QUERY" + if field.name == "__typename": + return operation_name != "SUBSCRIPTION" + + elif isinstance(field, DSLField): + return field.parent_type.name.upper() == operation_name + + return False + + +class DSLOperation(DSLExecutable, DSLRootFieldSelector): """Interface for GraphQL operations. Inherited by @@ -407,6 +476,9 @@ def executable_ast(self) -> OperationDefinitionNode: **({"name": NameNode(value=self.name)} if self.name else {}), ) + def __repr__(self) -> str: + return f"<{self.__class__.__name__}>" + class DSLQuery(DSLOperation): operation_type = OperationType.QUERY @@ -427,10 +499,11 @@ class DSLVariable: of the :class:`DSLVariableDefinitions` The type of the variable is set by the :class:`DSLField` instance that receives it - in the `args` method. + in the :meth:`args ` method. """ def __init__(self, name: str): + """:meta private:""" self.type: Optional[TypeNode] = None self.name = name self.ast_variable = VariableNode(name=NameNode(value=self.name)) @@ -462,11 +535,12 @@ class DSLVariableDefinitions: Attributes of the DSLVariableDefinitions class are generated automatically with the `__getattr__` dunder method in order to generate - instances of :class:`DSLVariable`, that can then be used as values in the - `DSLField.args` method + instances of :class:`DSLVariable`, that can then be used as values + in the :meth:`args ` method. """ def __init__(self): + """:meta private:""" self.variables: Dict[str, DSLVariable] = {} def __getattr__(self, name: str) -> "DSLVariable": @@ -549,28 +623,6 @@ class DSLSelectable(ABC): ast_field: Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] - @staticmethod - def get_ast_fields( - fields: Iterable["DSLSelectable"], - ) -> List[Union[FieldNode, InlineFragmentNode, FragmentSpreadNode]]: - """ - :meta private: - - Equivalent to: :code:`[field.ast_field for field in fields]` - But with a type check for each field in the list. - - :raises TypeError: if any of the provided fields are not instances - of the :class:`DSLSelectable` class. - """ - ast_fields = [] - for field in fields: - if isinstance(field, DSLSelectable): - ast_fields.append(field.ast_field) - else: - raise TypeError(f'Received incompatible field: "{field}".') - - return ast_fields - @staticmethod def get_aliased_fields( fields: Iterable["DSLSelectable"], @@ -593,64 +645,70 @@ def __str__(self) -> str: return print_ast(self.ast_field) -class DSLSelector(ABC): - """DSLSelector is an abstract class which defines the - :meth:`select ` method to select - children fields in the query. +class DSLFragmentSelector(DSLSelector): + """Class used to define the + :meth:`is_valid_field ` method + for fragments for the :meth:`select ` method. Inherited by - :class:`DSLField `, :class:`DSLFragment `, :class:`DSLInlineFragment ` """ - selection_set: SelectionSetNode + def is_valid_field(self, field: DSLSelectable) -> bool: + """Check that a field is valid.""" - def __init__(self): - self.selection_set = SelectionSetNode(selections=FrozenList([])) + assert isinstance(self, (DSLFragment, DSLInlineFragment)) - def select( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" - ) -> "DSLSelector": - r"""Select the new children fields - that we want to receive in the request. + if isinstance(field, (DSLFragment, DSLInlineFragment)): + return True - If used multiple times, we will add the new children fields - to the existing children fields. + assert isinstance(field, DSLField) - :param \*fields: new children fields - :type \*fields: DSLSelectable (DSLField, DSLFragment or DSLInlineFragment) - :param \**fields_with_alias: new children fields with alias as key - :type \**fields_with_alias: DSLField - :return: itself + if isinstance(field, DSLMetaField): + return field.name == "__typename" - :raises TypeError: if any of the provided fields are not instances - of the :class:`DSLSelectable` class. - """ + fragment_type = self._type - # Concatenate fields without and with alias - added_fields: Tuple["DSLSelectable", ...] = DSLSelectable.get_aliased_fields( - fields, fields_with_alias - ) + assert fragment_type is not None - # Check that we don't receive an invalid meta-field - for field in added_fields: - if isinstance(field, DSLMetaField) and field.name != "__typename": - raise AssertionError(f"Invalid field for {self!r}: {field!r}") + if field.name in fragment_type.fields.keys(): + return fragment_type.fields[field.name].type == field.field.type - # Get a list of AST Nodes for each added field - added_selections: List[ - Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] - ] = DSLSelectable.get_ast_fields(added_fields) + return False - # Update the current selection list with new selections - self.selection_set.selections = FrozenList( - self.selection_set.selections + added_selections - ) - log.debug(f"Added fields: {added_fields} in {self!r}") +class DSLFieldSelector(DSLSelector): + """Class used to define the + :meth:`is_valid_field ` method + for fields for the :meth:`select ` method. - return self + Inherited by + :class:`DSLField `, + """ + + def is_valid_field(self, field: DSLSelectable) -> bool: + """Check that a field is valid.""" + + assert isinstance(self, DSLField) + + if isinstance(field, (DSLFragment, DSLInlineFragment)): + return True + + assert isinstance(field, DSLField) + + if isinstance(field, DSLMetaField): + return field.name == "__typename" + + parent_type = get_named_type(self.field.type) + + if not isinstance(parent_type, (GraphQLInterfaceType, GraphQLObjectType)): + return False + + if field.name in parent_type.fields.keys(): + return parent_type.fields[field.name].type == field.field.type + + return False class DSLSelectableWithAlias(DSLSelectable): @@ -678,7 +736,7 @@ def alias(self, alias: str) -> "DSLSelectableWithAlias": return self -class DSLField(DSLSelectableWithAlias, DSLSelector): +class DSLField(DSLSelectableWithAlias, DSLFieldSelector): """The DSLField represents a GraphQL field for the DSL code. Instances of this class are generated for you automatically as attributes @@ -710,12 +768,14 @@ def __init__( parent type of the field :param field: the GraphQL field definition from the schema """ - DSLSelector.__init__(self) self.parent_type = parent_type self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) + log.debug(f"Creating {self!r}") + DSLSelector.__init__(self) + @property def name(self): """:meta private:""" @@ -818,7 +878,7 @@ def __init__(self, name: str): try: field = self.meta_type.fields[name] except KeyError: - raise AssertionError(f'Invalid meta-field "{name}"') + raise GraphQLError(f'Invalid meta-field "{name}"') super().__init__(name, self.meta_type, field) @@ -826,10 +886,10 @@ def alias(self, alias: str) -> "DSLSelectableWithAlias": """ :meta private: """ - pass + return self -class DSLInlineFragment(DSLSelectable, DSLSelector): +class DSLInlineFragment(DSLSelectable, DSLFragmentSelector): """DSLInlineFragment represents an inline fragment for the DSL code.""" _type: Union[GraphQLObjectType, GraphQLInterfaceType] @@ -846,11 +906,12 @@ def __init__( :type \**fields_with_alias: DSLField """ - DSLSelector.__init__(self) - self.ast_field = InlineFragmentNode() - self.select(*fields, **fields_with_alias) log.debug(f"Creating {self!r}") + self.ast_field = InlineFragmentNode() + + DSLSelector.__init__(self, *fields, **fields_with_alias) + def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" ) -> "DSLInlineFragment": @@ -882,7 +943,7 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__}{type_info}>" -class DSLFragment(DSLSelectable, DSLSelector, DSLExecutable): +class DSLFragment(DSLSelectable, DSLFragmentSelector, DSLExecutable): """DSLFragment represents a named GraphQL fragment for the DSL code.""" _type: Optional[Union[GraphQLObjectType, GraphQLInterfaceType]] @@ -890,23 +951,15 @@ class DSLFragment(DSLSelectable, DSLSelector, DSLExecutable): name: str def __init__( - self, - name: str, - *fields: "DSLSelectable", - **fields_with_alias: "DSLSelectableWithAlias", + self, name: str, ): r"""Initialize the DSLFragment. :param name: the name of the fragment :type name: str - :param \*fields: new children fields - :type \*fields: DSLSelectable (DSLField, DSLFragment or DSLInlineFragment) - :param \**fields_with_alias: new children fields with alias as key - :type \**fields_with_alias: DSLField """ - DSLSelector.__init__(self) - DSLExecutable.__init__(self, *fields, **fields_with_alias) + DSLExecutable.__init__(self) self.name = name self._type = None @@ -933,6 +986,11 @@ def select( """Calling :meth:`select ` method with corrected typing hints """ + if self._type is None: + raise AttributeError( + "Missing type condition. Please use .on(type_condition) method" + ) + super().select(*fields, **fields_with_alias) return self diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index a86ceff9..0335d721 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -189,9 +189,12 @@ def test_invalid_field_on_type_query(ds): def test_incompatible_field(ds): - with pytest.raises(Exception) as exc_info: + with pytest.raises(TypeError) as exc_info: ds.Query.hero.select("not_a_DSL_FIELD") - assert "Received incompatible field" in str(exc_info.value) + assert ( + "Fields should be instances of DSLSelectable. Received: " + in str(exc_info.value) + ) def test_hero_name_query(ds): @@ -378,6 +381,22 @@ def test_subscription(ds): ) +def test_field_does_not_exit_in_type(ds): + with pytest.raises( + GraphQLError, + match="Invalid field for : ", + ): + ds.Query.hero.select(ds.Query.hero) + + +def test_try_to_select_on_scalar_field(ds): + with pytest.raises( + GraphQLError, + match="Invalid field for : ", + ): + ds.Human.name.select(ds.Query.hero) + + def test_invalid_arg(ds): with pytest.raises( KeyError, match="Argument invalid_arg does not exist in Field: Character." @@ -480,6 +499,24 @@ def test_inline_fragments(ds): assert query == str(query_dsl) +def test_inline_fragments_nested(ds): + query = """hero(episode: JEDI) { + name + ... on Human { + ... on Human { + homePlanet + } + } +}""" + query_dsl = ds.Query.hero.args(episode=6).select( + ds.Character.name, + DSLInlineFragment() + .on(ds.Human) + .select(DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet)), + ) + assert query == str(query_dsl) + + def test_fragments_repr(ds): assert repr(DSLInlineFragment()) == "" @@ -519,9 +556,7 @@ def test_fragments(ds): def test_fragment_without_type_condition_error(ds): # We create a fragment without using the .on(type_condition) method - name_and_appearances = DSLFragment("NameAndAppearances").select( - ds.Character.name, ds.Character.appearsIn - ) + name_and_appearances = DSLFragment("NameAndAppearances") # If we try to use this fragment, gql generates an error with pytest.raises( @@ -530,6 +565,26 @@ def test_fragment_without_type_condition_error(ds): ): dsl_gql(name_and_appearances) + with pytest.raises( + AttributeError, + match=r"Missing type condition. Please use .on\(type_condition\) method", + ): + DSLFragment("NameAndAppearances").select( + ds.Character.name, ds.Character.appearsIn + ) + + +def test_inline_fragment_in_dsl_gql(ds): + + inline_fragment = DSLInlineFragment() + + query = DSLQuery() + + with pytest.raises( + GraphQLError, match=r"Invalid field for : ", + ): + query.select(inline_fragment) + def test_fragment_with_name_changed(ds): @@ -542,6 +597,17 @@ def test_fragment_with_name_changed(ds): assert str(fragment) == "...DEF" +def test_fragment_select_field_not_in_fragment(ds): + + fragment = DSLFragment("test").on(ds.Character) + + with pytest.raises( + GraphQLError, + match="Invalid field for : ", + ): + fragment.select(ds.Droid.primaryFunction) + + def test_dsl_nested_query_with_fragment(ds): query = """fragment NameAndAppearances on Character { name @@ -610,18 +676,19 @@ def test_dsl_nested_query_with_fragment(ds): def test_dsl_query_all_fields_should_be_instances_of_DSLField(): with pytest.raises( - TypeError, match="fields must be instances of DSLField. Received type:" + TypeError, + match="Fields should be instances of DSLSelectable. Received: ", ): DSLQuery("I am a string") def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): - with pytest.raises(AssertionError) as excinfo: + with pytest.raises(GraphQLError) as excinfo: DSLQuery(ds.Character.name) - assert ( - "Invalid root field for operation QUERY: " "" - ) in str(excinfo.value) + assert ("Invalid field for : ") in str( + excinfo.value + ) def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): @@ -691,40 +758,47 @@ def test_invalid_meta_field_selection(ds): assert metafield.name == "__typename" - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): DSLMetaField("__invalid_meta_field") DSLMutation(DSLMetaField("__typename")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): DSLMutation(DSLMetaField("__schema")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): DSLMutation(DSLMetaField("__type")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): DSLSubscription(DSLMetaField("__typename")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): DSLSubscription(DSLMetaField("__schema")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): DSLSubscription(DSLMetaField("__type")) - DSLFragment("blah", DSLMetaField("__typename")) + fragment = DSLFragment("blah") + + with pytest.raises(AttributeError): + fragment.select(DSLMetaField("__typename")) + + fragment.on(ds.Character) + + fragment.select(DSLMetaField("__typename")) - with pytest.raises(AssertionError): - DSLFragment("blah", DSLMetaField("__schema")) + with pytest.raises(GraphQLError): + fragment.select(DSLMetaField("__schema")) - with pytest.raises(AssertionError): - DSLFragment("blah", DSLMetaField("__type")) + with pytest.raises(GraphQLError): + fragment.select(DSLMetaField("__type")) ds.Query.hero.select(DSLMetaField("__typename")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): ds.Query.hero.select(DSLMetaField("__schema")) - with pytest.raises(AssertionError): + with pytest.raises(GraphQLError): ds.Query.hero.select(DSLMetaField("__type")) From eb986dff5a368d71d7004050d339d54acd155c11 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 25 Nov 2021 14:58:51 +0100 Subject: [PATCH 038/239] Remove Client type_def obsolete argument (#262) --- gql/client.py | 15 --------------- tests/test_async_client_validation.py | 21 +++------------------ 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/gql/client.py b/gql/client.py index c39da95b..fc686c16 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,5 +1,4 @@ import asyncio -import warnings from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union from graphql import ( @@ -45,7 +44,6 @@ def __init__( self, schema: Optional[Union[str, GraphQLSchema]] = None, introspection=None, - type_def: Optional[str] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[Union[int, float]] = 10, @@ -67,19 +65,6 @@ def __init__( :param parse_results: Whether gql will try to parse the serialized output sent by the backend. Can be used to unserialize custom scalars or enums. """ - assert not ( - type_def and introspection - ), "Cannot provide introspection and type definition at the same time." - - if type_def: - assert ( - not schema - ), "Cannot provide type definition and schema at the same time." - warnings.warn( - "type_def is deprecated; use schema instead", - category=DeprecationWarning, - ) - schema = type_def if introspection: assert ( diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index 107bd6c2..b588e6ba 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -1,6 +1,5 @@ import asyncio import json -import warnings import graphql import pytest @@ -83,7 +82,6 @@ async def server_starwars(ws, path): [ {"schema": StarWarsSchema}, {"introspection": StarWarsIntrospection}, - {"type_def": StarWarsTypeDef}, {"schema": StarWarsTypeDef}, ], ) @@ -97,11 +95,7 @@ async def test_async_client_validation( sample_transport = WebsocketsTransport(url=url) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="type_def is deprecated; use schema instead" - ) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=sample_transport, **client_params) async with client as session: @@ -135,7 +129,6 @@ async def test_async_client_validation( [ {"schema": StarWarsSchema}, {"introspection": StarWarsIntrospection}, - {"type_def": StarWarsTypeDef}, {"schema": StarWarsTypeDef}, ], ) @@ -149,11 +142,7 @@ async def test_async_client_validation_invalid_query( sample_transport = WebsocketsTransport(url=url) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="type_def is deprecated; use schema instead" - ) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=sample_transport, **client_params) async with client as session: @@ -174,11 +163,7 @@ async def test_async_client_validation_invalid_query( @pytest.mark.parametrize("subscription_str", [starwars_invalid_subscription_str]) @pytest.mark.parametrize( "client_params", - [ - {"schema": StarWarsSchema, "introspection": StarWarsIntrospection}, - {"schema": StarWarsSchema, "type_def": StarWarsTypeDef}, - {"introspection": StarWarsIntrospection, "type_def": StarWarsTypeDef}, - ], + [{"schema": StarWarsSchema, "introspection": StarWarsIntrospection}], ) async def test_async_client_validation_different_schemas_parameters_forbidden( event_loop, server, subscription_str, client_params From 7e45815f83b54ea0b3533fd493bab37c26a23649 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 25 Nov 2021 16:51:45 +0100 Subject: [PATCH 039/239] Support python 3.10 (#264) Add Python 3.10 tests on tox and on GitHub actions Update websockets to 10 for Python versions >= 3.6 Update pytest related dependencies Update mypy dependency + add new required types dependencies Ignore new "There is no current event loop" warnings --- .github/workflows/tests.yml | 4 +++- gql/client.py | 19 ++++++++++++++++--- gql/transport/websockets.py | 7 ++++++- setup.py | 18 ++++++++++++------ tests/conftest.py | 4 +++- tests/test_graphqlws_subscription.py | 11 ++++++++--- tests/test_websocket_subscription.py | 11 ++++++++--- tox.ini | 5 +++-- 8 files changed, 59 insertions(+), 20 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8326a645..870493aa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "pypy3"] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "pypy3"] os: [ubuntu-latest, windows-latest] exclude: - os: windows-latest @@ -17,6 +17,8 @@ jobs: python-version: "3.7" - os: windows-latest python-version: "3.9" + - os: windows-latest + python-version: "3.10" - os: windows-latest python-version: "pypy3" diff --git a/gql/client.py b/gql/client.py index fc686c16..124cf34b 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,4 +1,5 @@ import asyncio +import warnings from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union from graphql import ( @@ -151,7 +152,11 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: # Get the current asyncio event loop # Or create a new event loop if there isn't one (in a new Thread) try: - loop = asyncio.get_event_loop() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -194,7 +199,11 @@ def subscribe( # Get the current asyncio event loop # Or create a new event loop if there isn't one (in a new Thread) try: - loop = asyncio.get_event_loop() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -211,7 +220,11 @@ def subscribe( # Note: we need to create a task here in order to be able to close # the async generator properly on python 3.8 # See https://bugs.python.org/issue38559 - generator_task = asyncio.ensure_future(async_generator.__anext__()) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + generator_task = asyncio.ensure_future(async_generator.__anext__()) result = loop.run_until_complete(generator_task) yield result diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 06552d2f..6d03f08e 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import warnings from contextlib import suppress from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast @@ -165,7 +166,11 @@ def __init__( # We need to set an event loop here if there is none # Or else we will not be able to create an asyncio.Event() try: - self._loop = asyncio.get_event_loop() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() except RuntimeError: self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) diff --git a/setup.py b/setup.py index 94f3a9ee..c7e6dd7f 100644 --- a/setup.py +++ b/setup.py @@ -13,9 +13,9 @@ tests_requires = [ "parse==1.15.0", - "pytest==5.4.2", - "pytest-asyncio==0.11.0", - "pytest-cov==2.8.1", + "pytest==6.2.5", + "pytest-asyncio==0.16.0", + "pytest-cov==3.0.0", "mock==4.0.2", "vcrpy==4.0.2", "aiofiles", @@ -26,10 +26,13 @@ "check-manifest>=0.42,<1", "flake8==3.8.1", "isort==4.3.21", - "mypy==0.770", + "mypy==0.910", "sphinx>=3.0.0,<4", "sphinx_rtd_theme>=0.4,<1", "sphinx-argparse==0.2.5", + "types-aiofiles", + "types-mock", + "types-requests", ] + tests_requires install_aiohttp_requires = [ @@ -43,7 +46,8 @@ ] install_websockets_requires = [ - "websockets>=9,<10", + "websockets>=9,<10;python_version<='3.6'", + "websockets>=10,<11;python_version>'3.6'", ] install_all_requires = ( @@ -67,14 +71,16 @@ author_email="me@syrusakbary.com", license="MIT", classifiers=[ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay gql client", diff --git a/tests/conftest.py b/tests/conftest.py index 519738cc..6fd9fc44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -102,8 +102,10 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests for name in [ "websockets.legacy.server", - "gql.transport.websockets", + "gql.transport.aiohttp", "gql.transport.phoenix_channel_websockets", + "gql.transport.requests", + "gql.transport.websockets", "gql.dsl", "gql.utilities.parse_result", ]: diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 2c7cff23..7826aca1 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -1,6 +1,7 @@ import asyncio import json import sys +import warnings from typing import List import pytest @@ -742,9 +743,13 @@ def test_graphqlws_subscription_sync_graceful_shutdown( if count == 5: # Simulate a KeyboardInterrupt in the generator - asyncio.ensure_future( - client.session._generator.athrow(KeyboardInterrupt) - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) count -= 1 diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index ff484157..43795b14 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -1,6 +1,7 @@ import asyncio import json import sys +import warnings from typing import List import pytest @@ -531,9 +532,13 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) if count == 5: # Simulate a KeyboardInterrupt in the generator - asyncio.ensure_future( - client.session._generator.athrow(KeyboardInterrupt) - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) count -= 1 diff --git a/tox.ini b/tox.ini index 414f083b..2699744c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{36,37,38,39,py3} + py{36,37,38,39,310,py3} [pytest] markers = asyncio @@ -12,6 +12,7 @@ python = 3.7: py37 3.8: py38 3.9: py39 + 3.10: py310 pypy3: pypy3 [testenv] @@ -30,7 +31,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{36,37,39,py3}: pytest {posargs:tests} + py{36,37,39,310,py3}: pytest {posargs:tests} py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From 75a771d746d22cd4ff5102fb818d0885fe600327 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 25 Nov 2021 17:03:12 +0100 Subject: [PATCH 040/239] Documentation small changes (#263) --- README.md | 2 +- docs/transports/phoenix.rst | 5 +++- gql/client.py | 30 ++++++++++++--------- gql/transport/phoenix_channel_websockets.py | 2 +- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index a85761e1..2fa37978 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ The main features of GQL are: * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) -* [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries from the command line +* [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line * [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically ## Installation diff --git a/docs/transports/phoenix.rst b/docs/transports/phoenix.rst index 3b5b9f53..7fb4a90c 100644 --- a/docs/transports/phoenix.rst +++ b/docs/transports/phoenix.rst @@ -3,10 +3,13 @@ PhoenixChannelWebsocketsTransport ================================= -The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport which allows you +The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. +Reference: +:class:`gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport` + .. _Absinthe: http://absinthe-graphql.org .. _Phoenix: https://www.phoenixframework.org .. _channels: https://hexdocs.pm/phoenix/Phoenix.Channel.html#content diff --git a/gql/client.py b/gql/client.py index 124cf34b..111a3dd7 100644 --- a/gql/client.py +++ b/gql/client.py @@ -315,9 +315,10 @@ def _execute( :param variable_values: Dictionary of input parameters. :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be - serialized. Used for custom scalars and/or enums. Default: False. + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. - By default use the parse_results attribute of the client. + By default use the parse_results argument of the client. The extra arguments are passed to the transport execute method.""" @@ -378,9 +379,10 @@ def execute( :param variable_values: Dictionary of input parameters. :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be - serialized. Used for custom scalars and/or enums. Default: False. + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. - By default use the parse_results attribute of the client. + By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -462,9 +464,10 @@ async def _subscribe( :param variable_values: Dictionary of input parameters. :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be - serialized. Used for custom scalars and/or enums. Default: False. + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. - By default use the parse_results attribute of the client. + By default use the parse_results argument of the client. The extra arguments are passed to the transport subscribe method.""" @@ -539,9 +542,10 @@ async def subscribe( :param variable_values: Dictionary of input parameters. :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be - serialized. Used for custom scalars and/or enums. Default: False. + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. - By default use the parse_results attribute of the client. + By default use the parse_results argument of the client. :param get_execution_result: yield the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -598,9 +602,10 @@ async def _execute( :param variable_values: Dictionary of input parameters. :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be - serialized. Used for custom scalars and/or enums. Default: False. + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. - By default use the parse_results attribute of the client. + By default use the parse_results argument of the client. The extra arguments are passed to the transport execute method.""" @@ -665,9 +670,10 @@ async def execute( :param variable_values: Dictionary of input parameters. :param operation_name: Name of the operation that shall be executed. :param serialize_variables: whether the variable values should be - serialized. Used for custom scalars and/or enums. Default: False. + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. - By default use the parse_results attribute of the client. + By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 56d35f8b..b750c39c 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -25,7 +25,7 @@ def __init__(self, query_id: int) -> None: class PhoenixChannelWebsocketsTransport(WebsocketsTransport): - """The PhoenixChannelWebsocketsTransport is an **EXPERIMENTAL** async transport + """The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. From 7f402c80f46291e83651e47fe2ae5f849055f26b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 26 Nov 2021 15:01:25 +0100 Subject: [PATCH 041/239] Invert the order of proposed subprotocols apollo and graphql-ws (#265) Should fix make all_tests with countries.trevorblades.com --- gql/transport/websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 6d03f08e..779a3608 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -203,8 +203,8 @@ def __init__( self.close_exception: Optional[Exception] = None self.supported_subprotocols = [ - self.GRAPHQLWS_SUBPROTOCOL, self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, ] async def _send(self, message: str) -> None: From 9b85b6c1f01da4aa3c840b36302c1bd27cebb53b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 26 Nov 2021 20:25:44 +0100 Subject: [PATCH 042/239] Fix aiohttp wait for closed ssl connections (#153) --- gql/transport/aiohttp.py | 64 ++++++++++++++++++++++++++++++++++++++++ tests/conftest.py | 54 ++++++++++++++++++++++----------- tests/test_aiohttp.py | 35 ++++++++++++++++++++++ 3 files changed, 136 insertions(+), 17 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 090463e9..f34a0066 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,3 +1,5 @@ +import asyncio +import functools import io import json import logging @@ -44,6 +46,7 @@ def __init__( auth: Optional[BasicAuth] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -53,6 +56,8 @@ def __init__( :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -65,6 +70,7 @@ def __init__( self.auth: Optional[BasicAuth] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None @@ -100,6 +106,59 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") + @staticmethod + def create_aiohttp_closed_event(session) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + ssl_transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal ssl_transports + + try: + orig_lost(exc) + finally: + ssl_transports -= 1 + if ssl_transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: + orig_eof_received() + except AttributeError: # pragma: no cover + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + ssl_transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if ssl_transports == 0: + all_is_lost.set() + + return all_is_lost + async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -108,7 +167,12 @@ async def close(self) -> None: when you exit the async context manager. """ if self.session is not None: + closed_event = self.create_aiohttp_closed_event(self.session) await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass self.session = None async def execute( diff --git a/tests/conftest.py b/tests/conftest.py index 6fd9fc44..c0101241 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,8 +77,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_transport) -@pytest.fixture -async def aiohttp_server(): +async def aiohttp_server_base(with_ssl=False): """Factory to create a TestServer instance, given an app. aiohttp_server(app, **kwargs) @@ -89,7 +88,13 @@ async def aiohttp_server(): async def go(app, *, port=None, **kwargs): # type: ignore server = AIOHTTPTestServer(app, port=port) - await server.start_server(**kwargs) + + start_server_args = {**kwargs} + if with_ssl: + testcert, ssl_context = get_localhost_ssl_context() + start_server_args["ssl"] = ssl_context + + await server.start_server(**start_server_args) servers.append(server) return server @@ -99,6 +104,18 @@ async def go(app, *, port=None, **kwargs): # type: ignore await servers.pop().close() +@pytest.fixture +async def aiohttp_server(): + async for server in aiohttp_server_base(): + yield server + + +@pytest.fixture +async def ssl_aiohttp_server(): + async for server in aiohttp_server_base(with_ssl=True): + yield server + + # Adding debug logs to websocket tests for name in [ "websockets.legacy.server", @@ -121,6 +138,22 @@ async def go(app, *, port=None, **kwargs): # type: ignore MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1)) +def get_localhost_ssl_context(): + # This is a copy of certificate from websockets tests folder + # + # Generate TLS certificate with: + # $ openssl req -x509 -config test_localhost.cnf \ + # -days 15340 -newkey rsa:2048 \ + # -out test_localhost.crt -keyout test_localhost.key + # $ cat test_localhost.key test_localhost.crt > test_localhost.pem + # $ rm test_localhost.key test_localhost.crt + testcert = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(testcert) + + return (testcert, ssl_context) + + class WebSocketServer: """Websocket server on localhost on a free port. @@ -141,20 +174,7 @@ async def start(self, handler, extra_serve_args=None): extra_serve_args = {} if self.with_ssl: - # This is a copy of certificate from websockets tests folder - # - # Generate TLS certificate with: - # $ openssl req -x509 -config test_localhost.cnf \ - # -days 15340 -newkey rsa:2048 \ - # -out test_localhost.crt -keyout test_localhost.key - # $ cat test_localhost.key test_localhost.crt > test_localhost.pem - # $ rm test_localhost.key test_localhost.crt - self.testcert = bytes( - pathlib.Path(__file__).with_name("test_localhost.pem") - ) - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain(self.testcert) - + self.testcert, ssl_context = get_localhost_ssl_context() extra_serve_args["ssl"] = ssl_context # Start a server with a random open port diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 50cec3f9..6dbe46ae 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1073,3 +1073,38 @@ async def handler(request): execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + sample_transport = AIOHTTPTransport( + url=url, timeout=10, ssl_close_timeout=ssl_close_timeout + ) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" From 4554e4b4927b9a772329bdcd627a1b8065e74d50 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 28 Nov 2021 10:28:10 +0100 Subject: [PATCH 043/239] Documentation Adding error handling doc (#266) --- docs/advanced/error_handling.rst | 73 +++++++++++++++++++++++++++ docs/advanced/index.rst | 1 + docs/modules/gql.rst | 1 + docs/modules/transport_exceptions.rst | 7 +++ gql/transport/exceptions.py | 11 +++- 5 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 docs/advanced/error_handling.rst create mode 100644 docs/modules/transport_exceptions.rst diff --git a/docs/advanced/error_handling.rst b/docs/advanced/error_handling.rst new file mode 100644 index 00000000..2fd1e39b --- /dev/null +++ b/docs/advanced/error_handling.rst @@ -0,0 +1,73 @@ +Error Handing +============= + +Local errors +------------ + +If gql detects locally that something does not correspond to the GraphQL specification, +then gql may raise a **GraphQLError** from graphql-core. + +This may happen for example: + +- if your query is not valid +- if your query does not correspond to your schema +- if the result received from the backend does not correspond to the schema + if :code:`parse_results` is set to True + +Transport errors +---------------- + +If an error happens with the transport, then gql may raise a +:class:`TransportError ` + +Here are the possible Transport Errors: + +- :class:`TransportProtocolError `: + Should never happen if the backend is a correctly configured GraphQL server. + It means that the answer received from the server does not correspond + to the transport protocol. + +- :class:`TransportServerError `: + There was an error communicating with the server. If this error is received, + then the connection with the server will be closed. This may happen if the server + returned a 404 http header for example. + The http error code is available in the exception :code:`code` attribute. + +- :class:`TransportQueryError `: + There was a specific error returned from the server for your query. + The message you receive in this error has been created by the backend, not gql! + In that case, the connection to the server is still available and you are + free to try to send other queries using the same connection. + The message of the exception contains the first error returned by the backend. + All the errors messages are available in the exception :code:`errors` attribute. + +- :class:`TransportClosed `: + This exception is generated when the client is trying to use the transport + while the transport was previously closed. + +- :class:`TransportAlreadyConnected `: + Exception generated when the client is trying to connect to the transport + while the transport is already connected. + +HTTP +^^^^ + +For HTTP transports, we should get a json response which contain +:code:`data` or :code:`errors` fields. +If that is not the case, then the returned error depends whether the http return code +is below 400 or not. + +- json response: + - with data or errors keys: + - no errors key -> no exception + - errors key -> raise **TransportQueryError** + - no data or errors keys: + - http code < 400: + raise **TransportProtocolError** + - http code >= 400: + raise **TransportServerError** +- not a json response: + - http code < 400: + raise **TransportProtocolError** + - http code >= 400: + raise **TransportServerError** diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index 637a8ea4..8005b381 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -6,5 +6,6 @@ Advanced async_advanced_usage logging + error_handling local_schema dsl_module diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 06a89a96..6730e07b 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -20,5 +20,6 @@ Sub-Packages client transport + transport_exceptions dsl utilities diff --git a/docs/modules/transport_exceptions.rst b/docs/modules/transport_exceptions.rst new file mode 100644 index 00000000..1c13ed00 --- /dev/null +++ b/docs/modules/transport_exceptions.rst @@ -0,0 +1,7 @@ +gql.transport.exceptions +======================== + +.. currentmodule:: gql.transport.exceptions + +.. automodule:: gql.transport.exceptions + :member-order: bysource diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 250e7523..89ae992b 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -2,6 +2,8 @@ class TransportError(Exception): + """Base class for all the Transport exceptions""" + pass @@ -18,7 +20,9 @@ class TransportServerError(TransportError): This exception will close the transport connection. """ - def __init__(self, message=None, code=None): + code: Optional[int] + + def __init__(self, message: str, code: Optional[int] = None): super(TransportServerError, self).__init__(message) self.code = code @@ -29,6 +33,11 @@ class TransportQueryError(Exception): This exception should not close the transport connection. """ + query_id: Optional[int] + errors: Optional[List[Any]] + data: Optional[Any] + extensions: Optional[Any] + def __init__( self, msg: str, From 47aaf92c995bcdfbdda38ffc9997bae6d64da479 Mon Sep 17 00:00:00 2001 From: Connor Brinton Date: Mon, 6 Dec 2021 12:14:12 -0500 Subject: [PATCH 044/239] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Expand=20aiohttp?= =?UTF-8?q?=20version=20range=20(#274)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c7e6dd7f..776be66e 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ ] + tests_requires install_aiohttp_requires = [ - "aiohttp>=3.7.1,<3.8.0", + "aiohttp>=3.7.1,<3.9.0", ] install_requests_requires = [ From 7eedbbc832a9ea94559056bc4085de9162b4b664 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 22:39:07 +0100 Subject: [PATCH 045/239] Enable mypy to discover type hints as specified in PEP 561 (#275) --- MANIFEST.in | 2 ++ gql/py.typed | 1 + setup.py | 2 ++ 3 files changed, 5 insertions(+) create mode 100644 gql/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index 4d7eaef4..73d59a18 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -13,6 +13,8 @@ include tox.ini include scripts/gql-cli +include gql/py.typed + recursive-include tests *.py *.graphql *.cnf *.yaml *.pem recursive-include docs *.txt *.rst conf.py Makefile make.bat *.jpg *.png *.gif recursive-include docs/code_examples *.py diff --git a/gql/py.typed b/gql/py.typed new file mode 100644 index 00000000..82879f7d --- /dev/null +++ b/gql/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The gql package uses inline types. diff --git a/setup.py b/setup.py index 776be66e..266fbb0c 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,8 @@ ], keywords="api graphql protocol rest relay gql client", packages=find_packages(include=["gql*"]), + # PEP-561: https://www.python.org/dev/peps/pep-0561/ + package_data={"gql": ["py.typed"]}, install_requires=install_requires, tests_require=install_all_requires + tests_requires, extras_require={ From c7d65c74ca017291d576e4341093cb0d1b0aec8d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 8 Dec 2021 15:56:31 +0100 Subject: [PATCH 046/239] gql-cli add signal handlers to catch ctrl-c and close cleanly (#276) Now the gql-cli script works on Python 3.6 too --- scripts/gql-cli | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/scripts/gql-cli b/scripts/gql-cli index 055919ff..b2a079a3 100755 --- a/scripts/gql-cli +++ b/scripts/gql-cli @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import asyncio import sys +from signal import SIGINT, SIGTERM from gql.cli import get_parser, main @@ -9,8 +10,23 @@ parser = get_parser(with_examples=True) args = parser.parse_args() try: - # Execute the script - exit_code = asyncio.run(main(args)) + # Create a new asyncio event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a gql-cli task with the supplied arguments + main_task = asyncio.ensure_future(main(args), loop=loop) + + # Add signal handlers to close gql-cli cleanly on Control-C + for signal in [SIGINT, SIGTERM]: + loop.add_signal_handler(signal, main_task.cancel) + + # Run the asyncio loop to execute the task + exit_code = 0 + try: + exit_code = loop.run_until_complete(main_task) + finally: + loop.close() # Return with the correct exit code sys.exit(exit_code) From 2be6aaa52812d5b948eaa921a4b40f03aa6547ae Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 9 Dec 2021 10:18:21 +0100 Subject: [PATCH 047/239] Better way to support Python 3.10 for ensure_future (#277) --- gql/client.py | 8 +++----- tests/test_websocket_subscription.py | 7 ++++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/gql/client.py b/gql/client.py index 111a3dd7..f5f6872d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -220,11 +220,9 @@ def subscribe( # Note: we need to create a task here in order to be able to close # the async generator properly on python 3.8 # See https://bugs.python.org/issue38559 - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - generator_task = asyncio.ensure_future(async_generator.__anext__()) + generator_task = asyncio.ensure_future( + async_generator.__anext__(), loop=loop + ) result = loop.run_until_complete(generator_task) yield result diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 43795b14..14ffe0a2 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -521,6 +521,8 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) count = 10 subscription = gql(subscription_str.format(count=count)) + interrupt_task = None + with pytest.raises(KeyboardInterrupt): for result in client.subscribe(subscription): @@ -536,7 +538,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) warnings.filterwarnings( "ignore", message="There is no current event loop" ) - asyncio.ensure_future( + interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -544,6 +546,9 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) assert count == 4 + # Catch interrupt_task exception to remove warning + interrupt_task.exception() + # Check that the server received a connection_terminate message last assert logged_messages.pop() == '{"type": "connection_terminate"}' From 09e4b306988942216b3469591da2efb67bea5c2c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 10 Dec 2021 08:08:13 +0100 Subject: [PATCH 048/239] Add skip and include directive in introspection schema (#279) --- gql/client.py | 2 +- gql/utilities/__init__.py | 2 + gql/utilities/build_client_schema.py | 89 ++++++++++++++++++++++++++++ tests/starwars/test_validation.py | 71 +++++++++++++++++++++- 4 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 gql/utilities/build_client_schema.py diff --git a/gql/client.py b/gql/client.py index f5f6872d..2236189d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -7,7 +7,6 @@ ExecutionResult, GraphQLSchema, build_ast_schema, - build_client_schema, get_introspection_query, parse, validate, @@ -17,6 +16,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .utilities import build_client_schema from .utilities import parse_result as parse_result_fn from .utilities import serialize_variable_values diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index 7089d360..3d29dfe3 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,3 +1,4 @@ +from .build_client_schema import build_client_schema from .get_introspection_query_ast import get_introspection_query_ast from .parse_result import parse_result from .serialize_variable_values import serialize_value, serialize_variable_values @@ -5,6 +6,7 @@ from .update_schema_scalars import update_schema_scalar, update_schema_scalars __all__ = [ + "build_client_schema", "parse_result", "get_introspection_query_ast", "serialize_variable_values", diff --git a/gql/utilities/build_client_schema.py b/gql/utilities/build_client_schema.py new file mode 100644 index 00000000..78fb7586 --- /dev/null +++ b/gql/utilities/build_client_schema.py @@ -0,0 +1,89 @@ +from typing import Dict + +from graphql import GraphQLSchema +from graphql import build_client_schema as build_client_schema_orig +from graphql.pyutils import inspect + +__all__ = ["build_client_schema"] + + +INCLUDE_DIRECTIVE_JSON = { + "name": "include", + "description": ( + "Directs the executor to include this field or fragment " + "only when the `if` argument is true." + ), + "locations": ["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"], + "args": [ + { + "name": "if", + "description": "Included when true.", + "type": { + "kind": "NON_NULL", + "name": "None", + "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": "None"}, + }, + "defaultValue": "None", + } + ], +} + +SKIP_DIRECTIVE_JSON = { + "name": "skip", + "description": ( + "Directs the executor to skip this field or fragment " + "when the `if` argument is true." + ), + "locations": ["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"], + "args": [ + { + "name": "if", + "description": "Skipped when true.", + "type": { + "kind": "NON_NULL", + "name": "None", + "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": "None"}, + }, + "defaultValue": "None", + } + ], +} + + +def build_client_schema(introspection: Dict) -> GraphQLSchema: + """This is an alternative to the graphql-core function + :code:`build_client_schema` but with default include and skip directives + added to the schema to fix + `issue #278 `_ + + .. warning:: + This function will be removed once the issue + `graphql-js#3419 `_ + has been fixed and ported to graphql-core so don't use it + outside gql. + """ + + if not isinstance(introspection, dict) or not isinstance( + introspection.get("__schema"), dict + ): + raise TypeError( + "Invalid or incomplete introspection result. Ensure that you" + " are passing the 'data' attribute of an introspection response" + f" and no 'errors' were returned alongside: {inspect(introspection)}." + ) + + schema_introspection = introspection["__schema"] + + directives = schema_introspection.get("directives", None) + + if directives is None: + directives = [] + schema_introspection["directives"] = directives + + if not any(directive["name"] == "skip" for directive in directives): + directives.append(SKIP_DIRECTIVE_JSON) + + if not any(directive["name"] == "include" for directive in directives): + directives.append(INCLUDE_DIRECTIVE_JSON) + + return build_client_schema_orig(introspection, assume_valid=False) diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index 468bb553..1ca8a2bb 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -60,7 +60,35 @@ def introspection_schema(): return Client(introspection=StarWarsIntrospection) -@pytest.fixture(params=["local_schema", "typedef_schema", "introspection_schema"]) +@pytest.fixture +def introspection_schema_empty_directives(): + introspection = StarWarsIntrospection + + # Simulate an empty dictionary for directives + introspection["__schema"]["directives"] = [] + + return Client(introspection=introspection) + + +@pytest.fixture +def introspection_schema_no_directives(): + introspection = StarWarsIntrospection + + # Simulate no directives key + del introspection["__schema"]["directives"] + + return Client(introspection=introspection) + + +@pytest.fixture( + params=[ + "local_schema", + "typedef_schema", + "introspection_schema", + "introspection_schema_empty_directives", + "introspection_schema_no_directives", + ] +) def client(request): return request.getfixturevalue(request.param) @@ -187,3 +215,44 @@ def test_allows_object_fields_in_inline_fragments(client): } """ assert not validation_errors(client, query) + + +def test_include_directive(client): + query = """ + query fetchHero($with_friends: Boolean!) { + hero { + name + friends @include(if: $with_friends) { + name + } + } + } + """ + assert not validation_errors(client, query) + + +def test_skip_directive(client): + query = """ + query fetchHero($without_friends: Boolean!) { + hero { + name + friends @skip(if: $without_friends) { + name + } + } + } + """ + assert not validation_errors(client, query) + + +def test_build_client_schema_invalid_introspection(): + from gql.utilities import build_client_schema + + with pytest.raises(TypeError) as exc_info: + build_client_schema("blah") + + assert ( + "Invalid or incomplete introspection result. Ensure that you are passing the " + "'data' attribute of an introspection response and no 'errors' were returned " + "alongside: 'blah'." + ) in str(exc_info.value) From 3760f5b117c7247315ab780bc26f47e146f0c57d Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Fri, 10 Dec 2021 02:14:34 -0500 Subject: [PATCH 049/239] Aws appsync websocket transport (#239) Also refactor the websockets transport with WebsocketsTransportBase --- README.md | 7 +- .../code_examples/appsync/mutation_api_key.py | 54 ++ docs/code_examples/appsync/mutation_iam.py | 53 ++ .../appsync/subscription_api_key.py | 53 ++ .../code_examples/appsync/subscription_iam.py | 44 ++ docs/intro.rst | 30 +- docs/modules/gql.rst | 7 + docs/modules/transport.rst | 10 +- docs/modules/transport_aiohttp.rst | 7 + docs/modules/transport_appsync_auth.rst | 7 + docs/modules/transport_appsync_websockets.rst | 7 + .../transport_phoenix_channel_websockets.rst | 7 + docs/modules/transport_requests.rst | 7 + docs/modules/transport_websockets.rst | 7 + docs/modules/transport_websockets_base.rst | 7 + docs/transports/appsync.rst | 156 ++++ docs/transports/async_transports.rst | 1 + gql/client.py | 6 + gql/transport/aiohttp.py | 16 +- gql/transport/appsync_auth.py | 221 ++++++ gql/transport/appsync_websockets.py | 209 ++++++ gql/transport/websockets.py | 628 ++-------------- gql/transport/websockets_base.py | 666 +++++++++++++++++ setup.py | 7 +- tests/conftest.py | 17 +- tests/fixtures/__init__.py | 0 tests/fixtures/aws/__init__.py | 0 tests/fixtures/aws/fake_credentials.py | 28 + tests/fixtures/aws/fake_request.py | 22 + tests/fixtures/aws/fake_session.py | 24 + tests/fixtures/aws/fake_signer.py | 27 + tests/test_appsync_auth.py | 189 +++++ tests/test_appsync_http.py | 78 ++ tests/test_appsync_websockets.py | 702 ++++++++++++++++++ tox.ini | 3 - 35 files changed, 2704 insertions(+), 603 deletions(-) create mode 100644 docs/code_examples/appsync/mutation_api_key.py create mode 100644 docs/code_examples/appsync/mutation_iam.py create mode 100644 docs/code_examples/appsync/subscription_api_key.py create mode 100644 docs/code_examples/appsync/subscription_iam.py create mode 100644 docs/modules/transport_aiohttp.rst create mode 100644 docs/modules/transport_appsync_auth.rst create mode 100644 docs/modules/transport_appsync_websockets.rst create mode 100644 docs/modules/transport_phoenix_channel_websockets.rst create mode 100644 docs/modules/transport_requests.rst create mode 100644 docs/modules/transport_websockets.rst create mode 100644 docs/modules/transport_websockets_base.rst create mode 100644 docs/transports/appsync.rst create mode 100644 gql/transport/appsync_auth.py create mode 100644 gql/transport/appsync_websockets.py create mode 100644 gql/transport/websockets_base.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/aws/__init__.py create mode 100644 tests/fixtures/aws/fake_credentials.py create mode 100644 tests/fixtures/aws/fake_request.py create mode 100644 tests/fixtures/aws/fake_session.py create mode 100644 tests/fixtures/aws/fake_signer.py create mode 100644 tests/test_appsync_auth.py create mode 100644 tests/test_appsync_http.py create mode 100644 tests/test_appsync_websockets.py diff --git a/README.md b/README.md index 2fa37978..0962c80e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,12 @@ The complete documentation for GQL can be found at The main features of GQL are: -* Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html) (http, websockets, ...) +* Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html): + * http + * websockets: + * apollo or graphql-ws protocol + * Phoenix channels + * AWS AppSync realtime protocol (experimental) * Possibility to [validate the queries locally](https://gql.readthedocs.io/en/latest/usage/validation.html) using a GraphQL schema provided locally or fetched from the backend using an instrospection query * Supports GraphQL queries, mutations and [subscriptions](https://gql.readthedocs.io/en/latest/usage/subscriptions.html) * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) diff --git a/docs/code_examples/appsync/mutation_api_key.py b/docs/code_examples/appsync/mutation_api_key.py new file mode 100644 index 00000000..052da850 --- /dev/null +++ b/docs/code_examples/appsync/mutation_api_key.py @@ -0,0 +1,54 @@ +import asyncio +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) + + transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client( + transport=transport, fetch_schema_from_transport=False, + ) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + result = await session.execute(query, variable_values=variable_values) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/appsync/mutation_iam.py b/docs/code_examples/appsync/mutation_iam.py new file mode 100644 index 00000000..327e0d91 --- /dev/null +++ b/docs/code_examples/appsync/mutation_iam.py @@ -0,0 +1,53 @@ +import asyncio +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.appsync_auth import AppSyncIAMAuthentication + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + + if url is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncIAMAuthentication(host=host) + + transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client( + transport=transport, fetch_schema_from_transport=False, + ) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + result = await session.execute(query, variable_values=variable_values) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/appsync/subscription_api_key.py b/docs/code_examples/appsync/subscription_api_key.py new file mode 100644 index 00000000..87bb3611 --- /dev/null +++ b/docs/code_examples/appsync/subscription_api_key.py @@ -0,0 +1,53 @@ +import asyncio +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.appsync_auth import AppSyncApiKeyAuthentication +from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + print(f"Host: {host}") + + auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + async with Client(transport=transport) as session: + + subscription = gql( + """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + ) + + print("Waiting for messages...") + + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/appsync/subscription_iam.py b/docs/code_examples/appsync/subscription_iam.py new file mode 100644 index 00000000..1bb540d0 --- /dev/null +++ b/docs/code_examples/appsync/subscription_iam.py @@ -0,0 +1,44 @@ +import asyncio +import os +import sys + +from gql import Client, gql +from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + + if url is None: + print("Missing environment variables") + sys.exit() + + # Using implicit auth (IAM) + transport = AppSyncWebsocketsTransport(url=url) + + async with Client(transport=transport) as session: + + subscription = gql( + """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + ) + + print("Waiting for messages...") + + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/intro.rst b/docs/intro.rst index e377c56e..1cd3f5c8 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -35,19 +35,23 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: pip install --pre gql[aiohttp] -The corresponding between extra dependencies required and the GQL transports is: - -+-------------------+----------------------------------------------------------------+ -| Extra dependency | Transports | -+===================+================================================================+ -| aiohttp | :ref:`AIOHTTPTransport ` | -+-------------------+----------------------------------------------------------------+ -| websockets | :ref:`WebsocketsTransport ` | -| | | -| | :ref:`PhoenixChannelWebsocketsTransport ` | -+-------------------+----------------------------------------------------------------+ -| requests | :ref:`RequestsHTTPTransport ` | -+-------------------+----------------------------------------------------------------+ +The corresponding between extra dependencies required and the GQL classes is: + ++---------------------+----------------------------------------------------------------+ +| Extra dependencies | Classes | ++=====================+================================================================+ +| aiohttp | :ref:`AIOHTTPTransport ` | ++---------------------+----------------------------------------------------------------+ +| websockets | :ref:`WebsocketsTransport ` | +| | | +| | :ref:`PhoenixChannelWebsocketsTransport ` | +| | | +| | :ref:`AppSyncWebsocketsTransport ` | ++---------------------+----------------------------------------------------------------+ +| requests | :ref:`RequestsHTTPTransport ` | ++---------------------+----------------------------------------------------------------+ +| botocore | :ref:`AppSyncIAMAuthentication ` | ++---------------------+----------------------------------------------------------------+ .. note:: diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 6730e07b..be6f904b 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -20,6 +20,13 @@ Sub-Packages client transport + transport_aiohttp + transport_appsync_auth + transport_appsync_websockets transport_exceptions + transport_phoenix_channel_websockets + transport_requests + transport_websockets + transport_websockets_base dsl utilities diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index 1b250d7a..d03dbf1f 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -5,14 +5,6 @@ gql.transport .. autoclass:: gql.transport.transport.Transport -.. autoclass:: gql.transport.local_schema.LocalSchemaTransport - -.. autoclass:: gql.transport.requests.RequestsHTTPTransport - .. autoclass:: gql.transport.async_transport.AsyncTransport -.. autoclass:: gql.transport.aiohttp.AIOHTTPTransport - -.. autoclass:: gql.transport.websockets.WebsocketsTransport - -.. autoclass:: gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport +.. autoclass:: gql.transport.local_schema.LocalSchemaTransport diff --git a/docs/modules/transport_aiohttp.rst b/docs/modules/transport_aiohttp.rst new file mode 100644 index 00000000..41cebd99 --- /dev/null +++ b/docs/modules/transport_aiohttp.rst @@ -0,0 +1,7 @@ +gql.transport.aiohttp +===================== + +.. currentmodule:: gql.transport.aiohttp + +.. automodule:: gql.transport.aiohttp + :member-order: bysource diff --git a/docs/modules/transport_appsync_auth.rst b/docs/modules/transport_appsync_auth.rst new file mode 100644 index 00000000..b8ac42c0 --- /dev/null +++ b/docs/modules/transport_appsync_auth.rst @@ -0,0 +1,7 @@ +gql.transport.appsync_auth +========================== + +.. currentmodule:: gql.transport.appsync_auth + +.. automodule:: gql.transport.appsync_auth + :member-order: bysource diff --git a/docs/modules/transport_appsync_websockets.rst b/docs/modules/transport_appsync_websockets.rst new file mode 100644 index 00000000..f0d9523d --- /dev/null +++ b/docs/modules/transport_appsync_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.appsync_websockets +================================ + +.. currentmodule:: gql.transport.appsync_websockets + +.. automodule:: gql.transport.appsync_websockets + :member-order: bysource diff --git a/docs/modules/transport_phoenix_channel_websockets.rst b/docs/modules/transport_phoenix_channel_websockets.rst new file mode 100644 index 00000000..5f412a33 --- /dev/null +++ b/docs/modules/transport_phoenix_channel_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.phoenix_channel_websockets +======================================== + +.. currentmodule:: gql.transport.phoenix_channel_websockets + +.. automodule:: gql.transport.phoenix_channel_websockets + :member-order: bysource diff --git a/docs/modules/transport_requests.rst b/docs/modules/transport_requests.rst new file mode 100644 index 00000000..78a07a02 --- /dev/null +++ b/docs/modules/transport_requests.rst @@ -0,0 +1,7 @@ +gql.transport.requests +====================== + +.. currentmodule:: gql.transport.requests + +.. automodule:: gql.transport.requests + :member-order: bysource diff --git a/docs/modules/transport_websockets.rst b/docs/modules/transport_websockets.rst new file mode 100644 index 00000000..9a924afd --- /dev/null +++ b/docs/modules/transport_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.websockets +======================== + +.. currentmodule:: gql.transport.websockets + +.. automodule:: gql.transport.websockets + :member-order: bysource diff --git a/docs/modules/transport_websockets_base.rst b/docs/modules/transport_websockets_base.rst new file mode 100644 index 00000000..548351eb --- /dev/null +++ b/docs/modules/transport_websockets_base.rst @@ -0,0 +1,7 @@ +gql.transport.websockets_base +============================= + +.. currentmodule:: gql.transport.websockets_base + +.. automodule:: gql.transport.websockets_base + :member-order: bysource diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst new file mode 100644 index 00000000..7ceb7480 --- /dev/null +++ b/docs/transports/appsync.rst @@ -0,0 +1,156 @@ +.. _appsync_transport: + +AppSyncWebsocketsTransport +========================== + +AWS AppSync allows you to execute GraphQL subscriptions on its realtime GraphQL endpoint. + +See `Building a real-time websocket client`_ for an explanation. + +GQL provides the :code:`AppSyncWebsocketsTransport` transport which implements this +for you to allow you to execute subscriptions. + +.. note:: + It is only possible to execute subscriptions with this transport. + For queries or mutations, See :ref:`AppSync GraphQL Queries and mutations ` + +How to use it: + + * choose one :ref:`authentication method ` (API key, IAM, Cognito user pools or OIDC) + * instantiate a :code:`AppSyncWebsocketsTransport` with your GraphQL endpoint as url and your auth method + +.. note:: + It is also possible to instantiate the transport without an auth argument. In that case, + gql will use by default the :class:`IAM auth ` + which will try to authenticate with environment variables or from your aws credentials file. + +.. note:: + All the examples in this documentation are based on the sample app created + by following `this AWS blog post`_ + +Full example with API key authentication from environment variables: + +.. literalinclude:: ../code_examples/appsync/subscription_api_key.py + +Reference: :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` + +.. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html +.. _this AWS blog post: https://aws.amazon.com/fr/blogs/mobile/appsync-realtime/ + + +.. _appsync_authentication_methods: + +Authentication methods +---------------------- + +.. _appsync_api_key_auth: + +API key +^^^^^^^ + +Use the :code:`AppSyncApiKeyAuthentication` class to provide your API key: + +.. code-block:: python + + auth = AppSyncApiKeyAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + api_key="YOUR_API_KEY", + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + +Reference: :class:`gql.transport.appsync_auth.AppSyncApiKeyAuthentication` + +.. _appsync_iam_auth: + +IAM +^^^ + +For the IAM authentication, you can simply create your transport without +an auth argument. + +The region name will be autodetected from the url or from your AWS configuration +(:code:`.aws/config`) or the environment variable: + +- AWS_DEFAULT_REGION + +The credentials will be detected from your AWS configuration file +(:code:`.aws/credentials`) or from the environment variables: + +- AWS_ACCESS_KEY_ID +- AWS_SECRET_ACCESS_KEY +- AWS_SESSION_TOKEN (optional) + +.. code-block:: python + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + ) + +OR You can also provide the credentials manually by creating the +:code:`AppSyncIAMAuthentication` class yourself: + +.. code-block:: python + + from botocore.credentials import Credentials + + credentials = Credentials( + access_key = os.environ.get("AWS_ACCESS_KEY_ID"), + secret_key= os.environ.get("AWS_SECRET_ACCESS_KEY"), + token=os.environ.get("AWS_SESSION_TOKEN", None), # Optional + ) + + auth = AppSyncIAMAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + credentials=credentials, + region_name="your region" + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + +Reference: :class:`gql.transport.appsync_auth.AppSyncIAMAuthentication` + +.. _appsync_jwt_auth: + +Json Web Tokens (jwt) +^^^^^^^^^^^^^^^^^^^^^ + +AWS provides json web tokens (jwt) for the authentication methods: + +- Amazon Cognito user pools +- OpenID Connect (OIDC) + +For these authentication methods, you can use the :code:`AppSyncJWTAuthentication` class: + +.. code-block:: python + + auth = AppSyncJWTAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + jwt="YOUR_JWT_STRING", + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + +Reference: :class:`gql.transport.appsync_auth.AppSyncJWTAuthentication` + +.. _appsync_http: + +AppSync GraphQL Queries and mutations +------------------------------------- + +Queries and mutations are not allowed on the realtime websockets endpoint. +But you can use the :ref:`AIOHTTPTransport ` to create +a normal http session and reuse the authentication classes to create the headers for you. + +Full example with API key authentication from environment variables: + +.. literalinclude:: ../code_examples/appsync/mutation_api_key.py diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index 9fb1b017..df8c23cf 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -12,3 +12,4 @@ Async transports are transports which are using an underlying async library. The aiohttp websockets phoenix + appsync diff --git a/gql/client.py b/gql/client.py index 2236189d..e10f7509 100644 --- a/gql/client.py +++ b/gql/client.py @@ -82,6 +82,12 @@ def __init__( not schema ), "Cannot fetch the schema from transport if is already provided." + assert not type(transport).__name__ == "AppSyncWebsocketsTransport", ( + "fetch_schema_from_transport=True is not allowed " + "for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) + if schema and not transport: transport = LocalSchemaTransport(schema) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index f34a0066..12c57068 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -14,6 +14,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast from ..utils import extract_files +from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport from .exceptions import ( TransportAlreadyConnected, @@ -43,7 +44,7 @@ def __init__( url: str, headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, @@ -55,6 +56,7 @@ def __init__( :param headers: Dict of HTTP Headers. :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed + Or Appsync Authentication class :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly @@ -67,7 +69,7 @@ def __init__( self.url: str = url self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies - self.auth: Optional[BasicAuth] = auth + self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout @@ -89,7 +91,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": self.auth, + "auth": None + if isinstance(self.auth, AppSyncAuthentication) + else self.auth, } if self.timeout is not None: @@ -266,6 +270,12 @@ async def execute( if extra_args: post_args.update(extra_args) + # Add headers for AppSync if requested + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + json.dumps(payload), {"content-type": "application/json"}, + ) + if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/transport/appsync_auth.py b/gql/transport/appsync_auth.py new file mode 100644 index 00000000..04c07c10 --- /dev/null +++ b/gql/transport/appsync_auth.py @@ -0,0 +1,221 @@ +import json +import logging +import re +from abc import ABC, abstractmethod +from base64 import b64encode +from typing import Any, Callable, Dict, Optional + +try: + import botocore +except ImportError: # pragma: no cover + # botocore is only needed for the IAM AppSync authentication method + pass + +log = logging.getLogger("gql.transport.appsync") + + +class AppSyncAuthentication(ABC): + """AWS authentication abstract base class + + All AWS authentication class should have a + :meth:`get_headers ` + method which defines the headers used in the authentication process.""" + + def get_auth_url(self, url: str) -> str: + """ + :return: a url with base64 encoded headers used to establish + a websocket connection to the appsync-realtime-api. + """ + headers = self.get_headers() + + encoded_headers = b64encode( + json.dumps(headers, separators=(",", ":")).encode() + ).decode() + + url_base = url.replace("https://", "wss://").replace( + "appsync-api", "appsync-realtime-api" + ) + + return f"{url_base}?header={encoded_headers}&payload=e30=" + + @abstractmethod + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + raise NotImplementedError() # pragma: no cover + + +class AppSyncApiKeyAuthentication(AppSyncAuthentication): + """AWS authentication class using an API key""" + + def __init__(self, host: str, api_key: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param api_key: the API key + """ + self._host = host + self.api_key = api_key + + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return {"host": self._host, "x-api-key": self.api_key} + + +class AppSyncJWTAuthentication(AppSyncAuthentication): + """AWS authentication class using a JWT access token. + + It can be used either for: + - Amazon Cognito user pools + - OpenID Connect (OIDC) + """ + + def __init__(self, host: str, jwt: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param jwt: the JWT Access Token + """ + self._host = host + self.jwt = jwt + + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return {"host": self._host, "Authorization": self.jwt} + + +class AppSyncIAMAuthentication(AppSyncAuthentication): + """AWS authentication class using IAM. + + .. note:: + There is no need for you to use this class directly, you could instead + intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport` + without an auth argument. + + During initialization, this class will use botocore to attempt to + find your IAM credentials, either from environment variables or + from your AWS credentials file. + """ + + def __init__( + self, + host: str, + region_name: Optional[str] = None, + signer: Optional["botocore.auth.BaseSigner"] = None, + request_creator: Optional[ + Callable[[Dict[str, Any]], "botocore.awsrequest.AWSRequest"] + ] = None, + credentials: Optional["botocore.credentials.Credentials"] = None, + session: Optional["botocore.session.Session"] = None, + ) -> None: + """Initialize itself, saving the found credentials used + to sign the headers later. + + if no credentials are found, then a NoCredentialsError is raised. + """ + + from botocore.auth import SigV4Auth + from botocore.awsrequest import create_request_object + from botocore.session import get_session + + self._host = host + self._session = session if session else get_session() + self._credentials = ( + credentials if credentials else self._session.get_credentials() + ) + self._service_name = "appsync" + self._region_name = region_name or self._detect_region_name() + self._signer = ( + signer + if signer + else SigV4Auth(self._credentials, self._service_name, self._region_name) + ) + self._request_creator = ( + request_creator if request_creator else create_request_object + ) + + def _detect_region_name(self): + """Try to detect the correct region_name. + + First try to extract the region_name from the host. + + If that does not work, then try to get the region_name from + the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION + environment variable. + + If no region_name was found, then raise a NoRegionError exception.""" + + from botocore.exceptions import NoRegionError + + # Regular expression from botocore.utils.validate_region + m = re.search( + r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(? Dict[str, Any]: + + from botocore.exceptions import NoCredentialsError + + # Default headers for a websocket connection + headers = headers or { + "accept": "application/json, text/javascript", + "content-encoding": "amz-1.0", + "content-type": "application/json; charset=UTF-8", + } + + request: "botocore.awsrequest.AWSRequest" = self._request_creator( + { + "method": "POST", + "url": f"https://{self._host}/graphql{'' if data else '/connect'}", + "headers": headers, + "context": {}, + "body": data or "{}", + } + ) + + try: + self._signer.add_auth(request) + except NoCredentialsError: + log.warning( + "Credentials not found. " + "Do you have default AWS credentials configured?", + ) + raise + + headers = dict(request.headers) + + headers["host"] = self._host + + if log.isEnabledFor(logging.DEBUG): + headers_log = [] + headers_log.append("\n\nSigned headers:") + for key, value in headers.items(): + headers_log.append(f" {key}: {value}") + headers_log.append("\n") + log.debug("\n".join(headers_log)) + + return headers diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py new file mode 100644 index 00000000..c7e05a09 --- /dev/null +++ b/gql/transport/appsync_websockets.py @@ -0,0 +1,209 @@ +import json +import logging +from ssl import SSLContext +from typing import Any, Dict, Optional, Tuple, Union, cast +from urllib.parse import urlparse + +from graphql import DocumentNode, ExecutionResult, print_ast + +from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication +from .exceptions import TransportProtocolError, TransportServerError +from .websockets import WebsocketsTransport, WebsocketsTransportBase + +log = logging.getLogger("gql.transport.appsync") + +try: + import botocore +except ImportError: # pragma: no cover + # botocore is only needed for the IAM AppSync authentication method + pass + + +class AppSyncWebsocketsTransport(WebsocketsTransportBase): + """:ref:`Async Transport ` used to execute GraphQL subscription on + AWS appsync realtime endpoint. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + auth: Optional[AppSyncAuthentication] + + def __init__( + self, + url: str, + auth: Optional[AppSyncAuthentication] = None, + session: Optional["botocore.session.Session"] = None, + ssl: Union[SSLContext, bool] = False, + connect_timeout: int = 10, + close_timeout: int = 10, + ack_timeout: int = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL endpoint URL. Example: + https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + :param auth: Optional AWS authentication class which will provide the + necessary headers to be correctly authenticated. If this + argument is not provided, then we will try to authenticate + using IAM. + :param ssl: ssl_context of the connection. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + """ + + if not auth: + + # Extract host from url + host = str(urlparse(url).netloc) + + # May raise NoRegionError or NoCredentialsError or ImportError + auth = AppSyncIAMAuthentication(host=host, session=session) + + self.auth = auth + + url = self.auth.get_auth_url(url) + + super().__init__( + url, + ssl=ssl, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + connect_args=connect_args, + ) + + # Using the same 'graphql-ws' protocol as the apollo protocol + self.supported_subprotocols = [ + WebsocketsTransport.APOLLO_SUBPROTOCOL, + ] + self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server. + + Difference between apollo protocol and aws protocol: + + - aws protocol can return an error without an id + - aws protocol will send start_ack messages + + Returns a list consisting of: + - the answer_type: + - 'connection_ack', + - 'connection_error', + - 'start_ack', + - 'ka', + - 'data', + - 'error', + - 'complete' + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + + try: + json_answer = json.loads(answer) + + answer_type = str(json_answer.get("type")) + + if answer_type == "start_ack": + return ("start_ack", None, None) + + elif answer_type == "error" and "id" not in json_answer: + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{error_payload!r}'") + + else: + + return WebsocketsTransport._parse_answer_apollo( + cast(WebsocketsTransport, self), json_answer + ) + + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + + query_id = self.next_query_id + + self.next_query_id += 1 + + data: Dict = {"query": print_ast(document)} + + if variable_values: + data["variables"] = variable_values + + if operation_name: + data["operationName"] = operation_name + + serialized_data = json.dumps(data, separators=(",", ":")) + + payload = {"data": serialized_data} + + message: Dict = { + "id": str(query_id), + "type": "start", + "payload": payload, + } + + assert self.auth is not None + + message["payload"]["extensions"] = { + "authorization": self.auth.get_headers(serialized_data) + } + + await self._send(json.dumps(message, separators=(",", ":"),)) + + return query_id + + subscribe = WebsocketsTransportBase.subscribe + """Send a subscription query and receive the results using + a python async generator. + + Only subscriptions are supported, queries and mutations are forbidden. + + The results are sent as an ExecutionResult object. + """ + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """This method is not available. + + Only subscriptions are supported on the AWS realtime endpoint. + + :raise: AssertionError""" + raise AssertionError( + "execute method is not allowed for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) + + _initialize = WebsocketsTransport._initialize + _stop_listener = WebsocketsTransport._send_stop_message # type: ignore + _send_init_message_and_wait_ack = ( + WebsocketsTransport._send_init_message_and_wait_ack + ) + _wait_ack = WebsocketsTransport._wait_ack diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 779a3608..41478daf 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,85 +1,25 @@ import asyncio import json import logging -import warnings from contextlib import suppress from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast +from typing import Any, Dict, Optional, Tuple, Union, cast -import websockets from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.client import WebSocketClientProtocol from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol +from websockets.typing import Subprotocol -from .async_transport import AsyncTransport from .exceptions import ( - TransportAlreadyConnected, - TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class WebsocketsTransport(AsyncTransport): +class WebsocketsTransport(WebsocketsTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. @@ -133,15 +73,18 @@ def __init__( :param connect_args: Other parameters forwarded to websockets.connect """ - self.url: str = url - self.ssl: Union[SSLContext, bool] = ssl - self.headers: Optional[HeadersLike] = headers - self.init_payload: Dict[str, Any] = init_payload + super().__init__( + url, + headers, + ssl, + init_payload, + connect_timeout, + close_timeout, + ack_timeout, + keep_alive_timeout, + connect_args, + ) - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout self.ping_interval: Optional[Union[int, float]] = ping_interval self.pong_timeout: Optional[Union[int, float]] self.answer_pings: bool = answer_pings @@ -152,38 +95,7 @@ def __init__( else: self.pong_timeout = pong_timeout - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None self.send_ping_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - if self.keep_alive_timeout is not None: - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() self.ping_received: asyncio.Event = asyncio.Event() """ping_received is an asyncio Event which will fire each time @@ -193,56 +105,11 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - with the graphql-ws protocol. - Possible keys are: 'ping', 'pong', 'connection_ack'""" - - self._connecting: bool = False - - self.close_exception: Optional[Exception] = None - self.supported_subprotocols = [ self.APOLLO_SUBPROTOCOL, self.GRAPHQLWS_SUBPROTOCOL, ] - async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" - - if not self.websocket: - raise TransportClosed( - "Transport is not connected" - ) from self.close_exception - - try: - await self.websocket.send(message) - log.info(">>> %s", message) - except ConnectionClosed as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data - - log.info("<<< %s", answer) - - return answer - async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" @@ -274,6 +141,9 @@ async def _send_init_message_and_wait_ack(self) -> None: # Wait for the connection_ack message or raise a TimeoutError await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + async def _initialize(self): + await self._send_init_message_and_wait_ack() + async def send_ping(self, payload: Optional[Any] = None) -> None: """Send a ping message for the graphql-ws protocol """ @@ -316,7 +186,7 @@ async def _send_complete_message(self, query_id: int) -> None: await self._send(complete_message) - async def _stop_listener(self, query_id: int) -> None: + async def _stop_listener(self, query_id: int): """Stop the listener corresponding to the query_id depending on the detected backend protocol. @@ -326,6 +196,8 @@ async def _stop_listener(self, query_id: int) -> None: For graphql-ws: send a "complete" message and simulate the reception of a "complete" message from the backend """ + log.debug(f"stop listener {query_id}") + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: await self._send_complete_message(query_id) await self.listeners[query_id].put(("complete", None)) @@ -377,8 +249,12 @@ async def _send_query( return query_id + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + def _parse_answer_graphqlws( - self, answer: str + self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the graphql-ws protocol. @@ -403,8 +279,6 @@ def _parse_answer_graphqlws( execution_result: Optional[ExecutionResult] = None try: - json_answer = json.loads(answer) - answer_type = str(json_answer.get("type")) if answer_type in ["next", "error", "complete"]: @@ -450,13 +324,13 @@ def _parse_answer_graphqlws( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" + f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result def _parse_answer_apollo( - self, answer: str + self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the apollo websockets protocol. @@ -473,8 +347,6 @@ def _parse_answer_apollo( execution_result: Optional[ExecutionResult] = None try: - json_answer = json.loads(answer) - answer_type = str(json_answer.get("type")) if answer_type in ["data", "error", "complete"]: @@ -520,7 +392,7 @@ def _parse_answer_apollo( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" + f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result @@ -531,44 +403,17 @@ def _parse_answer( """Parse the answer received from the server depending on the detected subprotocol. """ - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(answer) - - return self._parse_answer_apollo(answer) - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass + return self._parse_answer_apollo(json_answer) async def _send_ping_coro(self) -> None: """Coroutine to periodically send a ping from the client to the backend. @@ -603,52 +448,6 @@ async def _send_ping_coro(self) -> None: clean_close=False, ) - async def _receive_data_loop(self) -> None: - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed: - break - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - async def _handle_answer( self, answer_type: str, @@ -656,13 +455,8 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) # Answer pong to ping for graphql-ws protocol if answer_type == "ping": @@ -673,334 +467,34 @@ async def _handle_answer( elif answer_type == "pong": self.pong_received.set() - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name - ) - - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False - ) - - async for result in generator: - first_result = result - - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - await generator.aclose() - - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def connect(self) -> None: - """Coroutine which will: - - - connect to the websocket address - - send the init message - - wait for the connection acknowledge from the server - - create an asyncio task which will be used to receive - and parse the websocket answers - - Should be cleaned with a call to the close coroutine - """ - - log.debug("connect: starting") - - if self.websocket is None and not self._connecting: - - # Set connecting to True to avoid a race condition if user is trying - # to connect twice using the same client at the same time - self._connecting = True - - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This will generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._send_init_message_and_wait_ack() - except ConnectionClosed as e: - raise e - except (TransportProtocolError, asyncio.TimeoutError) as e: - await self._fail(e, clean_close=False) - raise e - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _clean_close(self, e: Exception) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - # Finally send the 'connection_terminate' message - await self._send_connection_terminate_message() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") + async def _after_connect(self): + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket.response_headers try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL - # We should always have an active websocket connection here - assert self.websocket is not None - - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close(e) - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - await self.websocket.close() + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - log.debug("_close_coro: websocket connection closed") + async def _after_initialize(self): - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): - finally: + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - log.debug("_close_coro: start cleanup") + async def _close_hook(self): - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task self.send_ping_task = None - - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - await self._wait_closed.wait() - - log.debug("wait_close: done") diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py new file mode 100644 index 00000000..151e444e --- /dev/null +++ b/gql/transport/websockets_base.py @@ -0,0 +1,666 @@ +import asyncio +import logging +import warnings +from abc import abstractmethod +from contextlib import suppress +from ssl import SSLContext +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast + +import websockets +from graphql import DocumentNode, ExecutionResult +from websockets.client import WebSocketClientProtocol +from websockets.datastructures import HeadersLike +from websockets.exceptions import ConnectionClosed +from websockets.typing import Data, Subprotocol + +from .async_transport import AsyncTransport +from .exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + + +class WebsocketsTransportBase(AsyncTransport): + """abstract :ref:`Async Transport ` used to implement + different websockets protocols. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + def __init__( + self, + url: str, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + """ + + self.url: str = url + self.headers: Optional[HeadersLike] = headers + self.ssl: Union[SSLContext, bool] = ssl + self.init_payload: Dict[str, Any] = init_payload + + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + + self.connect_args = connect_args + + self.websocket: Optional[WebSocketClientProtocol] = None + self.next_query_id: int = 1 + self.listeners: Dict[int, ListenerQueue] = {} + + self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() + + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + + if self.keep_alive_timeout is not None: + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self._connecting: bool = False + + self.close_exception: Optional[Exception] = None + + # The list of supported subprotocols should be defined in the subclass + self.supported_subprotocols: List[Subprotocol] = [] + + async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + pass # pragma: no cover + + async def _stop_listener(self, query_id: int): + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + pass # pragma: no cover + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + pass # pragma: no cover + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close + """ + pass # pragma: no cover + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _send(self, message: str) -> None: + """Send the provided message to the websocket connection and log the message""" + + if not self.websocket: + raise TransportClosed( + "Transport is not connected" + ) from self.close_exception + + try: + await self.websocket.send(message) + log.info(">>> %s", message) + except ConnectionClosed as e: + await self._fail(e, clean_close=False) + raise e + + async def _receive(self) -> str: + """Wait the next message from the websocket connection and log the answer""" + + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + data: Data = await self.websocket.recv() + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + log.info("<<< %s", answer) + + return answer + + @abstractmethod + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + raise NotImplementedError # pragma: no cover + + @abstractmethod + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + raise NotImplementedError # pragma: no cover + + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + try: + while True: + + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionClosed, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed: + break + + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break + + await self._handle_answer(answer_type, answer_id, execution_result) + + finally: + log.debug("Exiting _receive_data_loop()") + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. + + The query can be a graphql query, mutation or subscription. + + The results are sent as an ExecutionResult object. + """ + + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) + + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener + + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + + try: + # Loop over the received answers + while True: + + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() + + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break + + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + + Send a query but close the async generator as soon as we have the first answer. + + The result is sent as an ExecutionResult object. + """ + first_result = None + + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) + + async for result in generator: + first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) + + return first_result + + async def connect(self) -> None: + """Coroutine which will: + + - connect to the websocket address + - send the init message + - wait for the connection acknowledge from the server + - create an asyncio task which will be used to receive + and parse the websocket answers + + Should be cleaned with a call to the close coroutine + """ + + log.debug("connect: starting") + + if self.websocket is None and not self._connecting: + + # Set connecting to True to avoid a race condition if user is trying + # to connect twice using the same client at the same time + self._connecting = True + + # If the ssl parameter is not provided, + # generate the ssl value depending on the url + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + "subprotocols": self.supported_subprotocols, + } + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + try: + self.websocket = await asyncio.wait_for( + websockets.client.connect(self.url, **connect_args), + self.connect_timeout, + ) + finally: + self._connecting = False + + self.websocket = cast(WebSocketClientProtocol, self.websocket) + + # Run the after_connect hook of the subclass + await self._after_connect() + + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() + + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionClosed as e: + raise e + except (TransportProtocolError, asyncio.TimeoutError) as e: + await self._fail(e, clean_close=False) + raise e + + # Run the after_init hook of the subclass + await self._after_initialize() + + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + else: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("connect: done") + + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + + async def _clean_close(self, e: Exception) -> None: + """Coroutine which will: + + - send stop messages for each active subscription to the server + - send the connection terminate message + """ + + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): + + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") + + # Calling the subclass hook + await self._connection_terminate() + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: + + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ + + log.debug("_close_coro: starting") + + try: + + # We should always have an active websocket connection here + assert self.websocket is not None + + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + + # Calling the subclass close hook + await self._close_hook() + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close(e) + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) + + log.debug("_close_coro: close websocket connection") + + await self.websocket.close() + + log.debug("_close_coro: websocket connection closed") + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + + log.debug("_close_coro: start cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self.websocket is None: + log.debug("_fail started with self.websocket == None -> already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) + + async def close(self) -> None: + log.debug("close: starting") + + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() + + log.debug("close: done") + + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + await self._wait_closed.wait() + + log.debug("wait_close: done") diff --git a/setup.py b/setup.py index 266fbb0c..7e97f8bc 100644 --- a/setup.py +++ b/setup.py @@ -50,8 +50,12 @@ "websockets>=10,<11;python_version>'3.6'", ] +install_botocore_requires = [ + "botocore>=1.21,<2", +] + install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_botocore_requires ) # Get version from __version__.py file @@ -97,6 +101,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, + "botocore": install_botocore_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index c0101241..d433c1ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,7 @@ from gql import Client -all_transport_dependencies = [ - "aiohttp", - "requests", - "websockets", -] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "botocore"] def pytest_addoption(parser): @@ -116,10 +112,11 @@ async def ssl_aiohttp_server(): yield server -# Adding debug logs to websocket tests +# Adding debug logs for name in [ "websockets.legacy.server", "gql.transport.aiohttp", + "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", "gql.transport.websockets", @@ -492,3 +489,11 @@ async def run_sync_test_inner(event_loop, server, test_function): await server.close() return run_sync_test_inner + + +pytest_plugins = [ + "tests.fixtures.aws.fake_credentials", + "tests.fixtures.aws.fake_request", + "tests.fixtures.aws.fake_session", + "tests.fixtures.aws.fake_signer", +] diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/__init__.py b/tests/fixtures/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py new file mode 100644 index 00000000..d8eac834 --- /dev/null +++ b/tests/fixtures/aws/fake_credentials.py @@ -0,0 +1,28 @@ +import pytest + + +class FakeCredentials(object): + def __init__( + self, access_key=None, secret_key=None, method=None, token=None, region=None + ): + self.region = region if region else "us-east-1a" + self.access_key = access_key if access_key else "fake-access-key" + self.secret_key = secret_key if secret_key else "fake-secret-key" + self.method = method if method else "shared-credentials-file" + self.token = token if token else "fake-token" + + +@pytest.fixture +def fake_credentials_factory(): + def _fake_credentials_factory( + access_key=None, secret_key=None, method=None, token=None, region=None + ): + return FakeCredentials( + access_key=access_key, + secret_key=secret_key, + method=method, + token=token, + region=region, + ) + + yield _fake_credentials_factory diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py new file mode 100644 index 00000000..615bc095 --- /dev/null +++ b/tests/fixtures/aws/fake_request.py @@ -0,0 +1,22 @@ +import pytest + + +class FakeRequest(object): + headers = None + + def __init__(self, request_props=None): + if not isinstance(request_props, dict): + return + self.method = request_props.get("method") + self.url = request_props.get("url") + self.headers = request_props.get("headers") + self.context = request_props.get("context") + self.body = request_props.get("body") + + +@pytest.fixture +def fake_request_factory(): + def _fake_request_factory(request_props=None): + return FakeRequest(request_props=request_props) + + yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py new file mode 100644 index 00000000..78e1511a --- /dev/null +++ b/tests/fixtures/aws/fake_session.py @@ -0,0 +1,24 @@ +import pytest + + +class FakeSession(object): + def __init__(self, credentials, region_name): + self._credentials = credentials + self._region_name = region_name + + def get_default_client_config(self): + return + + def get_credentials(self): + return self._credentials + + def _resolve_region_name(self, region_name, client_config): + return region_name if region_name else self._region_name + + +@pytest.fixture +def fake_session_factory(fake_credentials_factory): + def _fake_session_factory(credentials=fake_credentials_factory()): + return FakeSession(credentials=credentials, region_name="fake-region") + + yield _fake_session_factory diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py new file mode 100644 index 00000000..ff096745 --- /dev/null +++ b/tests/fixtures/aws/fake_signer.py @@ -0,0 +1,27 @@ +import pytest + + +@pytest.fixture +def fake_signer_factory(fake_request_factory): + def _fake_signer_factory(request=None): + if not request: + request = fake_request_factory() + return FakeSigner(request=request) + + yield _fake_signer_factory + + +class FakeSigner(object): + def __init__(self, request=None) -> None: + self.request = request + + def add_auth(self, request) -> None: + """ + A fake for getting a request object that + :return: + """ + request.headers = {"FakeAuthorization": "a", "FakeTime": "today"} + + def get_headers(self): + self.add_auth(self.request) + return self.request.headers diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py new file mode 100644 index 00000000..546e0e6f --- /dev/null +++ b/tests/test_appsync_auth.py @@ -0,0 +1,189 @@ +import pytest + +mock_transport_host = "appsyncapp.awsgateway.com.example.org" +mock_transport_url = f"https://{mock_transport_host}/graphql" + + +@pytest.mark.botocore +def test_appsync_init_with_minimal_args(fake_session_factory): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, session=fake_session_factory() + ) + assert isinstance(sample_transport.auth, AppSyncIAMAuthentication) + assert sample_transport.connect_timeout == 10 + assert sample_transport.close_timeout == 10 + assert sample_transport.ack_timeout == 10 + assert sample_transport.ssl is False + assert sample_transport.connect_args == {} + + +@pytest.mark.botocore +def test_appsync_init_with_no_credentials(caplog, fake_session_factory): + import botocore.exceptions + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + with pytest.raises(botocore.exceptions.NoCredentialsError): + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, session=fake_session_factory(credentials=None), + ) + assert sample_transport.auth is None + + expected_error = "Credentials not found" + + print(f"Captured log: {caplog.text}") + + assert expected_error in caplog.text + + +@pytest.mark.websockets +def test_appsync_init_with_jwt_auth(): + from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth + + assert auth.get_headers() == { + "host": mock_transport_host, + "Authorization": "some-jwt", + } + + +@pytest.mark.websockets +def test_appsync_init_with_apikey_auth(): + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth + + assert auth.get_headers() == { + "host": mock_transport_host, + "x-api-key": "some-api-key", + } + + +@pytest.mark.botocore +def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): + import botocore.exceptions + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncIAMAuthentication( + host=mock_transport_host, session=fake_session_factory(credentials=None), + ) + with pytest.raises(botocore.exceptions.NoCredentialsError): + AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + + +@pytest.mark.botocore +def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncIAMAuthentication( + host=mock_transport_host, + credentials=fake_credentials_factory(), + region_name="us-east-1", + ) + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth + + +@pytest.mark.botocore +def test_appsync_init_with_iam_auth_and_no_region( + caplog, fake_credentials_factory, fake_session_factory +): + """ + + WARNING: this test will fail if: + - you have a default region set in ~/.aws/config + - you have the AWS_DEFAULT_REGION environment variable set + + """ + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.exceptions import NoRegionError + import logging + + caplog.set_level(logging.WARNING) + + with pytest.raises(NoRegionError): + session = fake_session_factory(credentials=fake_credentials_factory()) + session._region_name = None + session._credentials.region = None + transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=session) + + # prints the region name in case the test fails + print(f"Region found: {transport.auth._region_name}") + + print(f"Captured: {caplog.text}") + + expected_error = ( + "Region name not found. " + "It was not possible to detect your region either from the host " + "or from your default AWS configuration." + ) + + assert expected_error in caplog.text + + +@pytest.mark.botocore +def test_munge_url(fake_signer_factory, fake_request_factory): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + test_url = "https://appsync-api.aws.example.org/some-other-params" + + auth = AppSyncIAMAuthentication( + host=test_url, + signer=fake_signer_factory(), + request_creator=fake_request_factory, + ) + sample_transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) + + header_string = ( + "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" + "IiwiaG9zdCI6Imh0dHBzOi8vYXBwc3luYy1hcGkuYXdzLmV4YW1wbGUu" + "b3JnL3NvbWUtb3RoZXItcGFyYW1zIn0=" + ) + expected_url = ( + "wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) + assert sample_transport.url == expected_url + + +@pytest.mark.botocore +def test_munge_url_format( + fake_signer_factory, + fake_request_factory, + fake_credentials_factory, + fake_session_factory, +): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + + test_url = "https://appsync-api.aws.example.org/some-other-params" + + auth = AppSyncIAMAuthentication( + host=test_url, + signer=fake_signer_factory(), + session=fake_session_factory(), + request_creator=fake_request_factory, + credentials=fake_credentials_factory(), + ) + + header_string = ( + "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" + "IiwiaG9zdCI6Imh0dHBzOi8vYXBwc3luYy1hcGkuYXdzLmV4YW1wbGUu" + "b3JnL3NvbWUtb3RoZXItcGFyYW1zIn0=" + ) + expected_url = ( + "wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) + assert auth.get_auth_url(test_url) == expected_url diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py new file mode 100644 index 00000000..1f787a68 --- /dev/null +++ b/tests/test_appsync_http.py @@ -0,0 +1,78 @@ +import json + +import pytest + +from gql import Client, gql + + +@pytest.mark.asyncio +@pytest.mark.aiohttp +@pytest.mark.botocore +async def test_appsync_iam_mutation( + event_loop, aiohttp_server, fake_credentials_factory +): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from urllib.parse import urlparse + + async def handler(request): + data = { + "createMessage": { + "id": "4b436192-aab2-460c-8bdf-4f2605eb63da", + "message": "Hello world!", + "createdAt": "2021-12-06T14:49:55.087Z", + } + } + payload = { + "data": data, + "extensions": {"received_headers": dict(request.headers)}, + } + + return web.Response( + text=json.dumps(payload, separators=(",", ":")), + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncIAMAuthentication( + host=host, credentials=fake_credentials_factory(), region_name="us-east-1", + ) + + sample_transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client(transport=sample_transport) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + # Execute query asynchronously + execution_result = await session.execute(query, get_execution_result=True) + + result = execution_result.data + message = result["createMessage"]["message"] + + assert message == "Hello world!" + + sent_headers = execution_result.extensions["received_headers"] + + assert sent_headers["X-Amz-Security-Token"] == "fake-token" + assert sent_headers["Authorization"].startswith( + "AWS4-HMAC-SHA256 Credential=fake-access-key/" + ) diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py new file mode 100644 index 00000000..f510d4a7 --- /dev/null +++ b/tests/test_appsync_websockets.py @@ -0,0 +1,702 @@ +import asyncio +import json +from base64 import b64decode +from typing import List +from urllib import parse + +import pytest + +from gql import Client, gql + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +SEND_MESSAGE_DELAY = 20 * MS +NB_MESSAGES = 10 + +DUMMY_API_KEY = "da2-thisisadummyapikey01234567" +DUMMY_ACCESS_KEY_ID = "DUMMYACCESSKEYID0123" +DUMMY_ACCESS_KEY_ID_NOT_ALLOWED = "DUMMYACCESSKEYID!ALL" +DUMMY_ACCESS_KEY_IDS = [DUMMY_ACCESS_KEY_ID, DUMMY_ACCESS_KEY_ID_NOT_ALLOWED] +DUMMY_SECRET_ACCESS_KEY = "ThisIsADummySecret0123401234012340123401" +DUMMY_SECRET_SESSION_TOKEN = ( + "FwoREDACTEDzEREDACTED+YREDACTEDJLREDACTEDz2REDACTEDH5RE" + "DACTEDbVREDACTEDqwREDACTEDHJREDACTEDxFREDACTEDtMREDACTED5kREDACTEDSwREDACTED0BRED" + "ACTEDuDREDACTEDm4REDACTEDSBREDACTEDaoREDACTEDP2REDACTEDCBREDACTED0wREDACTEDmdREDA" + "CTEDyhREDACTEDSKREDACTEDYbREDACTEDfeREDACTED3UREDACTEDaKREDACTEDi1REDACTEDGEREDAC" + "TED4VREDACTEDjmREDACTEDYcREDACTEDkQREDACTEDyI=" +) +REGION_NAME = "eu-west-3" + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def realtime_appsync_server_factory( + keepalive=False, not_json_answer=False, error_without_id=False +): + def verify_headers(headers, in_query=False): + """Returns an error or None if all is ok""" + + if "x-api-key" in headers: + print("API KEY Authentication detected!") + + if headers["x-api-key"] == DUMMY_API_KEY: + return None + + elif "Authorization" in headers: + if "X-Amz-Security-Token" in headers: + with_token = True + print("IAM Authentication with token detected!") + else: + with_token = False + print("IAM Authentication with token detected!") + print("IAM Authentication without token detected!") + + assert headers["accept"] == "application/json, text/javascript" + assert headers["content-encoding"] == "amz-1.0" + assert headers["content-type"] == "application/json; charset=UTF-8" + assert "X-Amz-Date" in headers + + authorization_fields = headers["Authorization"].split(" ") + + assert authorization_fields[0] == "AWS4-HMAC-SHA256" + + credential_field = authorization_fields[1][:-1].split("=") + assert credential_field[0] == "Credential" + credential_content = credential_field[1].split("/") + assert credential_content[0] in DUMMY_ACCESS_KEY_IDS + + if in_query: + if credential_content[0] == DUMMY_ACCESS_KEY_ID_NOT_ALLOWED: + return { + "errorType": "UnauthorizedException", + "message": "Permission denied", + } + + # assert credential_content[1]== date + # assert credential_content[2]== region + assert credential_content[3] == "appsync" + assert credential_content[4] == "aws4_request" + + signed_headers_field = authorization_fields[2][:-1].split("=") + + assert signed_headers_field[0] == "SignedHeaders" + signed_headers = signed_headers_field[1].split(";") + + assert "accept" in signed_headers + assert "content-encoding" in signed_headers + assert "content-type" in signed_headers + assert "host" in signed_headers + assert "x-amz-date" in signed_headers + + if with_token: + assert "x-amz-security-token" in signed_headers + + signature_field = authorization_fields[3].split("=") + + assert signature_field[0] == "Signature" + + return None + + return { + "errorType": "com.amazonaws.deepdish.graphql.auth#UnauthorizedException", + "message": "You are not authorized to make this call.", + "errorCode": 400, + } + + async def realtime_appsync_server_template(ws, path): + import websockets + + logged_messages.clear() + + try: + if not_json_answer: + await ws.send("Something not json") + return + + if error_without_id: + await ws.send( + json.dumps( + { + "type": "error", + "payload": { + "errors": [ + { + "errorType": "Error without id", + "message": ( + "Sometimes AppSync will send you " + "an error without an id" + ), + } + ] + }, + }, + separators=(",", ":"), + ) + ) + return + + print(f"path = {path}") + + path_base, parameters_str = path.split("?") + + assert path_base == "/graphql" + + parameters = parse.parse_qs(parameters_str) + + header_param = parameters["header"][0] + payload_param = parameters["payload"][0] + + assert payload_param == "e30=" + + headers = json.loads(b64decode(header_param).decode()) + + print("\nHeaders received in URL:") + for key, value in headers.items(): + print(f" {key}: {value}") + print("\n") + + error = verify_headers(headers) + + if error is not None: + await ws.send( + json.dumps( + {"payload": {"errors": [error]}, "type": "connection_error"}, + separators=(",", ":"), + ) + ) + return + + await WebSocketServerHelper.send_connection_ack( + ws, payload='{"connectionTimeoutMs":300000}' + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + + query_id = json_result["id"] + assert json_result["type"] == "start" + + payload = json_result["payload"] + + # With appsync, the data field is serialized to string + data_str = payload["data"] + extensions = payload["extensions"] + + data = json.loads(data_str) + + query = data["query"] + variables = data.get("variables", None) + operation_name = data.get("operationName", None) + print(f"Received query: {query}") + print(f"Received variables: {variables}") + print(f"Received operation_name: {operation_name}") + + authorization = extensions["authorization"] + print("\nHeaders received in the extensions of the query:") + for key, value in authorization.items(): + print(f" {key}: {value}") + print("\n") + + error = verify_headers(headers, in_query=True) + + if error is not None: + await ws.send( + json.dumps( + { + "id": str(query_id), + "type": "error", + "payload": {"errors": [error]}, + }, + separators=(",", ":"), + ) + ) + return + + await ws.send( + json.dumps( + {"id": str(query_id), "type": "start_ack"}, separators=(",", ":") + ) + ) + + async def send_message_coro(): + print(" Server: send message task started") + try: + for number in range(NB_MESSAGES): + payload = { + "data": { + "onCreateMessage": {"message": f"Hello world {number}!"} + } + } + + if operation_name or variables: + + payload["extensions"] = {} + + if operation_name: + payload["extensions"]["operation_name"] = operation_name + if variables: + payload["extensions"]["variables"] = variables + + await ws.send( + json.dumps( + { + "id": str(query_id), + "type": "data", + "payload": payload, + }, + separators=(",", ":"), + ) + ) + await asyncio.sleep(SEND_MESSAGE_DELAY) + finally: + print(" Server: send message task ended") + + print(" Server: starting send message task") + send_message_task = asyncio.ensure_future(send_message_coro()) + + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + try: + await WebSocketServerHelper.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break + + if keepalive: + print(" Server: starting keepalive task") + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + print(" Server: receiving task started") + try: + nonlocal send_message_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + print( + " Server: waiting for sending message task to complete" + ) + await send_message_task + except asyncio.CancelledError: + print(" Server: Now sending message task is cancelled") + + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) + + if keepalive: + print(" Server: cancelling keepalive task") + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") + + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\n Server: Assertion failed: {e!s}\n") + except Exception as e: + print(f"\n Server: Exception received: {e!s}\n") + finally: + print(" Server: waiting for websocket connection to close") + await ws.wait_closed() + print(" Server: connection closed") + + return realtime_appsync_server_template + + +async def realtime_appsync_server(ws, path): + + server = realtime_appsync_server_factory() + await server(ws, path) + + +async def realtime_appsync_server_keepalive(ws, path): + + server = realtime_appsync_server_factory(keepalive=True) + await server(ws, path) + + +async def realtime_appsync_server_not_json_answer(ws, path): + + server = realtime_appsync_server_factory(not_json_answer=True) + await server(ws, path) + + +async def realtime_appsync_server_error_without_id(ws, path): + + server = realtime_appsync_server_factory(error_without_id=True) + await server(ws, path) + + +on_create_message_subscription_str = """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + + +async def default_transport_test(transport): + client = Client(transport=transport) + + expected_messages = [f"Hello world {number}!" for number in range(NB_MESSAGES)] + received_messages = [] + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + async for result in session.subscribe(subscription): + + message = result["onCreateMessage"]["message"] + print(f"Message received: '{message}'") + + received_messages.append(message) + + assert expected_messages == received_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) +async def test_appsync_subscription_api_key(event_loop, server): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport( + url=url, auth=auth, keep_alive_timeout=(5 * SEND_MESSAGE_DELAY) + ) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_with_token(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, + secret_key=DUMMY_SECRET_ACCESS_KEY, + token=DUMMY_SECRET_SESSION_TOKEN, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_without_token(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_execute_method_not_allowed(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + async with client as session: + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + with pytest.raises(AssertionError) as exc_info: + await session.execute(query, variable_values=variable_values) + + assert ( + "execute method is not allowed for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.botocore +async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication( + host="something", credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url="https://something", auth=auth) + + with pytest.raises(AssertionError) as exc_info: + Client(transport=transport, fetch_schema_from_transport=True) + + assert ( + "fetch_schema_from_transport=True is not allowed for AppSyncWebsocketsTransport" + " because only subscriptions are allowed on the realtime endpoint." + ) in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_api_key_unauthorized(event_loop, server): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportServerError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key="invalid") + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportServerError) as exc_info: + async with client as _: + pass + + assert "You are not authorized to make this call." in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_not_allowed(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportQueryError + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID_NOT_ALLOWED, + secret_key=DUMMY_SECRET_ACCESS_KEY, + token=DUMMY_SECRET_SESSION_TOKEN, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + with pytest.raises(TransportQueryError) as exc_info: + + async for result in session.subscribe(subscription): + pass + + assert "Permission denied" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [realtime_appsync_server_not_json_answer], indirect=True +) +async def test_appsync_subscription_server_sending_a_not_json_answer( + event_loop, server +): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportProtocolError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportProtocolError) as exc_info: + async with client as _: + pass + + assert "Server did not return a GraphQL result: Something not json" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [realtime_appsync_server_error_without_id], indirect=True +) +async def test_appsync_subscription_server_sending_an_error_without_an_id( + event_loop, server +): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportServerError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportServerError) as exc_info: + async with client as _: + pass + + assert "Sometimes AppSync will send you an error without an id" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) +async def test_appsync_subscription_variable_values_and_operation_name( + event_loop, server +): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport( + url=url, auth=auth, keep_alive_timeout=(5 * SEND_MESSAGE_DELAY) + ) + + client = Client(transport=transport) + + expected_messages = [f"Hello world {number}!" for number in range(NB_MESSAGES)] + received_messages = [] + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + async for execution_result in session.subscribe( + subscription, + operation_name="onCreateMessage", + variable_values={"key1": "val1"}, + get_execution_result=True, + ): + + result = execution_result.data + message = result["onCreateMessage"]["message"] + print(f"Message received: '{message}'") + + received_messages.append(message) + + print(f"extensions received: {execution_result.extensions}") + + assert execution_result.extensions["operation_name"] == "onCreateMessage" + variables = execution_result.extensions["variables"] + assert variables["key1"] == "val1" + + assert expected_messages == received_messages diff --git a/tox.ini b/tox.ini index 2699744c..e75b8fac 100644 --- a/tox.ini +++ b/tox.ini @@ -3,9 +3,6 @@ envlist = black,flake8,import-order,mypy,manifest, py{36,37,38,39,310,py3} -[pytest] -markers = asyncio - [gh-actions] python = 3.6: py36 From e2d208518effd4386db9cbed6da85da83c240d20 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 10 Dec 2021 08:17:05 +0100 Subject: [PATCH 050/239] README.md remove unnecessary line --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 0962c80e..08484c2f 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,6 @@ The complete documentation for GQL can be found at ## Features -The main features of GQL are: - * Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html): * http * websockets: From ec37cb0e48ca7fb4272651e6a1b4fda1ef3dd220 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 10 Dec 2021 08:39:55 +0100 Subject: [PATCH 051/239] PhoenixChannelWebsocketsTransport inherits WebsocketsBase (#280) --- gql/transport/phoenix_channel_websockets.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index b750c39c..b8226234 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -11,7 +11,7 @@ TransportQueryError, TransportServerError, ) -from .websockets import WebsocketsTransport +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -24,7 +24,7 @@ def __init__(self, query_id: int) -> None: self.unsubscribe_id: Optional[int] = None -class PhoenixChannelWebsocketsTransport(WebsocketsTransport): +class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): """The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. @@ -54,7 +54,7 @@ def __init__( self.subscriptions: Dict[str, Subscription] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) - async def _send_init_message_and_wait_ack(self) -> None: + async def _initialize(self) -> None: """Join the specified channel and wait for the connection ACK. If the answer is not a connection_ack message, we will return an Exception. @@ -131,6 +131,9 @@ async def _send_stop_message(self, query_id: int) -> None: await self._send(unsubscribe_message) + async def _stop_listener(self, query_id: int) -> None: + await self._send_stop_message(query_id) + async def _send_connection_terminate_message(self) -> None: """Send a phx_leave message to disconnect from the provided channel.""" @@ -148,6 +151,9 @@ async def _send_connection_terminate_message(self) -> None: await self._send(connection_terminate_message) + async def _connection_terminate(self): + await self._send_connection_terminate_message() + async def _send_query( self, document: DocumentNode, From 7cc1002dfaec196a2e49807d17abf727bb9cfe81 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 10 Dec 2021 08:45:56 +0100 Subject: [PATCH 052/239] gql-cli add --transport argument (#281) --- docs/transports/appsync.rst | 28 ++++++ gql/cli.py | 140 ++++++++++++++++++++++++++---- gql/transport/appsync_auth.py | 8 +- tests/test_aiohttp.py | 6 +- tests/test_cli.py | 157 +++++++++++++++++++++++++++++++++- 5 files changed, 316 insertions(+), 23 deletions(-) diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 7ceb7480..e7e413cd 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -154,3 +154,31 @@ a normal http session and reuse the authentication classes to create the headers Full example with API key authentication from environment variables: .. literalinclude:: ../code_examples/appsync/mutation_api_key.py + +From the command line +--------------------- + +Using :ref:`gql-cli `, it is possible to execute GraphQL queries and subscriptions +from the command line on an AppSync endpoint. + +- For queries and mutations, use the :code:`--transport appsync_http` argument:: + + # Put the request in a file + $ echo 'mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } + }' > mutation.graphql + + # Execute the request using gql-cli with --transport appsync_http + $ cat mutation.graphql | gql-cli $AWS_GRAPHQL_API_ENDPOINT --transport appsync_http -V message:"Hello world!" + +- For subscriptions, use the :code:`--transport appsync_websockets` argument:: + + echo "subscription{onCreateMessage{message}}" | gql-cli $AWS_GRAPHQL_API_ENDPOINT --transport appsync_websockets + +- You can also get the full GraphQL schema from the backend from introspection:: + + $ gql-cli $AWS_GRAPHQL_API_ENDPOINT --transport appsync_http --print-schema > schema.graphql diff --git a/gql/cli.py b/gql/cli.py index 917a4268..1e248081 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -2,7 +2,7 @@ import logging import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter -from typing import Any, Dict +from typing import Any, Dict, Optional from graphql import GraphQLError, print_schema from yarl import URL @@ -101,6 +101,43 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: action="store_true", dest="print_schema", ) + parser.add_argument( + "--transport", + default="auto", + choices=[ + "auto", + "aiohttp", + "phoenix", + "websockets", + "appsync_http", + "appsync_websockets", + ], + help=( + "select the transport. 'auto' by default: " + "aiohttp or websockets depending on url scheme" + ), + dest="transport", + ) + + appsync_description = """ +By default, for an AppSync backend, the IAM authentication is chosen. + +If you want API key or JWT authentication, you can provide one of the +following arguments:""" + + appsync_group = parser.add_argument_group( + "AWS AppSync options", description=appsync_description + ) + + appsync_auth_group = appsync_group.add_mutually_exclusive_group() + + appsync_auth_group.add_argument( + "--api-key", help="Provide an API key for authentication", dest="api_key", + ) + + appsync_auth_group.add_argument( + "--jwt", help="Provide an JSON Web token for authentication", dest="jwt", + ) return parser @@ -191,7 +228,20 @@ def get_execute_args(args: Namespace) -> Dict[str, Any]: return execute_args -def get_transport(args: Namespace) -> AsyncTransport: +def autodetect_transport(url: URL) -> str: + """Detects which transport should be used depending on url.""" + + if url.scheme in ["ws", "wss"]: + transport_name = "websockets" + + else: + assert url.scheme in ["http", "https"] + transport_name = "aiohttp" + + return transport_name + + +def get_transport(args: Namespace) -> Optional[AsyncTransport]: """Instantiate a transport from the parsed command line arguments :param args: parsed command line arguments @@ -199,28 +249,85 @@ def get_transport(args: Namespace) -> AsyncTransport: # Get the url scheme from server parameter url = URL(args.server) - scheme = url.scheme + + # Validate scheme + if url.scheme not in ["http", "https", "ws", "wss"]: + raise ValueError("URL protocol should be one of: http, https, ws, wss") # Get extra transport parameters from command line arguments # (headers) transport_args = get_transport_args(args) - # Instantiate transport depending on url scheme - transport: AsyncTransport - if scheme in ["ws", "wss"]: - from gql.transport.websockets import WebsocketsTransport + # Either use the requested transport or autodetect it + if args.transport == "auto": + transport_name = autodetect_transport(url) + else: + transport_name = args.transport - transport = WebsocketsTransport( - url=args.server, ssl=(scheme == "wss"), **transport_args - ) - elif scheme in ["http", "https"]: + # Import the correct transport class depending on the transport name + if transport_name == "aiohttp": from gql.transport.aiohttp import AIOHTTPTransport - transport = AIOHTTPTransport(url=args.server, **transport_args) + return AIOHTTPTransport(url=args.server, **transport_args) + + elif transport_name == "phoenix": + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + return PhoenixChannelWebsocketsTransport(url=args.server, **transport_args) + + elif transport_name == "websockets": + from gql.transport.websockets import WebsocketsTransport + + transport_args["ssl"] = url.scheme == "wss" + + return WebsocketsTransport(url=args.server, **transport_args) + else: - raise ValueError("URL protocol should be one of: http, https, ws, wss") - return transport + from gql.transport.appsync_auth import AppSyncAuthentication + + assert transport_name in ["appsync_http", "appsync_websockets"] + assert url.host is not None + + auth: AppSyncAuthentication + + if args.api_key: + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + + auth = AppSyncApiKeyAuthentication(host=url.host, api_key=args.api_key) + + elif args.jwt: + from gql.transport.appsync_auth import AppSyncJWTAuthentication + + auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt) + + else: + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from botocore.exceptions import NoRegionError + + try: + auth = AppSyncIAMAuthentication(host=url.host) + except NoRegionError: + # A warning message has been printed in the console + return None + + transport_args["auth"] = auth + + if transport_name == "appsync_http": + from gql.transport.aiohttp import AIOHTTPTransport + + return AIOHTTPTransport(url=args.server, **transport_args) + + else: + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + try: + return AppSyncWebsocketsTransport(url=args.server, **transport_args) + except Exception: + # This is for the NoCredentialsError but we cannot import it here + return None async def main(args: Namespace) -> int: @@ -238,13 +345,16 @@ async def main(args: Namespace) -> int: # Instantiate transport from command line arguments transport = get_transport(args) + if transport is None: + return 1 + # Get extra execute parameters from command line arguments # (variables, operation_name) execute_args = get_execute_args(args) except ValueError as e: print(f"Error: {e}", file=sys.stderr) - sys.exit(1) + return 1 # By default, the exit_code is 0 (everything is ok) exit_code = 0 diff --git a/gql/transport/appsync_auth.py b/gql/transport/appsync_auth.py index 04c07c10..5ce93d4e 100644 --- a/gql/transport/appsync_auth.py +++ b/gql/transport/appsync_auth.py @@ -54,7 +54,7 @@ def __init__(self, host: str, api_key: str) -> None: XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com :param api_key: the API key """ - self._host = host + self._host = host.replace("appsync-realtime-api", "appsync-api") self.api_key = api_key def get_headers( @@ -77,7 +77,7 @@ def __init__(self, host: str, jwt: str) -> None: XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com :param jwt: the JWT Access Token """ - self._host = host + self._host = host.replace("appsync-realtime-api", "appsync-api") self.jwt = jwt def get_headers( @@ -120,7 +120,7 @@ def __init__( from botocore.awsrequest import create_request_object from botocore.session import get_session - self._host = host + self._host = host.replace("appsync-realtime-api", "appsync-api") self._session = session if session else get_session() self._credentials = ( credentials if credentials else self._session.get_credentials() @@ -201,7 +201,7 @@ def get_headers( self._signer.add_auth(request) except NoCredentialsError: log.warning( - "Credentials not found. " + "Credentials not found for the IAM auth. " "Do you have default AWS credentials configured?", ) raise diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 6dbe46ae..682cea0d 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -994,9 +994,9 @@ async def handler(request): # via the standard input monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) - # Checking that sys.exit() is called - with pytest.raises(SystemExit): - await main(args) + # Check that the exit_code is an error + exit_code = await main(args) + assert exit_code == 1 # Check that the error has been printed on stdout captured = capsys.readouterr() diff --git a/tests/test_cli.py b/tests/test_cli.py index 3329bb5d..8df47a63 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,7 +2,13 @@ import pytest -from gql.cli import get_execute_args, get_parser, get_transport, get_transport_args +from gql.cli import ( + get_execute_args, + get_parser, + get_transport, + get_transport_args, + main, +) @pytest.fixture @@ -177,6 +183,155 @@ def test_cli_get_transport_websockets(parser, url): assert isinstance(transport, WebsocketsTransport) +@pytest.mark.websockets +@pytest.mark.parametrize( + "url", ["ws://your_server.com", "wss://your_server.com"], +) +def test_cli_get_transport_phoenix(parser, url): + + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + args = parser.parse_args([url, "--transport", "phoenix"]) + + transport = get_transport(args) + + assert isinstance(transport, PhoenixChannelWebsocketsTransport) + + +@pytest.mark.websockets +@pytest.mark.botocore +@pytest.mark.parametrize( + "url", + [ + "wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql", + "wss://noregion.amazonaws.com/graphql", + ], +) +def test_cli_get_transport_appsync_websockets_iam(parser, url): + + args = parser.parse_args([url, "--transport", "appsync_websockets"]) + + transport = get_transport(args) + + # In the tests, the AWS Appsync credentials are not set + # So the transport is None + assert transport is None + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.botocore +@pytest.mark.parametrize( + "url", ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], +) +async def test_cli_main_appsync_websockets_iam(parser, url): + + args = parser.parse_args([url, "--transport", "appsync_websockets"]) + + exit_code = await main(args) + + # In the tests, the AWS Appsync credentials are not set + # So the transport is None and the main returns + # an exit_code of 1 + assert exit_code == 1 + + +@pytest.mark.websockets +@pytest.mark.parametrize( + "url", ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], +) +def test_cli_get_transport_appsync_websockets_api_key(parser, url): + + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + + args = parser.parse_args( + [url, "--transport", "appsync_websockets", "--api-key", "test-api-key"] + ) + + transport = get_transport(args) + + assert isinstance(transport, AppSyncWebsocketsTransport) + assert isinstance(transport.auth, AppSyncApiKeyAuthentication) + assert transport.auth.api_key == "test-api-key" + + +@pytest.mark.websockets +@pytest.mark.parametrize( + "url", ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], +) +def test_cli_get_transport_appsync_websockets_jwt(parser, url): + + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.appsync_auth import AppSyncJWTAuthentication + + args = parser.parse_args( + [url, "--transport", "appsync_websockets", "--jwt", "test-jwt"] + ) + + transport = get_transport(args) + + assert isinstance(transport, AppSyncWebsocketsTransport) + assert isinstance(transport.auth, AppSyncJWTAuthentication) + assert transport.auth.jwt == "test-jwt" + + +@pytest.mark.aiohttp +@pytest.mark.botocore +@pytest.mark.parametrize( + "url", ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], +) +def test_cli_get_transport_appsync_http_iam(parser, url): + + from gql.transport.aiohttp import AIOHTTPTransport + + args = parser.parse_args([url, "--transport", "appsync_http"]) + + transport = get_transport(args) + + assert isinstance(transport, AIOHTTPTransport) + + +@pytest.mark.aiohttp +@pytest.mark.parametrize( + "url", ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], +) +def test_cli_get_transport_appsync_http_api_key(parser, url): + + from gql.transport.aiohttp import AIOHTTPTransport + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + + args = parser.parse_args( + [url, "--transport", "appsync_http", "--api-key", "test-api-key"] + ) + + transport = get_transport(args) + + assert isinstance(transport, AIOHTTPTransport) + assert isinstance(transport.auth, AppSyncApiKeyAuthentication) + assert transport.auth.api_key == "test-api-key" + + +@pytest.mark.aiohttp +@pytest.mark.parametrize( + "url", ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], +) +def test_cli_get_transport_appsync_http_jwt(parser, url): + + from gql.transport.aiohttp import AIOHTTPTransport + from gql.transport.appsync_auth import AppSyncJWTAuthentication + + args = parser.parse_args([url, "--transport", "appsync_http", "--jwt", "test-jwt"]) + + transport = get_transport(args) + + assert isinstance(transport, AIOHTTPTransport) + assert isinstance(transport.auth, AppSyncJWTAuthentication) + assert transport.auth.jwt == "test-jwt" + + def test_cli_get_transport_no_protocol(parser): args = parser.parse_args(["your_server.com"]) From c64037f2c5b52b214f8daef25f92814bc6622c24 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 10 Dec 2021 09:47:43 +0100 Subject: [PATCH 053/239] Bump version number to 3.0.0rc0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 3996ce87..ce971bbf 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0b1" +__version__ = "3.0.0rc0" From 9b131a1f2b3f07aeb03cc597bbc337947cd5677b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 16 Dec 2021 09:26:29 +0100 Subject: [PATCH 054/239] Update issue templates --- .github/ISSUE_TEMPLATE/bug_report.md | 30 +++++++++++++++++++++++ .github/ISSUE_TEMPLATE/feature_request.md | 10 ++++++++ 2 files changed, 40 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..f89a2238 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Common problems** +- If you receive a TransportQueryError, it means the error is coming from the backend (See [Error Handling](https://gql.readthedocs.io/en/latest/advanced/error_handling.html)) and has probably nothing to do with gql +- If you use IPython (Jupyter, Spyder), then [you need to use the async version](https://gql.readthedocs.io/en/latest/async/async_usage.html#ipython) +- Before sending a bug report, please consider [activating debug logs](https://gql.readthedocs.io/en/latest/advanced/logging.html) to see the messages exchanged between the client and the backend + +**Describe the bug** +A clear and concise description of what the bug is. +Please provide a full stack trace if you have one. +If you can, please provide the backend URL, the GraphQL schema, the code you used. + +**To Reproduce** +Steps to reproduce the behavior: + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**System info (please complete the following information):** + - OS: + - Python version: + - gql version: + - graphql-core version: diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..e46a4c01 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,10 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + + From 6b51742611b122d7aef1f84645b9464181feac30 Mon Sep 17 00:00:00 2001 From: Johnathan August Fisher Date: Fri, 17 Dec 2021 14:45:12 -0500 Subject: [PATCH 055/239] Update type hint to align with docs and usage. (#285) --- gql/transport/requests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 31b52809..32e57478 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -40,7 +40,7 @@ def __init__( auth: Optional[AuthBase] = None, use_json: bool = True, timeout: Optional[int] = None, - verify: bool = True, + verify: Union[bool, str] = True, retries: int = 0, method: str = "POST", **kwargs: Any, From 4d16502510188e2fc32bd5eed52bdbb08a72e4a1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 5 Jan 2022 10:11:48 +0100 Subject: [PATCH 056/239] Mark online tests as online (#288) --- tests/test_client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index 1521eac7..c8df40ee 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -85,6 +85,7 @@ def test_no_schema_exception(): ) +@pytest.mark.online @pytest.mark.requests def test_execute_result_error(): @@ -111,6 +112,7 @@ def test_execute_result_error(): assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) +@pytest.mark.online @pytest.mark.requests def test_http_transport_raise_for_status_error(http_transport_query): from gql.transport.requests import RequestsHTTPTransport @@ -127,6 +129,7 @@ def test_http_transport_raise_for_status_error(http_transport_query): assert "400 Client Error: Bad Request for url" in str(exc_info.value) +@pytest.mark.online @pytest.mark.requests def test_http_transport_verify_error(http_transport_query): from gql.transport.requests import RequestsHTTPTransport @@ -142,6 +145,7 @@ def test_http_transport_verify_error(http_transport_query): assert "Unverified HTTPS request is being made to host" in str(record[0].message) +@pytest.mark.online @pytest.mark.requests def test_http_transport_specify_method_valid(http_transport_query): from gql.transport.requests import RequestsHTTPTransport @@ -155,6 +159,7 @@ def test_http_transport_specify_method_valid(http_transport_query): assert result is not None +@pytest.mark.online @pytest.mark.requests def test_http_transport_specify_method_invalid(http_transport_query): from gql.transport.requests import RequestsHTTPTransport From 5440c6c14b74f0414551e0ebebeed187bdf4ae5a Mon Sep 17 00:00:00 2001 From: Connor Brinton Date: Wed, 5 Jan 2022 12:11:50 -0500 Subject: [PATCH 057/239] Add type overloads for get_execution_result (#287) --- gql/client.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 2 deletions(-) diff --git a/gql/client.py b/gql/client.py index e10f7509..91cbcde6 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,6 +1,7 @@ import asyncio +import sys import warnings -from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union +from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, overload from graphql import ( DocumentNode, @@ -20,6 +21,16 @@ from .utilities import parse_result as parse_result_fn from .utilities import serialize_variable_values +""" +Load the appropriate instance of the Literal type +Note: we cannot use try: except ImportError because of the following mypy issue: +https://github.com/python/mypy/issues/8520 +""" +if sys.version_info[:2] >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # pragma: no cover + class Client: """The Client class is the main entrypoint to execute GraphQL requests @@ -362,6 +373,34 @@ def _execute( return result + @overload + def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> Dict[str, Any]: + ... # pragma: no cover + + @overload + def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + get_execution_result: Literal[True], + **kwargs, + ) -> ExecutionResult: + ... # pragma: no cover + def execute( self, document: DocumentNode, @@ -525,6 +564,34 @@ async def _subscribe( finally: await inner_generator.aclose() + @overload + def subscribe( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> AsyncGenerator[Dict[str, Any], None]: + ... # pragma: no cover + + @overload + def subscribe( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + get_execution_result: Literal[True], + **kwargs, + ) -> AsyncGenerator[ExecutionResult, None]: + ... # pragma: no cover + async def subscribe( self, document: DocumentNode, @@ -535,7 +602,9 @@ async def subscribe( parse_result: Optional[bool] = None, get_execution_result: bool = False, **kwargs, - ) -> AsyncGenerator[Union[Dict[str, Any], ExecutionResult], None]: + ) -> Union[ + AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] + ]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport. @@ -653,6 +722,34 @@ async def _execute( return result + @overload + async def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> Dict[str, Any]: + ... # pragma: no cover + + @overload + async def execute( + self, + document: DocumentNode, + *args, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + get_execution_result: Literal[True], + **kwargs, + ) -> ExecutionResult: + ... # pragma: no cover + async def execute( self, document: DocumentNode, From bf16f1e2369d80b7877adccd55d53bfd9c686351 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 16 Jan 2022 14:14:07 +0100 Subject: [PATCH 058/239] Bump graphql-core to 3.2.0 (#283) * Fix error 'graphql.error.graphql_error.GraphQLError: Names must only contain [_a-zA-Z0-9] but 'meta-field' does not.' * Don't use format_error * Put the is_finite method in the test file * Fix StarWars schema Fixes 'Support for returning GraphQLObjectType from resolve_type was removed in GraphQL-core 3.2, please return type name instead' * fix print_ast removing last newline * Fix error: AttributeError: 'ParseResultVisitor' object has no attribute 'enter_leave_map' * Rename specifiedByUrl to specifiedByURL See https://github.com/graphql/graphql-js/issues/3156 * Bump graphql-core version to stable 3.2 This new version of GraphQL-core replaces the FrozenLists in AST nodes with tuples, so we need to make the appropriate changes in dsl.py. * Fix Introspection directive typing --- gql/dsl.py | 65 ++++----- gql/utilities/build_client_schema.py | 29 ++-- gql/utilities/get_introspection_query_ast.py | 2 +- gql/utilities/parse_result.py | 2 + setup.py | 2 +- tests/custom_scalars/test_money.py | 10 +- tests/fixtures/vcr_cassettes/queries.yaml | 10 +- tests/starwars/fixtures.py | 58 ++++++-- tests/starwars/schema.py | 139 ++++++++++--------- tests/starwars/test_dsl.py | 24 ++-- tests/starwars/test_query.py | 6 +- tests/test_aiohttp.py | 6 +- tests/test_requests.py | 6 +- 13 files changed, 195 insertions(+), 164 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 0cadef8b..6a2e0718 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -6,7 +6,7 @@ import re from abc import ABC, abstractmethod from math import isfinite -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, @@ -61,7 +61,7 @@ is_wrapping_type, print_ast, ) -from graphql.pyutils import FrozenList, inspect +from graphql.pyutils import inspect from .utils import to_camel_case @@ -90,17 +90,17 @@ def ast_from_serialized_value_untyped(serialized: Any) -> Optional[ValueNode]: (key, ast_from_serialized_value_untyped(value)) for key, value in serialized.items() ) - field_nodes = ( + field_nodes = tuple( ObjectFieldNode(name=NameNode(value=field_name), value=field_value) for field_name, field_value in field_items if field_value ) - return ObjectValueNode(fields=FrozenList(field_nodes)) + return ObjectValueNode(fields=field_nodes) if isinstance(serialized, Iterable) and not isinstance(serialized, str): maybe_nodes = (ast_from_serialized_value_untyped(item) for item in serialized) - nodes = filter(None, maybe_nodes) - return ListValueNode(values=FrozenList(nodes)) + nodes = tuple(node for node in maybe_nodes if node) + return ListValueNode(values=nodes) if isinstance(serialized, bool): return BooleanValueNode(value=serialized) @@ -158,8 +158,8 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: item_type = type_.of_type if isinstance(value, Iterable) and not isinstance(value, str): maybe_value_nodes = (ast_from_value(item, item_type) for item in value) - value_nodes = filter(None, maybe_value_nodes) - return ListValueNode(values=FrozenList(value_nodes)) + value_nodes = tuple(node for node in maybe_value_nodes if node) + return ListValueNode(values=value_nodes) return ast_from_value(value, item_type) # Populate the fields of the input object by creating ASTs from each value in the @@ -173,12 +173,12 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: for field_name, field in type_.fields.items() if field_name in value ) - field_nodes = ( + field_nodes = tuple( ObjectFieldNode(name=NameNode(value=field_name), value=field_value) for field_name, field_value in field_items if field_value ) - return ObjectValueNode(fields=FrozenList(field_nodes)) + return ObjectValueNode(fields=field_nodes) if is_leaf_type(type_): # Since value is an internally represented value, it must be serialized to an @@ -314,7 +314,7 @@ def __init__( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", ): """:meta private:""" - self.selection_set = SelectionSetNode(selections=FrozenList([])) + self.selection_set = SelectionSetNode(selections=()) if fields or fields_with_alias: self.select(*fields, **fields_with_alias) @@ -355,14 +355,12 @@ def select( raise GraphQLError(f"Invalid field for {self!r}: {field!r}") # Get a list of AST Nodes for each added field - added_selections: List[ - Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] - ] = [field.ast_field for field in added_fields] + added_selections: Tuple[ + Union[FieldNode, InlineFragmentNode, FragmentSpreadNode], ... + ] = tuple(field.ast_field for field in added_fields) # Update the current selection list with new selections - self.selection_set.selections = FrozenList( - self.selection_set.selections + added_selections - ) + self.selection_set.selections = self.selection_set.selections + added_selections log.debug(f"Added fields: {added_fields} in {self!r}") @@ -470,9 +468,7 @@ def executable_ast(self) -> OperationDefinitionNode: return OperationDefinitionNode( operation=OperationType(self.operation_type), selection_set=self.selection_set, - variable_definitions=FrozenList( - self.variable_definitions.get_ast_definitions() - ), + variable_definitions=self.variable_definitions.get_ast_definitions(), **({"name": NameNode(value=self.name)} if self.name else {}), ) @@ -548,19 +544,19 @@ def __getattr__(self, name: str) -> "DSLVariable": self.variables[name] = DSLVariable(name) return self.variables[name] - def get_ast_definitions(self) -> List[VariableDefinitionNode]: + def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: """ :meta private: Return a list of VariableDefinitionNodes for each variable with a type """ - return [ + return tuple( VariableDefinitionNode( type=var.type, variable=var.ast_variable, default_value=None, ) for var in self.variables.values() if var.type is not None # only variables used - ] + ) class DSLType: @@ -770,7 +766,7 @@ def __init__( """ self.parent_type = parent_type self.field = field - self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) + self.ast_field = FieldNode(name=NameNode(value=name), arguments=()) log.debug(f"Creating {self!r}") @@ -803,15 +799,12 @@ def args(self, **kwargs) -> "DSLField": assert self.ast_field.arguments is not None - self.ast_field.arguments = FrozenList( - self.ast_field.arguments - + [ - ArgumentNode( - name=NameNode(value=name), - value=ast_from_value(value, self._get_argument(name).type), - ) - for name, value in kwargs.items() - ] + self.ast_field.arguments = self.ast_field.arguments + tuple( + ArgumentNode( + name=NameNode(value=name), + value=ast_from_value(value, self._get_argument(name).type), + ) + for name, value in kwargs.items() ) log.debug(f"Added arguments {kwargs} in field {self!r})") @@ -856,7 +849,7 @@ class DSLMetaField(DSLField): """ meta_type = GraphQLObjectType( - "meta-field", + "meta_field", fields={ "__typename": GraphQLField(GraphQLString), "__schema": GraphQLField( @@ -1022,9 +1015,7 @@ def executable_ast(self) -> FragmentDefinitionNode: return FragmentDefinitionNode( type_condition=NamedTypeNode(name=NameNode(value=self._type.name)), selection_set=self.selection_set, - variable_definitions=FrozenList( - self.variable_definitions.get_ast_definitions() - ), + variable_definitions=self.variable_definitions.get_ast_definitions(), name=NameNode(value=self.name), ) diff --git a/gql/utilities/build_client_schema.py b/gql/utilities/build_client_schema.py index 78fb7586..048ed80d 100644 --- a/gql/utilities/build_client_schema.py +++ b/gql/utilities/build_client_schema.py @@ -1,19 +1,25 @@ -from typing import Dict - -from graphql import GraphQLSchema +from graphql import GraphQLSchema, IntrospectionQuery from graphql import build_client_schema as build_client_schema_orig from graphql.pyutils import inspect +from graphql.utilities.get_introspection_query import ( + DirectiveLocation, + IntrospectionDirective, +) __all__ = ["build_client_schema"] -INCLUDE_DIRECTIVE_JSON = { +INCLUDE_DIRECTIVE_JSON: IntrospectionDirective = { "name": "include", "description": ( "Directs the executor to include this field or fragment " "only when the `if` argument is true." ), - "locations": ["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"], + "locations": [ + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ], "args": [ { "name": "if", @@ -28,13 +34,17 @@ ], } -SKIP_DIRECTIVE_JSON = { +SKIP_DIRECTIVE_JSON: IntrospectionDirective = { "name": "skip", "description": ( "Directs the executor to skip this field or fragment " "when the `if` argument is true." ), - "locations": ["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"], + "locations": [ + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ], "args": [ { "name": "if", @@ -50,7 +60,7 @@ } -def build_client_schema(introspection: Dict) -> GraphQLSchema: +def build_client_schema(introspection: IntrospectionQuery) -> GraphQLSchema: """This is an alternative to the graphql-core function :code:`build_client_schema` but with default include and skip directives added to the schema to fix @@ -77,8 +87,7 @@ def build_client_schema(introspection: Dict) -> GraphQLSchema: directives = schema_introspection.get("directives", None) if directives is None: - directives = [] - schema_introspection["directives"] = directives + schema_introspection["directives"] = directives = [] if not any(directive["name"] == "skip" for directive in directives): directives.append(SKIP_DIRECTIVE_JSON) diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index bbb07771..d053c1c0 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -59,7 +59,7 @@ def get_introspection_query_ast( if descriptions: fragment_FullType.select(ds.__Type.description) if specified_by_url: - fragment_FullType.select(ds.__Type.specifiedByUrl) + fragment_FullType.select(ds.__Type.specifiedByURL) fields = ds.__Type.fields(includeDeprecated=True).select(ds.__Field.name) diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index ecb73474..5f9dd2a4 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -102,6 +102,8 @@ def __init__( self.result_stack: List[Any] = [] + super().__init__() + @property def current_result(self): try: diff --git a/setup.py b/setup.py index 7e97f8bc..4b1de3e4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.1.5,<3.2", + "graphql-core>=3.2,<3.3", "yarl>=1.6,<2.0", ] diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 2e30b6b7..23dc281d 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -1,11 +1,12 @@ import asyncio +from math import isfinite from typing import Any, Dict, NamedTuple, Optional import pytest from graphql import graphql_sync from graphql.error import GraphQLError from graphql.language import ValueNode -from graphql.pyutils import inspect, is_finite +from graphql.pyutils import inspect from graphql.type import ( GraphQLArgument, GraphQLField, @@ -34,6 +35,13 @@ class Money(NamedTuple): currency: str +def is_finite(value: Any) -> bool: + """Return true if a value is a finite number.""" + return (isinstance(value, int) and not isinstance(value, bool)) or ( + isinstance(value, float) and isfinite(value) + ) + + def serialize_money(output_value: Any) -> Dict[str, Any]: if not isinstance(output_value, Money): raise GraphQLError("Cannot serialize money value: " + inspect(output_value)) diff --git a/tests/fixtures/vcr_cassettes/queries.yaml b/tests/fixtures/vcr_cassettes/queries.yaml index f3fa1c96..f2ff24ef 100644 --- a/tests/fixtures/vcr_cassettes/queries.yaml +++ b/tests/fixtures/vcr_cassettes/queries.yaml @@ -12,7 +12,7 @@ interactions: {\n ...TypeRef\n }\n defaultValue\n}\n\nfragment TypeRef on __Type {\n kind\n name\n ofType {\n kind\n name\n ofType {\n kind\n name\n ofType {\n kind\n name\n ofType {\n kind\n name\n ofType {\n kind\n name\n ofType - {\n kind\n name\n ofType {\n kind\n name\n }\n }\n }\n }\n }\n }\n }\n}\n"}' + {\n kind\n name\n ofType {\n kind\n name\n }\n }\n }\n }\n }\n }\n }\n}"}' headers: Accept: - '*/*' @@ -202,7 +202,7 @@ interactions: message: OK - request: body: '{"query": "{\n myFavoriteFilm: film(id: \"RmlsbToz\") {\n id\n title\n episodeId\n characters(first: - 5) {\n edges {\n node {\n name\n }\n }\n }\n }\n}\n"}' + 5) {\n edges {\n node {\n name\n }\n }\n }\n }\n}"}' headers: Accept: - '*/*' @@ -248,7 +248,7 @@ interactions: code: 200 message: OK - request: - body: '{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}\n", + body: '{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}", "variables": {"id": "UGxhbmV0OjEw"}}' headers: Accept: @@ -294,7 +294,7 @@ interactions: message: OK - request: body: '{"query": "query Planet1 {\n planet(id: \"UGxhbmV0OjEw\") {\n id\n name\n }\n}\n\nquery - Planet2 {\n planet(id: \"UGxhbmV0OjEx\") {\n id\n name\n }\n}\n", "operationName": + Planet2 {\n planet(id: \"UGxhbmV0OjEx\") {\n id\n name\n }\n}", "operationName": "Planet2"}' headers: Accept: @@ -339,7 +339,7 @@ interactions: code: 200 message: OK - request: - body: '{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}\n"}' + body: '{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}"}' headers: Accept: - '*/*' diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 7bc31037..36232147 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -1,7 +1,37 @@ import asyncio -from collections import namedtuple +from typing import Collection + + +class Character: + id: str + name: str + friends: Collection[str] + appearsIn: Collection[str] + + +# noinspection PyPep8Naming +class Human(Character): + type = "Human" + homePlanet: str + + # noinspection PyShadowingBuiltins + def __init__(self, id, name, friends, appearsIn, homePlanet): + self.id, self.name = id, name + self.friends, self.appearsIn = friends, appearsIn + self.homePlanet = homePlanet + + +# noinspection PyPep8Naming +class Droid(Character): + type = "Droid" + primaryFunction: str + + # noinspection PyShadowingBuiltins + def __init__(self, id, name, friends, appearsIn, primaryFunction): + self.id, self.name = id, name + self.friends, self.appearsIn = friends, appearsIn + self.primaryFunction = primaryFunction -Human = namedtuple("Human", "id name friends appearsIn homePlanet") luke = Human( id="1000", @@ -47,8 +77,6 @@ "1004": tarkin, } -Droid = namedtuple("Droid", "id name friends appearsIn primaryFunction") - threepio = Droid( id="2000", name="C-3PO", @@ -77,38 +105,38 @@ } -def getCharacter(id): +def get_character(id): return humanData.get(id) or droidData.get(id) -def getCharacters(ids): - return map(getCharacter, ids) +def get_characters(ids): + return map(get_character, ids) -def getFriends(character): - return map(getCharacter, character.friends) +def get_friends(character): + return map(get_character, character.friends) -def getHero(episode): +def get_hero(episode): if episode == 5: return luke return artoo -async def getHeroAsync(episode): +async def get_hero_async(episode): await asyncio.sleep(0.001) - return getHero(episode) + return get_hero(episode) -def getHuman(id): +def get_human(id): return humanData.get(id) -def getDroid(id): +def get_droid(id): return droidData.get(id) -def createReview(episode, review): +def create_review(episode, review): reviews[episode].append(review) review["episode"] = episode return review diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 95320ffe..50e2420f 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -20,99 +20,103 @@ ) from .fixtures import ( - createReview, - getCharacters, - getDroid, - getFriends, - getHeroAsync, - getHuman, + create_review, + get_characters, + get_droid, + get_friends, + get_hero_async, + get_human, reviews, ) -episodeEnum = GraphQLEnumType( +episode_enum = GraphQLEnumType( "Episode", - description="One of the films in the Star Wars Trilogy", - values={ + { "NEWHOPE": GraphQLEnumValue(4, description="Released in 1977.",), "EMPIRE": GraphQLEnumValue(5, description="Released in 1980.",), "JEDI": GraphQLEnumValue(6, description="Released in 1983.",), }, + description="One of the films in the Star Wars Trilogy", ) -characterInterface = GraphQLInterfaceType( + +human_type: GraphQLObjectType +droid_type: GraphQLObjectType + +character_interface = GraphQLInterfaceType( "Character", - description="A character in the Star Wars Trilogy", - fields=lambda: { + lambda: { "id": GraphQLField( GraphQLNonNull(GraphQLString), description="The id of the character." ), "name": GraphQLField(GraphQLString, description="The name of the character."), "friends": GraphQLField( - GraphQLList(characterInterface), # type: ignore + GraphQLList(character_interface), # type: ignore description="The friends of the character," " or an empty list if they have none.", ), "appearsIn": GraphQLField( - GraphQLList(episodeEnum), description="Which movies they appear in." + GraphQLList(episode_enum), description="Which movies they appear in." ), }, - resolve_type=lambda character, *_: humanType # type: ignore - if getHuman(character.id) - else droidType, # type: ignore + resolve_type=lambda character, _info, _type: { + "Human": human_type.name, + "Droid": droid_type.name, + }[character.type], + description="A character in the Star Wars Trilogy", ) -humanType = GraphQLObjectType( +human_type = GraphQLObjectType( "Human", - description="A humanoid creature in the Star Wars universe.", - fields=lambda: { + lambda: { "id": GraphQLField( GraphQLNonNull(GraphQLString), description="The id of the human.", ), "name": GraphQLField(GraphQLString, description="The name of the human.",), "friends": GraphQLField( - GraphQLList(characterInterface), + GraphQLList(character_interface), description="The friends of the human, or an empty list if they have none.", - resolve=lambda human, info, **args: getFriends(human), + resolve=lambda human, _info: get_friends(human), ), "appearsIn": GraphQLField( - GraphQLList(episodeEnum), description="Which movies they appear in.", + GraphQLList(episode_enum), description="Which movies they appear in.", ), "homePlanet": GraphQLField( GraphQLString, description="The home planet of the human, or null if unknown.", ), }, - interfaces=[characterInterface], + interfaces=[character_interface], + description="A humanoid creature in the Star Wars universe.", ) -droidType = GraphQLObjectType( +droid_type = GraphQLObjectType( "Droid", - description="A mechanical creature in the Star Wars universe.", - fields=lambda: { + lambda: { "id": GraphQLField( GraphQLNonNull(GraphQLString), description="The id of the droid.", ), "name": GraphQLField(GraphQLString, description="The name of the droid.",), "friends": GraphQLField( - GraphQLList(characterInterface), + GraphQLList(character_interface), description="The friends of the droid, or an empty list if they have none.", - resolve=lambda droid, info, **args: getFriends(droid), + resolve=lambda droid, _info: get_friends(droid), ), "appearsIn": GraphQLField( - GraphQLList(episodeEnum), description="Which movies they appear in.", + GraphQLList(episode_enum), description="Which movies they appear in.", ), "primaryFunction": GraphQLField( GraphQLString, description="The primary function of the droid.", ), }, - interfaces=[characterInterface], + interfaces=[character_interface], + description="A mechanical creature in the Star Wars universe.", ) -reviewType = GraphQLObjectType( +review_type = GraphQLObjectType( "Review", - description="Represents a review for a movie", - fields=lambda: { - "episode": GraphQLField(episodeEnum, description="The movie"), + lambda: { + "episode": GraphQLField(episode_enum, description="The movie"), "stars": GraphQLField( GraphQLNonNull(GraphQLInt), description="The number of stars this review gave, 1-5", @@ -121,84 +125,83 @@ GraphQLString, description="Comment about the movie" ), }, + description="Represents a review for a movie", ) -reviewInputType = GraphQLInputObjectType( +review_input_type = GraphQLInputObjectType( "ReviewInput", - description="The input object sent when someone is creating a new review", - fields={ + lambda: { "stars": GraphQLInputField(GraphQLInt, description="0-5 stars"), "commentary": GraphQLInputField( GraphQLString, description="Comment about the movie, optional" ), }, + description="The input object sent when someone is creating a new review", ) -queryType = GraphQLObjectType( +query_type = GraphQLObjectType( "Query", - fields=lambda: { + lambda: { "hero": GraphQLField( - characterInterface, + character_interface, args={ "episode": GraphQLArgument( + episode_enum, description="If omitted, returns the hero of the whole saga. If " "provided, returns the hero of that particular episode.", - type_=episodeEnum, # type: ignore ) }, - resolve=lambda root, info, **args: getHeroAsync(args.get("episode")), + resolve=lambda _souce, _info, episode=None: get_hero_async(episode), ), "human": GraphQLField( - humanType, + human_type, args={ "id": GraphQLArgument( description="id of the human", type_=GraphQLNonNull(GraphQLString), ) }, - resolve=lambda root, info, **args: getHuman(args["id"]), + resolve=lambda _souce, _info, id: get_human(id), ), "droid": GraphQLField( - droidType, + droid_type, args={ "id": GraphQLArgument( description="id of the droid", type_=GraphQLNonNull(GraphQLString), ) }, - resolve=lambda root, info, **args: getDroid(args["id"]), + resolve=lambda _source, _info, id: get_droid(id), ), "characters": GraphQLField( - GraphQLList(characterInterface), + GraphQLList(character_interface), args={ "ids": GraphQLArgument( - description="list of character ids", - type_=GraphQLList(GraphQLString), + GraphQLList(GraphQLString), description="list of character ids", ) }, - resolve=lambda root, info, **args: getCharacters(args["ids"]), + resolve=lambda _source, _info, ids=None: get_characters(ids), ), }, ) -mutationType = GraphQLObjectType( +mutation_type = GraphQLObjectType( "Mutation", - description="The mutation type, represents all updates we can make to our data", - fields=lambda: { + lambda: { "createReview": GraphQLField( - reviewType, + review_type, args={ "episode": GraphQLArgument( - description="Episode to create review", - type_=episodeEnum, # type: ignore + episode_enum, description="Episode to create review", ), "review": GraphQLArgument( - description="set alive status", type_=reviewInputType, + description="set alive status", type_=review_input_type, ), }, - resolve=lambda root, info, **args: createReview( - args.get("episode"), args.get("review") + resolve=lambda _source, _info, episode=None, review=None: create_review( + episode, review ), ), }, + description="The mutation type, represents all updates we can make to our data", ) @@ -212,14 +215,14 @@ async def resolve_review(review, _info, **_args): return review -subscriptionType = GraphQLObjectType( +subscription_type = GraphQLObjectType( "Subscription", - fields=lambda: { + lambda: { "reviewAdded": GraphQLField( - reviewType, + review_type, args={ "episode": GraphQLArgument( - description="Episode to review", type_=episodeEnum, + episode_enum, description="Episode to review", ) }, subscribe=subscribe_reviews, @@ -230,10 +233,10 @@ async def resolve_review(review, _info, **_args): StarWarsSchema = GraphQLSchema( - query=queryType, - mutation=mutationType, - subscription=subscriptionType, - types=[humanType, droidType, reviewType, reviewInputType], + query=query_type, + mutation=mutation_type, + subscription=subscription_type, + types=[human_type, droid_type, review_type, review_input_type], ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 0335d721..6adc84a9 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -129,8 +129,7 @@ def test_use_variable_definition_multiple_times(ds): stars commentary } -} -""" +}""" ) @@ -151,8 +150,7 @@ def test_add_variable_definitions(ds): stars commentary } -} -""" +}""" ) @@ -177,8 +175,7 @@ def test_add_variable_definitions_in_input_object(ds): stars commentary } -} -""" +}""" ) @@ -376,8 +373,7 @@ def test_subscription(ds): stars commentary } -} -""" +}""" ) @@ -445,8 +441,7 @@ def test_operation_name(ds): hero { name } -} -""" +}""" ) @@ -476,8 +471,7 @@ def test_multiple_operations(ds): stars commentary } -} -""" +}""" ) @@ -535,8 +529,7 @@ def test_fragments(ds): hero { ...NameAndAppearances } -} -""" +}""" name_and_appearances = ( DSLFragment("NameAndAppearances") @@ -624,8 +617,7 @@ def test_dsl_nested_query_with_fragment(ds): } } } -} -""" +}""" name_and_appearances = ( DSLFragment("NameAndAppearances") diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 520018c1..430aa18e 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,5 +1,5 @@ import pytest -from graphql import GraphQLError, format_error +from graphql import GraphQLError from gql import Client, gql from tests.starwars.schema import StarWarsSchema @@ -302,9 +302,7 @@ def test_parse_error(client): ) error = exc_info.value assert isinstance(error, GraphQLError) - formatted_error = format_error(error) - assert formatted_error["locations"] == [{"column": 13, "line": 2}] - assert formatted_error["message"] == "Syntax Error: Unexpected Name 'qeury'." + assert "Syntax Error: Unexpected Name 'qeury'." in str(error) def test_mutation_result(client): diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 682cea0d..f66dc1a9 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -496,7 +496,7 @@ def test_code(): file_upload_mutation_1_operations = ( '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' - '$other_var, file: $file}) {\\n success\\n }\\n}\\n", "variables": ' + '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -763,7 +763,7 @@ async def file_sender(file_name): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}\\n", ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -859,7 +859,7 @@ async def handler(request): file_upload_mutation_3_operations = ( '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(input: {files: $files})' - ' {\\n success\\n }\\n}\\n", "variables": {"files": [null, null]}}' + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' diff --git a/tests/test_requests.py b/tests/test_requests.py index c3123d72..1ed4ca56 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -347,7 +347,7 @@ def test_code(): file_upload_mutation_1_operations = ( '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' - '$other_var, file: $file}) {\\n success\\n }\\n}\\n", "variables": ' + '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -551,7 +551,7 @@ def test_code(): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}\\n", ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -651,7 +651,7 @@ def test_code(): file_upload_mutation_3_operations = ( '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(input: {files: $files})' - ' {\\n success\\n }\\n}\\n", "variables": {"files": [null, null]}}' + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) From 4ed23ac4b9478b53e86fb4e09ee0e24a55a91f97 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 16 Jan 2022 14:48:34 +0100 Subject: [PATCH 059/239] Bump version number to 3.0.0rc1 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index ce971bbf..451cd1fa 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0rc0" +__version__ = "3.0.0rc1" From e20534ce4e2324abd22a992b1e6d6a15125f1c39 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 22 Jan 2022 17:03:04 +0100 Subject: [PATCH 060/239] Update docs in preparation of the stable 3.0.0 release --- README.md | 9 +++------ docs/index.rst | 7 +------ docs/intro.rst | 9 ++------- docs/usage/file_upload.rst | 2 +- 4 files changed, 7 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 08484c2f..ea5e3074 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,6 @@ Plays nicely with `graphene`, `graphql-core`, `graphql-js` and any other GraphQL GQL architecture is inspired by `React-Relay` and `Apollo-Client`. -> **WARNING**: Please note that the following documentation describes the current version which is currently only available as a pre-release -> The documentation for the 2.x version compatible with python<3.6 is available in the [2.x branch](https://github.com/graphql-python/gql/tree/v2.x) - [![GitHub-Actions][gh-image]][gh-url] [![pyversion][pyversion-image]][pyversion-url] [![pypi][pypi-image]][pypi-url] @@ -48,11 +45,11 @@ The complete documentation for GQL can be found at ## Installation -> **WARNING**: Please note that the following documentation describes the current version which is currently only available as a pre-release and needs to be installed with +You can install GQL with all the optional dependencies using pip: - $ pip install --pre gql[all] + pip install gql[all] -> **NOTE**: See also [the documentation](https://gql.readthedocs.io/en/latest/intro.html#less-dependencies) to install GQL with less extra dependencies +> **NOTE**: See also [the documentation](https://gql.readthedocs.io/en/latest/intro.html#less-dependencies) to install GQL with less extra dependencies depending on the transports you would like to use ## Usage diff --git a/docs/index.rst b/docs/index.rst index ff63ed3a..ecb2f0e1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,10 +1,5 @@ Welcome to GQL 3 documentation! -================================= - -.. warning:: - - Please note that the following documentation describes the current version which is currently only available - as a pre-release and needs to be installed with "`--pre`" +=============================== Contents -------- diff --git a/docs/intro.rst b/docs/intro.rst index 1cd3f5c8..a6e8ee21 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -12,12 +12,7 @@ Installation You can install GQL 3 and all the extra dependencies using pip_:: - pip install --pre gql[all] - -.. warning:: - - Please note that the following documentation describes the current version which is currently only available - as a pre-release and needs to be installed with "`--pre`" + pip install gql[all] After installation, you can start using GQL by importing from the top-level :mod:`gql` package. @@ -33,7 +28,7 @@ instead of using the "`all`" extra dependency as described above, which installs If for example you only need the :ref:`AIOHTTPTransport `, which needs the :code:`aiohttp` dependency, then you can install GQL with:: - pip install --pre gql[aiohttp] + pip install gql[aiohttp] The corresponding between extra dependencies required and the GQL classes is: diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index cfc85df9..8062f317 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -60,7 +60,7 @@ It is also possible to upload multiple files using a list. ''') f1 = open("YOUR_FILE_PATH_1", "rb") - f2 = open("YOUR_FILE_PATH_1", "rb") + f2 = open("YOUR_FILE_PATH_2", "rb") params = {"files": [f1, f2]} From 514769b4d0f120cbca4e6f7ff77b3aba72892c98 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 22 Jan 2022 17:11:10 +0100 Subject: [PATCH 061/239] setup.py upgrade development status to production/stable --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b1de3e4..07bab00e 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ author_email="me@syrusakbary.com", license="MIT", classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3", From 12fc895bd45c297ce7664358fcdebb216fe11cac Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 22 Jan 2022 17:12:49 +0100 Subject: [PATCH 062/239] Bump version number to 3.0.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 451cd1fa..528787cf 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0rc1" +__version__ = "3.0.0" From 0084b95d546e13569305c32f0fdf0ba9ebc169b6 Mon Sep 17 00:00:00 2001 From: joricht <53896809+joricht@users.noreply.github.com> Date: Tue, 22 Feb 2022 08:13:26 +0100 Subject: [PATCH 063/239] Close transport when fetching the schema failed (#297) --- gql/client.py | 22 +++++++++++++++++---- tests/test_client.py | 47 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/gql/client.py b/gql/client.py index 91cbcde6..5203d17d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -271,8 +271,15 @@ async def __aenter__(self): self.session = AsyncClientSession(client=self) # Get schema from transport if needed - if self.fetch_schema_from_transport and not self.schema: - await self.session.fetch_schema() + try: + if self.fetch_schema_from_transport and not self.schema: + await self.session.fetch_schema() + except Exception: + # we don't know what type of exception is thrown here because it + # depends on the underlying transport; we just make sure that the + # transport is closed and re-raise the exception + await self.transport.close() + raise return self.session @@ -293,8 +300,15 @@ def __enter__(self): self.session = SyncClientSession(client=self) # Get schema from transport if needed - if self.fetch_schema_from_transport and not self.schema: - self.session.fetch_schema() + try: + if self.fetch_schema_from_transport and not self.schema: + self.session.fetch_schema() + except Exception: + # we don't know what type of exception is thrown here because it + # depends on the underlying transport; we just make sure that the + # transport is closed and re-raise the exception + self.transport.close() + raise return self.session diff --git a/tests/test_client.py b/tests/test_client.py index c8df40ee..fecdf43d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -200,3 +200,50 @@ def test_gql(): client = Client(schema=schema) result = client.execute(query) assert result["user"] is None + + +@pytest.mark.requests +def test_sync_transport_close_on_schema_retrieval_failure(): + """ + Ensure that the transport session is closed if an error occurs when + entering the context manager (e.g., because schema retrieval fails) + """ + + from gql.transport.requests import RequestsHTTPTransport + + transport = RequestsHTTPTransport(url="http://localhost/") + client = Client(transport=transport, fetch_schema_from_transport=True) + + try: + with client: + pass + except Exception: + # we don't care what exception is thrown, we just want to check if the + # transport is closed afterwards + pass + + assert client.transport.session is None + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_async_transport_close_on_schema_retrieval_failure(): + """ + Ensure that the transport session is closed if an error occurs when + entering the context manager (e.g., because schema retrieval fails) + """ + + from gql.transport.aiohttp import AIOHTTPTransport + + transport = AIOHTTPTransport(url="http://localhost/") + client = Client(transport=transport, fetch_schema_from_transport=True) + + try: + async with client: + pass + except Exception: + # we don't care what exception is thrown, we just want to check if the + # transport is closed afterwards + pass + + assert client.transport.session is None From d3be91616721cfabe84755b476ec1c7f46d27ca4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 22 Feb 2022 08:20:35 +0100 Subject: [PATCH 064/239] Fix errors raising TransportProtocolError with the graphql-ws protocol (#299) --- gql/transport/websockets.py | 12 +++++--- tests/test_graphqlws_exceptions.py | 47 ++++++++++-------------------- 2 files changed, 23 insertions(+), 36 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 41478daf..04983ef8 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -272,6 +272,7 @@ def _parse_answer_graphqlws( - instead of a unidirectional keep-alive (ka) message from server to client, there is now the possibility to send bidirectional ping/pong messages - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error """ answer_type: str = "" @@ -288,11 +289,11 @@ def _parse_answer_graphqlws( payload = json_answer.get("payload") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - if answer_type == "next": + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + if "errors" not in payload and "data" not in payload: raise ValueError( "payload does not contain 'data' or 'errors' fields" @@ -309,8 +310,11 @@ def _parse_answer_graphqlws( elif answer_type == "error": + if not isinstance(payload, list): + raise ValueError("payload is not a list") + raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] + str(payload[0]), query_id=answer_id, errors=payload ) elif answer_type in ["ping", "pong", "connection_ack"]: diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index 8a2e7495..ca689c47 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -1,6 +1,4 @@ import asyncio -import json -import types from typing import List import pytest @@ -125,49 +123,29 @@ async def test_graphqlws_server_does_not_send_ack( pass -invalid_payload_server_answer = ( - '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +invalid_query_server_answer = ( + '{"id":"1","type":"error","payload":[{"message":"Cannot query field ' + '\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",' + '"locations":[{"line":2,"column":3}]}]}' ) -async def server_invalid_payload(ws, path): +async def server_invalid_query(ws, path): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") - await ws.send(invalid_payload_server_answer) + await ws.send(invalid_query_server_answer) await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() @pytest.mark.asyncio -@pytest.mark.parametrize("graphqlws_server", [server_invalid_payload], indirect=True) -@pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_graphqlws_sending_invalid_payload( - event_loop, client_and_graphqlws_server, query_str -): +@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) +async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_server): session, server = client_and_graphqlws_server - # Monkey patching the _send_query method to send an invalid payload - - async def monkey_patch_send_query( - self, document, variable_values=None, operation_name=None, - ) -> int: - query_id = self.next_query_id - self.next_query_id += 1 - - query_str = json.dumps( - {"id": str(query_id), "type": "subscribe", "payload": "BLAHBLAH"} - ) - - await self._send(query_str) - return query_id - - session.transport._send_query = types.MethodType( - monkey_patch_send_query, session.transport - ) - - query = gql(query_str) + query = gql("{helo}") with pytest.raises(TransportQueryError) as exc_info: await session.execute(query) @@ -178,7 +156,10 @@ async def monkey_patch_send_query( error = exception.errors[0] - assert error["message"] == "Must provide document" + assert ( + error["message"] + == 'Cannot query field "helo" on type "Query". Did you mean "hello"?' + ) not_json_answer = ["BLAHBLAH"] @@ -188,6 +169,7 @@ async def monkey_patch_send_query( missing_id_answer_3 = ['{"type": "complete"}'] data_without_payload = ['{"type": "next", "id":"1"}'] error_without_payload = ['{"type": "error", "id":"1"}'] +error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}'] payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}'] empty_payload = ['{"type": "next", "id":"1", "payload": {}}'] sending_bytes = [b"\x01\x02\x03"] @@ -205,6 +187,7 @@ async def monkey_patch_send_query( data_without_payload, error_without_payload, payload_is_not_a_dict, + error_with_payload_not_a_list, empty_payload, sending_bytes, ], From 5127e8cdaea86ebdae21b5d3e2f60776764af432 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 22 Feb 2022 08:39:24 +0100 Subject: [PATCH 065/239] Allow to specify subprotocols in the websockets transport (#300) --- docs/transports/websockets.rst | 10 +++++++++- gql/transport/websockets.py | 17 ++++++++++++----- tests/conftest.py | 4 +++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/transports/websockets.rst b/docs/transports/websockets.rst index 689cc136..23e4735a 100644 --- a/docs/transports/websockets.rst +++ b/docs/transports/websockets.rst @@ -8,7 +8,14 @@ The websockets transport supports both: - the `Apollo websockets transport protocol`_. - the `GraphQL-ws websockets transport protocol`_ -It will detect the backend supported protocol from the response http headers returned. +It will propose both subprotocols to the backend and detect the supported protocol +from the response http headers returned by the backend. + +.. note:: + For some backends (graphql-ws before `version 5.6.1`_ without backwards compatibility), it may be necessary to specify + only one subprotocol to the backend. It can be done by using + :code:`subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL]` + or :code:`subprotocols=[WebsocketsTransport.APOLLO_SUBPROTOCOL]` in the transport arguments. This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. @@ -118,5 +125,6 @@ Here is an example with a ping sent every 60 seconds, expecting a pong within 10 pong_timeout=10, ) +.. _version 5.6.1: https://github.com/enisdenjo/graphql-ws/releases/tag/v5.6.1 .. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md .. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 04983ef8..1650624e 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -3,7 +3,7 @@ import logging from contextlib import suppress from ssl import SSLContext -from typing import Any, Dict, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike @@ -46,6 +46,7 @@ def __init__( pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, connect_args: Dict[str, Any] = {}, + subprotocols: Optional[List[Subprotocol]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -71,6 +72,9 @@ def __init__( (for the graphql-ws protocol). By default: True :param connect_args: Other parameters forwarded to websockets.connect + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. """ super().__init__( @@ -105,10 +109,13 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] + if subprotocols is None: + self.supported_subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + else: + self.supported_subprotocols = subprotocols async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" diff --git a/tests/conftest.py b/tests/conftest.py index d433c1ca..fbb881ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -465,7 +465,9 @@ async def client_and_graphqlws_server(graphqlws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + sample_transport = WebsocketsTransport( + url=url, subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + ) async with Client(transport=sample_transport) as session: From 23c6f8561f8e89695f4795724a0f4d08dab31060 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 10 Mar 2022 12:30:05 +0100 Subject: [PATCH 066/239] Client: Add explicit overloads and remove *args arguments (#306) --- gql/client.py | 412 ++++++++++++++++++++++++++++++++++++++---- tests/test_aiohttp.py | 34 ++++ 2 files changed, 412 insertions(+), 34 deletions(-) diff --git a/gql/client.py b/gql/client.py index 5203d17d..c0972133 100644 --- a/gql/client.py +++ b/gql/client.py @@ -131,17 +131,186 @@ def validate(self, document: DocumentNode): if validation_errors: raise validation_errors[0] - def execute_sync(self, document: DocumentNode, *args, **kwargs) -> Dict: + @overload + def execute_sync( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> Dict[str, Any]: + ... # pragma: no cover + + @overload + def execute_sync( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[True], + **kwargs, + ) -> ExecutionResult: + ... # pragma: no cover + + @overload + def execute_sync( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover + + def execute_sync( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" with self as session: - return session.execute(document, *args, **kwargs) + return session.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) - async def execute_async(self, document: DocumentNode, *args, **kwargs) -> Dict: + @overload + async def execute_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> Dict[str, Any]: + ... # pragma: no cover + + @overload + async def execute_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[True], + **kwargs, + ) -> ExecutionResult: + ... # pragma: no cover + + @overload + async def execute_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover + + async def execute_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" async with self as session: - return await session.execute(document, *args, **kwargs) + return await session.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + + @overload + def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> Dict[str, Any]: + ... # pragma: no cover + + @overload + def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[True], + **kwargs, + ) -> ExecutionResult: + ... # pragma: no cover - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + @overload + def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover + + def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: """Execute the provided document AST against the remote server using the transport provided during init. @@ -183,31 +352,160 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: " Use 'await client.execute_async(query)' instead." ) - data: Dict[Any, Any] = loop.run_until_complete( - self.execute_async(document, *args, **kwargs) + data = loop.run_until_complete( + self.execute_async( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) ) return data else: # Sync transports - return self.execute_sync(document, *args, **kwargs) + return self.execute_sync( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + + @overload + def subscribe_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> AsyncGenerator[Dict[str, Any], None]: + ... # pragma: no cover + + @overload + def subscribe_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[True], + **kwargs, + ) -> AsyncGenerator[ExecutionResult, None]: + ... # pragma: no cover + + @overload + def subscribe_async( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[ + AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] + ]: + ... # pragma: no cover async def subscribe_async( - self, document: DocumentNode, *args, **kwargs - ) -> AsyncGenerator[Dict, None]: + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[ + AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] + ]: """:meta private:""" async with self as session: - generator: AsyncGenerator[Dict, None] = session.subscribe( - document, *args, **kwargs + generator = session.subscribe( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, ) async for result in generator: yield result + @overload + def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[False] = ..., + **kwargs, + ) -> Generator[Dict[str, Any], None, None]: + ... # pragma: no cover + + @overload def subscribe( - self, document: DocumentNode, *args, **kwargs - ) -> Generator[Dict, None, None]: + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: Literal[True], + **kwargs, + ) -> Generator[ExecutionResult, None, None]: + ... # pragma: no cover + + @overload + def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[ + Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] + ]: + ... # pragma: no cover + + def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + *, + get_execution_result: bool = False, + **kwargs, + ) -> Union[ + Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] + ]: """Execute a GraphQL subscription with a python generator. We need an async transport for this functionality. @@ -225,7 +523,17 @@ def subscribe( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - async_generator = self.subscribe_async(document, *args, **kwargs) + async_generator: Union[ + AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] + ] = self.subscribe_async( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) assert not loop.is_running(), ( "Cannot run client.subscribe(query) if an asyncio loop is running." @@ -240,7 +548,11 @@ def subscribe( generator_task = asyncio.ensure_future( async_generator.__anext__(), loop=loop ) - result = loop.run_until_complete(generator_task) + result: Union[ + Dict[str, Any], ExecutionResult + ] = loop.run_until_complete( + generator_task + ) # type: ignore yield result except StopAsyncIteration: @@ -330,7 +642,6 @@ def __init__(self, client: Client): def _execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -369,7 +680,6 @@ def _execute( result = self.transport.execute( document, - *args, variable_values=variable_values, operation_name=operation_name, **kwargs, @@ -391,11 +701,11 @@ def _execute( def execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., + *, get_execution_result: Literal[False] = ..., **kwargs, ) -> Dict[str, Any]: @@ -405,20 +715,33 @@ def execute( def execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., + *, get_execution_result: Literal[True], **kwargs, ) -> ExecutionResult: ... # pragma: no cover + @overload + def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover + def execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -448,7 +771,6 @@ def execute( # Validate and execute on the transport result = self._execute( document, - *args, variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, @@ -503,7 +825,6 @@ def __init__(self, client: Client): async def _subscribe( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -549,7 +870,6 @@ async def _subscribe( ExecutionResult, None ] = self.transport.subscribe( document, - *args, variable_values=variable_values, operation_name=operation_name, **kwargs, @@ -582,11 +902,11 @@ async def _subscribe( def subscribe( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., + *, get_execution_result: Literal[False] = ..., **kwargs, ) -> AsyncGenerator[Dict[str, Any], None]: @@ -596,20 +916,35 @@ def subscribe( def subscribe( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., + *, get_execution_result: Literal[True], **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover + @overload + def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[ + AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] + ]: + ... # pragma: no cover + async def subscribe( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -640,7 +975,6 @@ async def subscribe( inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( document, - *args, variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, @@ -672,7 +1006,6 @@ async def subscribe( async def _execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -718,7 +1051,6 @@ async def _execute( document, variable_values=variable_values, operation_name=operation_name, - *args, **kwargs, ), self.client.execute_timeout, @@ -740,11 +1072,11 @@ async def _execute( async def execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., + *, get_execution_result: Literal[False] = ..., **kwargs, ) -> Dict[str, Any]: @@ -754,20 +1086,33 @@ async def execute( async def execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., + *, get_execution_result: Literal[True], **kwargs, ) -> ExecutionResult: ... # pragma: no cover + @overload + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = ..., + operation_name: Optional[str] = ..., + serialize_variables: Optional[bool] = ..., + parse_result: Optional[bool] = ..., + *, + get_execution_result: bool, + **kwargs, + ) -> Union[Dict[str, Any], ExecutionResult]: + ... # pragma: no cover + async def execute( self, document: DocumentNode, - *args, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -797,7 +1142,6 @@ async def execute( # Validate and execute on the transport result = await self._execute( document, - *args, variable_values=variable_values, operation_name=operation_name, serialize_variables=serialize_variables, diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f66dc1a9..030c9134 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -419,6 +419,40 @@ async def handler(request): assert continent["name"] == "Europe" +@pytest.mark.asyncio +async def test_aiohttp_query_variable_values_fix_issue_292(event_loop, aiohttp_server): + """Allow to specify variable_values without keyword. + + See https://github.com/graphql-python/gql/issues/292""" + + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + params = {"code": "EU"} + + query = gql(query2_str) + + # Execute query asynchronously + result = await session.execute(query, params, operation_name="getEurope") + + continent = result["continent"] + + assert continent["name"] == "Europe" + + @pytest.mark.asyncio async def test_aiohttp_execute_running_in_thread( event_loop, aiohttp_server, run_sync_test From 074ca9e31bcef18eb1d84314955d78b86af9b5b3 Mon Sep 17 00:00:00 2001 From: sondale-git <61547150+sondale-git@users.noreply.github.com> Date: Thu, 10 Mar 2022 18:40:42 +0100 Subject: [PATCH 067/239] Saving http response headers reference in transports (#293) --- docs/usage/headers.rst | 2 + gql/transport/aiohttp.py | 5 ++ gql/transport/requests.py | 3 + gql/transport/websockets_base.py | 6 +- tests/conftest.py | 3 + tests/test_aiohttp.py | 108 +++++++++++++++++-------------- tests/test_requests.py | 69 +++++++++++--------- tests/test_websocket_query.py | 61 +++++++++-------- 8 files changed, 150 insertions(+), 107 deletions(-) diff --git a/docs/usage/headers.rst b/docs/usage/headers.rst index 23af64a7..b41c8b43 100644 --- a/docs/usage/headers.rst +++ b/docs/usage/headers.rst @@ -6,3 +6,5 @@ If you want to add additional http headers for your connection, you can specify .. code-block:: python transport = AIOHTTPTransport(url='YOUR_URL', headers={'Authorization': 'token'}) + +After the connection, the latest response headers can be found in :code:`transport.response_headers` diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 12c57068..7df60417 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -12,6 +12,7 @@ from aiohttp.helpers import BasicAuth from aiohttp.typedefs import LooseCookies, LooseHeaders from graphql import DocumentNode, ExecutionResult, print_ast +from multidict import CIMultiDictProxy from ..utils import extract_files from .appsync_auth import AppSyncAuthentication @@ -75,6 +76,7 @@ def __init__( self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None + self.response_headers: Optional[CIMultiDictProxy[str]] async def connect(self) -> None: """Coroutine which will create an aiohttp ClientSession() as self.session. @@ -311,6 +313,9 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): if "errors" not in result and "data" not in result: await raise_response_error(resp, 'No "data" or "errors" keys in answer') + # Saving latest response headers in the transport + self.response_headers = resp.headers + return ExecutionResult( errors=result.get("errors"), data=result.get("data"), diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 32e57478..a34e7542 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -81,6 +81,8 @@ def __init__( self.session = None + self.response_headers = None + def connect(self): if self.session is None: @@ -217,6 +219,7 @@ def execute( # type: ignore response = self.session.request( self.method, self.url, **post_args # type: ignore ) + self.response_headers = response.headers def raise_response_error(resp: requests.Response, reason: str): # We raise a TransportServerError if the status code is 400 or higher diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index 151e444e..7a83c47f 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -9,7 +9,7 @@ import websockets from graphql import DocumentNode, ExecutionResult from websockets.client import WebSocketClientProtocol -from websockets.datastructures import HeadersLike +from websockets.datastructures import Headers, HeadersLike from websockets.exceptions import ConnectionClosed from websockets.typing import Data, Subprotocol @@ -169,6 +169,8 @@ def __init__( # The list of supported subprotocols should be defined in the subclass self.supported_subprotocols: List[Subprotocol] = [] + self.response_headers: Optional[Headers] = None + async def _initialize(self): """Hook to send the initialization messages after the connection and potentially wait for the backend ack. @@ -495,6 +497,8 @@ async def connect(self) -> None: self.websocket = cast(WebSocketClientProtocol, self.websocket) + self.response_headers = self.websocket.response_headers + # Run the after_connect hook of the subclass await self._after_connect() diff --git a/tests/conftest.py b/tests/conftest.py index fbb881ec..8d0b95ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,6 +174,9 @@ async def start(self, handler, extra_serve_args=None): self.testcert, ssl_context = get_localhost_ssl_context() extra_serve_args["ssl"] = ssl_context + # Adding dummy response headers + extra_serve_args["extra_headers"] = {"dummy": "test1234"} + # Start a server with a random open port self.start_server = websockets.server.serve( handler, "127.0.0.1", 0, **extra_serve_args diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 030c9134..ab02e8f5 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,5 +1,6 @@ import io import json +from typing import Mapping import pytest @@ -45,7 +46,11 @@ async def test_aiohttp_query(event_loop, aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): - return web.Response(text=query1_server_answer, content_type="application/json") + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) app = web.Application() app.router.add_route("POST", "/", handler) @@ -53,9 +58,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -68,6 +73,11 @@ async def handler(request): assert africa["code"] == "AF" + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + @pytest.mark.asyncio async def test_aiohttp_ignore_backend_content_type(event_loop, aiohttp_server): @@ -83,9 +93,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -115,9 +125,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, cookies={"cookie1": "val1"}) + transport = AIOHTTPTransport(url=url, cookies={"cookie1": "val1"}) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -150,9 +160,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -177,9 +187,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -208,9 +218,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -259,9 +269,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -285,9 +295,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -310,9 +320,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: with pytest.raises(TransportAlreadyConnected): await session.transport.connect() @@ -332,12 +342,12 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) query = gql(query1_str) with pytest.raises(TransportClosed): - await sample_transport.execute(query) + await transport.execute(query) @pytest.mark.asyncio @@ -358,11 +368,11 @@ async def handler(request): from aiohttp import DummyCookieJar jar = DummyCookieJar() - sample_transport = AIOHTTPTransport( + transport = AIOHTTPTransport( url=url, timeout=10, client_session_args={"version": "1.1", "cookie_jar": jar} ) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -401,9 +411,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: params = {"code": "EU"} @@ -437,9 +447,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: params = {"code": "EU"} @@ -470,9 +480,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) query = gql(query1_str) @@ -498,9 +508,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) query = gql(query1_str) @@ -580,11 +590,11 @@ async def test_aiohttp_file_upload(event_loop, aiohttp_server): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) with TemporaryFile(file_1_content) as test_file: - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) @@ -618,11 +628,11 @@ async def test_aiohttp_file_upload_without_session( url = server.make_url("/") def test_code(): - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) with TemporaryFile(file_1_content) as test_file: - client = Client(transport=sample_transport,) + client = Client(transport=transport) query = gql(file_upload_mutation_1) @@ -685,11 +695,11 @@ async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) with TemporaryFile(binary_file_content) as test_file: - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) @@ -728,9 +738,9 @@ async def binary_data_handler(request): url = server.make_url("/") binary_data_url = server.make_url("/binary_data") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) async with ClientSession() as client: async with client.get(binary_data_url) as resp: @@ -758,11 +768,11 @@ async def test_aiohttp_async_generator_upload(event_loop, aiohttp_server): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) with TemporaryFile(binary_file_content) as test_file: - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) @@ -851,12 +861,12 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) with TemporaryFile(file_1_content) as test_file_1: with TemporaryFile(file_2_content) as test_file_2: - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(file_upload_mutation_2) @@ -941,12 +951,12 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) with TemporaryFile(file_1_content) as test_file_1: with TemporaryFile(file_2_content) as test_file_2: - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(file_upload_mutation_3) @@ -1098,9 +1108,9 @@ async def handler(request): url = server.make_url("/") - sample_transport = AIOHTTPTransport(url=url, timeout=10) + transport = AIOHTTPTransport(url=url, timeout=10) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) @@ -1126,11 +1136,11 @@ async def handler(request): assert str(url).startswith("https://") - sample_transport = AIOHTTPTransport( + transport = AIOHTTPTransport( url=url, timeout=10, ssl_close_timeout=ssl_close_timeout ) - async with Client(transport=sample_transport,) as session: + async with Client(transport=transport) as session: query = gql(query1_str) diff --git a/tests/test_requests.py b/tests/test_requests.py index 1ed4ca56..7cd7f712 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,3 +1,5 @@ +from typing import Mapping + import pytest from gql import Client, gql @@ -38,7 +40,11 @@ async def test_requests_query(event_loop, aiohttp_server, run_sync_test): from gql.transport.requests import RequestsHTTPTransport async def handler(request): - return web.Response(text=query1_server_answer, content_type="application/json") + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) app = web.Application() app.router.add_route("POST", "/", handler) @@ -47,9 +53,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -62,6 +68,11 @@ def test_code(): assert africa["code"] == "AF" + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + await run_sync_test(event_loop, server, test_code) @@ -84,9 +95,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url, cookies={"cookie1": "val1"}) + transport = RequestsHTTPTransport(url=url, cookies={"cookie1": "val1"}) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -123,9 +134,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -154,9 +165,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -187,9 +198,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -225,9 +236,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -253,9 +264,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: with pytest.raises(TransportAlreadyConnected): session.transport.connect() @@ -281,12 +292,12 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) query = gql(query1_str) with pytest.raises(TransportClosed): - sample_transport.execute(query) + transport.execute(query) await run_sync_test(event_loop, server, test_code) @@ -322,9 +333,9 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(query1_str) @@ -399,10 +410,10 @@ async def single_upload_handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) with TemporaryFile(file_1_content) as test_file: - with Client(transport=sample_transport) as session: + with Client(transport=transport) as session: query = gql(file_upload_mutation_1) file_path = test_file.filename @@ -463,10 +474,10 @@ async def single_upload_handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) + transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) with TemporaryFile(file_1_content) as test_file: - with Client(transport=sample_transport) as session: + with Client(transport=transport) as session: query = gql(file_upload_mutation_1) file_path = test_file.filename @@ -526,11 +537,11 @@ async def binary_upload_handler(request): url = server.make_url("/") - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) def test_code(): with TemporaryFile(binary_file_content) as test_file: - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(file_upload_mutation_1) @@ -617,12 +628,12 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) with TemporaryFile(file_1_content) as test_file_1: with TemporaryFile(file_2_content) as test_file_2: - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(file_upload_mutation_2) @@ -718,11 +729,11 @@ async def handler(request): url = server.make_url("/") def test_code(): - sample_transport = RequestsHTTPTransport(url=url) + transport = RequestsHTTPTransport(url=url) with TemporaryFile(file_1_content) as test_file_1: with TemporaryFile(file_2_content) as test_file_2: - with Client(transport=sample_transport,) as session: + with Client(transport=transport) as session: query = gql(file_upload_mutation_3) diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 4e51f161..2382f157 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -2,7 +2,7 @@ import json import ssl import sys -from typing import Dict +from typing import Dict, Mapping import pytest @@ -58,12 +58,12 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: assert isinstance( - sample_transport.websocket, websockets.client.WebSocketClientProtocol + transport.websocket, websockets.client.WebSocketClientProtocol ) query1 = gql(query1_str) @@ -80,8 +80,13 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert africa["code"] == "AF" + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + # Check client is disconnect here - assert sample_transport.websocket is None + assert transport.websocket is None @pytest.mark.asyncio @@ -98,12 +103,12 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(ws_ssl_server.testcert) - sample_transport = WebsocketsTransport(url=url, ssl=ssl_context) + transport = WebsocketsTransport(url=url, ssl=ssl_context) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: assert isinstance( - sample_transport.websocket, websockets.client.WebSocketClientProtocol + transport.websocket, websockets.client.WebSocketClientProtocol ) query1 = gql(query1_str) @@ -121,7 +126,7 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): assert africa["code"] == "AF" # Check client is disconnect here - assert sample_transport.websocket is None + assert transport.websocket is None @pytest.mark.asyncio @@ -301,19 +306,19 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert sample_transport.websocket is None + assert transport.websocket is None - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert sample_transport.websocket is None + assert transport.websocket is None @pytest.mark.asyncio @@ -325,8 +330,8 @@ async def test_websocket_multiple_connections_in_parallel(event_loop, server): print(f"url = {url}") async def task_coro(): - sample_transport = WebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + transport = WebsocketsTransport(url=url) + async with Client(transport=transport) as session: await assert_client_is_working(session) task1 = asyncio.ensure_future(task_coro()) @@ -345,12 +350,12 @@ async def test_websocket_trying_to_connect_to_already_connected_transport( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + transport = WebsocketsTransport(url=url) + async with Client(transport=transport) as session: await assert_client_is_working(session) with pytest.raises(TransportAlreadyConnected): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -395,9 +400,9 @@ async def test_websocket_connect_success_with_authentication_in_connection_init( init_payload = {"Authorization": 12345} - sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) + transport = WebsocketsTransport(url=url, init_payload=init_payload) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql(query_str) @@ -428,10 +433,10 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) + transport = WebsocketsTransport(url=url, init_payload=init_payload) with pytest.raises(TransportServerError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql(query_str) await session.execute(query1) @@ -444,9 +449,9 @@ def test_websocket_execute_sync(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) query1 = gql(query1_str) @@ -476,7 +481,7 @@ def test_websocket_execute_sync(server): assert africa["code"] == "AF" # Check client is disconnect here - assert sample_transport.websocket is None + assert transport.websocket is None @pytest.mark.asyncio @@ -487,11 +492,11 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server): url = f"ws://{server.hostname}:{server.port}/graphql" # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions - sample_transport = WebsocketsTransport(url=url, connect_args={"max_size": 2 ** 21}) + transport = WebsocketsTransport(url=url, connect_args={"max_size": 2 ** 21}) query = gql(query1_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) From fe213c42f07ae14f1311fd5cdd453413a35156df Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 11 Mar 2022 12:15:06 +0100 Subject: [PATCH 068/239] Bump version number to 3.1.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 528787cf..f5f41e56 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.0.0" +__version__ = "3.1.0" From f0790aed0e407d0d3fe7ebd28c888432324bccf3 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 29 Mar 2022 17:43:19 +0200 Subject: [PATCH 069/239] Update black dev dependency to 22.3.0 (#313) --- docs/code_examples/aiohttp_async.py | 3 +- .../code_examples/appsync/mutation_api_key.py | 3 +- docs/code_examples/appsync/mutation_iam.py | 3 +- docs/code_examples/requests_sync.py | 4 +- docs/code_examples/requests_sync_dsl.py | 4 +- docs/code_examples/websockets_async.py | 3 +- gql/cli.py | 8 ++- gql/dsl.py | 23 ++++++-- gql/transport/aiohttp.py | 6 +- gql/transport/appsync_websockets.py | 7 ++- gql/transport/async_transport.py | 10 ++-- gql/transport/local_schema.py | 22 ++++--- gql/transport/requests.py | 3 +- gql/transport/transport.py | 3 +- gql/transport/websockets.py | 6 +- gql/transport/websockets_base.py | 3 +- gql/utilities/get_introspection_query_ast.py | 15 +++-- setup.py | 2 +- tests/conftest.py | 3 +- tests/custom_scalars/test_enum_colors.py | 8 ++- tests/custom_scalars/test_json.py | 3 +- tests/custom_scalars/test_money.py | 37 +++++++++--- tests/nested_input/schema.py | 5 +- tests/starwars/fixtures.py | 6 +- tests/starwars/schema.py | 58 ++++++++++++++----- tests/starwars/test_dsl.py | 47 ++++++++++++--- tests/test_appsync_auth.py | 8 ++- tests/test_appsync_http.py | 4 +- tests/test_appsync_websockets.py | 9 ++- tests/test_async_client_validation.py | 3 +- tests/test_cli.py | 27 ++++++--- tests/test_client.py | 12 ++-- tests/test_graphqlws_subscription.py | 6 +- tests/test_phoenix_channel_exceptions.py | 4 +- tests/test_websocket_exceptions.py | 5 +- tests/test_websocket_query.py | 2 +- tests/test_websocket_subscription.py | 2 +- 37 files changed, 263 insertions(+), 114 deletions(-) diff --git a/docs/code_examples/aiohttp_async.py b/docs/code_examples/aiohttp_async.py index b8cc05a1..0c1d10dd 100644 --- a/docs/code_examples/aiohttp_async.py +++ b/docs/code_examples/aiohttp_async.py @@ -11,7 +11,8 @@ async def main(): # Using `async with` on the client will start a connection on the transport # and provide a `session` variable to execute queries on this connection async with Client( - transport=transport, fetch_schema_from_transport=True, + transport=transport, + fetch_schema_from_transport=True, ) as session: # Execute single query diff --git a/docs/code_examples/appsync/mutation_api_key.py b/docs/code_examples/appsync/mutation_api_key.py index 052da850..634e2439 100644 --- a/docs/code_examples/appsync/mutation_api_key.py +++ b/docs/code_examples/appsync/mutation_api_key.py @@ -31,7 +31,8 @@ async def main(): transport = AIOHTTPTransport(url=url, auth=auth) async with Client( - transport=transport, fetch_schema_from_transport=False, + transport=transport, + fetch_schema_from_transport=False, ) as session: query = gql( diff --git a/docs/code_examples/appsync/mutation_iam.py b/docs/code_examples/appsync/mutation_iam.py index 327e0d91..3cc04a5a 100644 --- a/docs/code_examples/appsync/mutation_iam.py +++ b/docs/code_examples/appsync/mutation_iam.py @@ -30,7 +30,8 @@ async def main(): transport = AIOHTTPTransport(url=url, auth=auth) async with Client( - transport=transport, fetch_schema_from_transport=False, + transport=transport, + fetch_schema_from_transport=False, ) as session: query = gql( diff --git a/docs/code_examples/requests_sync.py b/docs/code_examples/requests_sync.py index 53b1e2a7..2184f286 100644 --- a/docs/code_examples/requests_sync.py +++ b/docs/code_examples/requests_sync.py @@ -2,7 +2,9 @@ from gql.transport.requests import RequestsHTTPTransport transport = RequestsHTTPTransport( - url="https://countries.trevorblades.com/", verify=True, retries=3, + url="https://countries.trevorblades.com/", + verify=True, + retries=3, ) client = Client(transport=transport, fetch_schema_from_transport=True) diff --git a/docs/code_examples/requests_sync_dsl.py b/docs/code_examples/requests_sync_dsl.py index 925b9aa2..e16ded92 100644 --- a/docs/code_examples/requests_sync_dsl.py +++ b/docs/code_examples/requests_sync_dsl.py @@ -3,7 +3,9 @@ from gql.transport.requests import RequestsHTTPTransport transport = RequestsHTTPTransport( - url="https://countries.trevorblades.com/", verify=True, retries=3, + url="https://countries.trevorblades.com/", + verify=True, + retries=3, ) client = Client(transport=transport, fetch_schema_from_transport=True) diff --git a/docs/code_examples/websockets_async.py b/docs/code_examples/websockets_async.py index e5e83021..e645a7ef 100644 --- a/docs/code_examples/websockets_async.py +++ b/docs/code_examples/websockets_async.py @@ -14,7 +14,8 @@ async def main(): # Using `async with` on the client will start a connection on the transport # and provide a `session` variable to execute queries on this connection async with Client( - transport=transport, fetch_schema_from_transport=True, + transport=transport, + fetch_schema_from_transport=True, ) as session: # Execute single query diff --git a/gql/cli.py b/gql/cli.py index 1e248081..78d82551 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -132,11 +132,15 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: appsync_auth_group = appsync_group.add_mutually_exclusive_group() appsync_auth_group.add_argument( - "--api-key", help="Provide an API key for authentication", dest="api_key", + "--api-key", + help="Provide an API key for authentication", + dest="api_key", ) appsync_auth_group.add_argument( - "--jwt", help="Provide an JSON Web token for authentication", dest="jwt", + "--jwt", + help="Provide an JSON Web token for authentication", + dest="jwt", ) return parser diff --git a/gql/dsl.py b/gql/dsl.py index 6a2e0718..634c10cb 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -311,7 +311,9 @@ class DSLSelector(ABC): selection_set: SelectionSetNode def __init__( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + self, + *fields: "DSLSelectable", + **fields_with_alias: "DSLSelectableWithAlias", ): """:meta private:""" self.selection_set = SelectionSetNode(selections=()) @@ -326,7 +328,9 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: ) # pragma: no cover def select( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + self, + *fields: "DSLSelectable", + **fields_with_alias: "DSLSelectableWithAlias", ): r"""Select the fields which should be added. @@ -387,7 +391,9 @@ def executable_ast(self): ) # pragma: no cover def __init__( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + self, + *fields: "DSLSelectable", + **fields_with_alias: "DSLSelectableWithAlias", ): r"""Given arguments of type :class:`DSLSelectable` containing GraphQL requests, generate an operation which can be converted to a Document @@ -552,7 +558,9 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: """ return tuple( VariableDefinitionNode( - type=var.type, variable=var.ast_variable, default_value=None, + type=var.type, + variable=var.ast_variable, + default_value=None, ) for var in self.variables.values() if var.type is not None # only variables used @@ -889,7 +897,9 @@ class DSLInlineFragment(DSLSelectable, DSLFragmentSelector): ast_field: InlineFragmentNode def __init__( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", + self, + *fields: "DSLSelectable", + **fields_with_alias: "DSLSelectableWithAlias", ): r"""Initialize the DSLInlineFragment. @@ -944,7 +954,8 @@ class DSLFragment(DSLSelectable, DSLFragmentSelector, DSLExecutable): name: str def __init__( - self, name: str, + self, + name: str, ): r"""Initialize the DSLFragment. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 7df60417..de9ab953 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -222,7 +222,8 @@ async def execute( # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( - variables=variable_values, file_classes=self.file_classes, + variables=variable_values, + file_classes=self.file_classes, ) # Save the nulled variable values in the payload @@ -275,7 +276,8 @@ async def execute( # Add headers for AppSync if requested if isinstance(self.auth, AppSyncAuthentication): post_args["headers"] = self.auth.get_headers( - json.dumps(payload), {"content-type": "application/json"}, + json.dumps(payload), + {"content-type": "application/json"}, ) if self.session is None: diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index c7e05a09..66091747 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -172,7 +172,12 @@ async def _send_query( "authorization": self.auth.get_headers(serialized_data) } - await self._send(json.dumps(message, separators=(",", ":"),)) + await self._send( + json.dumps( + message, + separators=(",", ":"), + ) + ) return query_id diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 7de24015..18f6df79 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -7,16 +7,14 @@ class AsyncTransport: @abc.abstractmethod async def connect(self): - """Coroutine used to create a connection to the specified address - """ + """Coroutine used to create a connection to the specified address""" raise NotImplementedError( "Any AsyncTransport subclass must implement connect method" ) # pragma: no cover @abc.abstractmethod async def close(self): - """Coroutine used to Close an established connection - """ + """Coroutine used to Close an established connection""" raise NotImplementedError( "Any AsyncTransport subclass must implement close method" ) # pragma: no cover @@ -28,8 +26,8 @@ async def execute( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> ExecutionResult: - """Execute the provided document AST for either a remote or local GraphQL Schema. - """ + """Execute the provided document AST for either a remote or local GraphQL + Schema.""" raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" ) # pragma: no cover diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 18cd2982..87395b19 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -10,7 +10,8 @@ class LocalSchemaTransport(AsyncTransport): """A transport for executing GraphQL queries against a local schema.""" def __init__( - self, schema: GraphQLSchema, + self, + schema: GraphQLSchema, ): """Initialize the transport with the given local schema. @@ -19,20 +20,20 @@ def __init__( self.schema = schema async def connect(self): - """No connection needed on local transport - """ + """No connection needed on local transport""" pass async def close(self): - """No close needed on local transport - """ + """No close needed on local transport""" pass async def execute( - self, document: DocumentNode, *args, **kwargs, + self, + document: DocumentNode, + *args, + **kwargs, ) -> ExecutionResult: - """Execute the provided document AST for on a local GraphQL Schema. - """ + """Execute the provided document AST for on a local GraphQL Schema.""" result_or_awaitable = execute(self.schema, document, *args, **kwargs) @@ -48,7 +49,10 @@ async def execute( return execution_result async def subscribe( - self, document: DocumentNode, *args, **kwargs, + self, + document: DocumentNode, + *args, + **kwargs, ) -> AsyncGenerator[ExecutionResult, None]: """Send a subscription and receive the results using an async generator diff --git a/gql/transport/requests.py b/gql/transport/requests.py index a34e7542..690615b4 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -155,7 +155,8 @@ def execute( # type: ignore # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( - variables=variable_values, file_classes=self.file_classes, + variables=variable_values, + file_classes=self.file_classes, ) # Save the nulled variable values in the payload diff --git a/gql/transport/transport.py b/gql/transport/transport.py index 56d882f4..a21502f0 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -18,8 +18,7 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: ) # pragma: no cover def connect(self): - """Establish a session with the transport. - """ + """Establish a session with the transport.""" pass # pragma: no cover def close(self): diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 1650624e..9e111551 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -152,8 +152,7 @@ async def _initialize(self): await self._send_init_message_and_wait_ack() async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol - """ + """Send a ping message for the graphql-ws protocol""" ping_message = {"type": "ping"} @@ -163,8 +162,7 @@ async def send_ping(self, payload: Optional[Any] = None) -> None: await self._send(json.dumps(ping_message)) async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol - """ + """Send a pong message for the graphql-ws protocol""" pong_message = {"type": "pong"} diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index 7a83c47f..45c96d3e 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -196,8 +196,7 @@ async def _after_initialize(self): pass # pragma: no cover async def _close_hook(self): - """Hook to add custom code for subclasses for the connection close - """ + """Hook to add custom code for subclasses for the connection close""" pass # pragma: no cover async def _connection_terminate(self): diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index d053c1c0..d35a2a75 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -48,13 +48,15 @@ def get_introspection_query_ast( if directive_is_repeatable: directives.select(ds.__Directive.isRepeatable) directives.select( - ds.__Directive.locations, ds.__Directive.args.select(fragment_InputValue), + ds.__Directive.locations, + ds.__Directive.args.select(fragment_InputValue), ) schema.select(directives) fragment_FullType.select( - ds.__Type.kind, ds.__Type.name, + ds.__Type.kind, + ds.__Type.name, ) if descriptions: fragment_FullType.select(ds.__Type.description) @@ -81,7 +83,8 @@ def get_introspection_query_ast( enum_values.select(ds.__EnumValue.description) enum_values.select( - ds.__EnumValue.isDeprecated, ds.__EnumValue.deprecationReason, + ds.__EnumValue.isDeprecated, + ds.__EnumValue.deprecationReason, ) fragment_FullType.select( @@ -98,11 +101,13 @@ def get_introspection_query_ast( fragment_InputValue.select(ds.__InputValue.description) fragment_InputValue.select( - ds.__InputValue.type.select(fragment_TypeRef), ds.__InputValue.defaultValue, + ds.__InputValue.type.select(fragment_TypeRef), + ds.__InputValue.defaultValue, ) fragment_TypeRef.select( - ds.__Type.kind, ds.__Type.name, + ds.__Type.kind, + ds.__Type.name, ) if type_recursion_level >= 1: diff --git a/setup.py b/setup.py index 07bab00e..1a46c4db 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ ] dev_requires = [ - "black==19.10b0", + "black==22.3.0", "check-manifest>=0.42,<1", "flake8==3.8.1", "isort==4.3.21", diff --git a/tests/conftest.py b/tests/conftest.py index 8d0b95ba..518d0d3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -469,7 +469,8 @@ async def client_and_graphqlws_server(graphqlws_server): path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" sample_transport = WebsocketsTransport( - url=url, subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + url=url, + subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) async with Client(transport=sample_transport) as session: diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 2c7b887c..2f15a8ca 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -71,7 +71,10 @@ def resolve_list(_root, _info): queryType = GraphQLObjectType( name="RootQueryType", fields={ - "all": GraphQLField(GraphQLList(ColorType), resolve=resolve_all,), + "all": GraphQLField( + GraphQLList(ColorType), + resolve=resolve_all, + ), "opposite": GraphQLField( ColorType, args={"color": GraphQLArgument(ColorType)}, @@ -90,7 +93,8 @@ def resolve_list(_root, _info): resolve=resolve_list_of_list, ), "list": GraphQLField( - GraphQLNonNull(GraphQLList(ColorType)), resolve=resolve_list, + GraphQLNonNull(GraphQLList(ColorType)), + resolve=resolve_list, ), }, ) diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index 9659d0a5..4c4da588 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -68,7 +68,8 @@ def resolve_players(root, _info): queryType = GraphQLObjectType( - name="Query", fields={"players": GraphQLField(JsonScalar, resolve=resolve_players)}, + name="Query", + fields={"players": GraphQLField(JsonScalar, resolve=resolve_players)}, ) diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 23dc281d..e67a0bcd 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -155,7 +155,8 @@ def resolve_to_euros(_root, _info, money): GraphQLList(MoneyScalar), resolve=resolve_friends_balance ), "countries_balance": GraphQLField( - GraphQLNonNull(countriesBalance), resolve=resolve_countries_balance, + GraphQLNonNull(countriesBalance), + resolve=resolve_countries_balance, ), }, ) @@ -184,7 +185,10 @@ async def subscribe_spend_all(_root, _info, money): }, ) -schema = GraphQLSchema(query=queryType, subscription=subscriptionType,) +schema = GraphQLSchema( + query=queryType, + subscription=subscriptionType, +) def test_custom_scalar_in_output(): @@ -470,7 +474,9 @@ async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server transport = await make_money_transport(aiohttp_server) - async with Client(transport=transport,) as session: + async with Client( + transport=transport, + ) as session: query = gql("{balance}") @@ -486,7 +492,9 @@ async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_s transport = await make_money_transport(aiohttp_server) - async with Client(transport=transport,) as session: + async with Client( + transport=transport, + ) as session: query = gql('{toEuros(money: {amount: 10, currency: "DM"})}') @@ -508,7 +516,9 @@ async def test_custom_scalar_in_input_variable_values_with_transport( transport = await make_money_transport(aiohttp_server) - async with Client(transport=transport,) as session: + async with Client( + transport=transport, + ) as session: query = gql("query myquery($money: Money) {toEuros(money: $money)}") @@ -530,7 +540,9 @@ async def test_custom_scalar_in_input_variable_values_split_with_transport( transport = await make_money_transport(aiohttp_server) - async with Client(transport=transport,) as session: + async with Client( + transport=transport, + ) as session: query = gql( """ @@ -552,7 +564,10 @@ async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): transport = await make_money_transport(aiohttp_server) - async with Client(schema=schema, transport=transport,) as session: + async with Client( + schema=schema, + transport=transport, + ) as session: query = gql("query myquery($money: Money) {toEuros(money: $money)}") @@ -571,7 +586,9 @@ async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_s transport = await make_money_transport(aiohttp_server) - async with Client(transport=transport,) as session: + async with Client( + transport=transport, + ) as session: query = gql("query myquery($money: Money) {toEuros(money: $money)}") @@ -670,7 +687,9 @@ def test_update_schema_scalars_invalid_scalar_argument(): def test_update_schema_scalars_scalar_not_found_in_schema(): - NotFoundScalar = GraphQLScalarType(name="abcd",) + NotFoundScalar = GraphQLScalarType( + name="abcd", + ) with pytest.raises(KeyError) as exc_info: update_schema_scalars(schema, [MoneyScalar, NotFoundScalar]) diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py index bd5a0507..ccdebb4a 100644 --- a/tests/nested_input/schema.py +++ b/tests/nested_input/schema.py @@ -30,4 +30,7 @@ }, ) -NestedInputSchema = GraphQLSchema(query=queryType, types=[nestedInput],) +NestedInputSchema = GraphQLSchema( + query=queryType, + types=[nestedInput], +) diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 36232147..efbb1b0e 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -66,7 +66,11 @@ def __init__(self, id, name, friends, appearsIn, primaryFunction): ) tarkin = Human( - id="1004", name="Wilhuff Tarkin", friends=["1001"], appearsIn=[4], homePlanet=None, + id="1004", + name="Wilhuff Tarkin", + friends=["1001"], + appearsIn=[4], + homePlanet=None, ) humanData = { diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 50e2420f..c3db0a3d 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -32,9 +32,18 @@ episode_enum = GraphQLEnumType( "Episode", { - "NEWHOPE": GraphQLEnumValue(4, description="Released in 1977.",), - "EMPIRE": GraphQLEnumValue(5, description="Released in 1980.",), - "JEDI": GraphQLEnumValue(6, description="Released in 1983.",), + "NEWHOPE": GraphQLEnumValue( + 4, + description="Released in 1977.", + ), + "EMPIRE": GraphQLEnumValue( + 5, + description="Released in 1980.", + ), + "JEDI": GraphQLEnumValue( + 6, + description="Released in 1983.", + ), }, description="One of the films in the Star Wars Trilogy", ) @@ -70,16 +79,21 @@ "Human", lambda: { "id": GraphQLField( - GraphQLNonNull(GraphQLString), description="The id of the human.", + GraphQLNonNull(GraphQLString), + description="The id of the human.", + ), + "name": GraphQLField( + GraphQLString, + description="The name of the human.", ), - "name": GraphQLField(GraphQLString, description="The name of the human.",), "friends": GraphQLField( GraphQLList(character_interface), description="The friends of the human, or an empty list if they have none.", resolve=lambda human, _info: get_friends(human), ), "appearsIn": GraphQLField( - GraphQLList(episode_enum), description="Which movies they appear in.", + GraphQLList(episode_enum), + description="Which movies they appear in.", ), "homePlanet": GraphQLField( GraphQLString, @@ -94,19 +108,25 @@ "Droid", lambda: { "id": GraphQLField( - GraphQLNonNull(GraphQLString), description="The id of the droid.", + GraphQLNonNull(GraphQLString), + description="The id of the droid.", + ), + "name": GraphQLField( + GraphQLString, + description="The name of the droid.", ), - "name": GraphQLField(GraphQLString, description="The name of the droid.",), "friends": GraphQLField( GraphQLList(character_interface), description="The friends of the droid, or an empty list if they have none.", resolve=lambda droid, _info: get_friends(droid), ), "appearsIn": GraphQLField( - GraphQLList(episode_enum), description="Which movies they appear in.", + GraphQLList(episode_enum), + description="Which movies they appear in.", ), "primaryFunction": GraphQLField( - GraphQLString, description="The primary function of the droid.", + GraphQLString, + description="The primary function of the droid.", ), }, interfaces=[character_interface], @@ -157,7 +177,8 @@ human_type, args={ "id": GraphQLArgument( - description="id of the human", type_=GraphQLNonNull(GraphQLString), + description="id of the human", + type_=GraphQLNonNull(GraphQLString), ) }, resolve=lambda _souce, _info, id: get_human(id), @@ -166,7 +187,8 @@ droid_type, args={ "id": GraphQLArgument( - description="id of the droid", type_=GraphQLNonNull(GraphQLString), + description="id of the droid", + type_=GraphQLNonNull(GraphQLString), ) }, resolve=lambda _source, _info, id: get_droid(id), @@ -175,7 +197,8 @@ GraphQLList(character_interface), args={ "ids": GraphQLArgument( - GraphQLList(GraphQLString), description="list of character ids", + GraphQLList(GraphQLString), + description="list of character ids", ) }, resolve=lambda _source, _info, ids=None: get_characters(ids), @@ -190,10 +213,12 @@ review_type, args={ "episode": GraphQLArgument( - episode_enum, description="Episode to create review", + episode_enum, + description="Episode to create review", ), "review": GraphQLArgument( - description="set alive status", type_=review_input_type, + description="set alive status", + type_=review_input_type, ), }, resolve=lambda _source, _info, episode=None, review=None: create_review( @@ -222,7 +247,8 @@ async def resolve_review(review, _info, **_args): review_type, args={ "episode": GraphQLArgument( - episode_enum, description="Episode to review", + episode_enum, + description="Episode to review", ) }, subscribe=subscribe_reviews, diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 6adc84a9..50f5449c 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -217,7 +217,9 @@ def test_hero_name_and_friends_query(ds): query_dsl = ds.Query.hero.select( ds.Character.id, ds.Character.name, - ds.Character.friends.select(ds.Character.name,), + ds.Character.friends.select( + ds.Character.name, + ), ) assert query == str(query_dsl) @@ -225,7 +227,11 @@ def test_hero_name_and_friends_query(ds): query_dsl = ( ds.Query.hero.select(ds.Character.id) .select(ds.Character.name) - .select(ds.Character.friends.select(ds.Character.name,),) + .select( + ds.Character.friends.select( + ds.Character.name, + ), + ) ) assert query == str(query_dsl) @@ -272,7 +278,9 @@ def test_fetch_luke_query(ds): name } """.strip() - query_dsl = ds.Query.human(id="1000").select(ds.Human.name,) + query_dsl = ds.Query.human(id="1000").select( + ds.Human.name, + ) assert query == str(query_dsl) @@ -283,11 +291,23 @@ def test_fetch_luke_aliased(ds): name } """.strip() - query_dsl = ds.Query.human.args(id=1000).alias("luke").select(ds.Character.name,) + query_dsl = ( + ds.Query.human.args(id=1000) + .alias("luke") + .select( + ds.Character.name, + ) + ) assert query == str(query_dsl) # Should also work with select before alias - query_dsl = ds.Query.human.args(id=1000).select(ds.Character.name,).alias("luke") + query_dsl = ( + ds.Query.human.args(id=1000) + .select( + ds.Character.name, + ) + .alias("luke") + ) assert query == str(query_dsl) @@ -308,7 +328,9 @@ def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): my_name: name } """.strip() - query_dsl = ds.Query.human.args(id=1000).select(my_name=ds.Character.name,) + query_dsl = ds.Query.human.args(id=1000).select( + my_name=ds.Character.name, + ) assert query == str(query_dsl) @@ -322,7 +344,9 @@ def test_hero_name_query_result(ds, client): def test_arg_serializer_list(ds, client): query = dsl_gql( DSLQuery( - ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) + ds.Query.characters.args(ids=[1000, 1001, 1003]).select( + ds.Character.name, + ) ) ) result = client.execute(query) @@ -433,7 +457,11 @@ def test_root_fields_aliased(ds, client): def test_operation_name(ds): - query = dsl_gql(GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name),)) + query = dsl_gql( + GetHeroName=DSLQuery( + ds.Query.hero.select(ds.Character.name), + ) + ) assert ( print_ast(query) @@ -574,7 +602,8 @@ def test_inline_fragment_in_dsl_gql(ds): query = DSLQuery() with pytest.raises( - GraphQLError, match=r"Invalid field for : ", + GraphQLError, + match=r"Invalid field for : ", ): query.select(inline_fragment) diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 546e0e6f..cb279ae5 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -27,7 +27,8 @@ def test_appsync_init_with_no_credentials(caplog, fake_session_factory): with pytest.raises(botocore.exceptions.NoCredentialsError): sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, session=fake_session_factory(credentials=None), + url=mock_transport_url, + session=fake_session_factory(credentials=None), ) assert sample_transport.auth is None @@ -75,7 +76,8 @@ def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncIAMAuthentication( - host=mock_transport_host, session=fake_session_factory(credentials=None), + host=mock_transport_host, + session=fake_session_factory(credentials=None), ) with pytest.raises(botocore.exceptions.NoCredentialsError): AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) @@ -105,7 +107,7 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have a default region set in ~/.aws/config - you have the AWS_DEFAULT_REGION environment variable set - """ + """ from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from botocore.exceptions import NoRegionError import logging diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 1f787a68..ca3a3fcb 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -44,7 +44,9 @@ async def handler(request): host = str(urlparse(url).netloc) auth = AppSyncIAMAuthentication( - host=host, credentials=fake_credentials_factory(), region_name="us-east-1", + host=host, + credentials=fake_credentials_factory(), + region_name="us-east-1", ) sample_transport = AIOHTTPTransport(url=url, auth=auth) diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index f510d4a7..62816cc9 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -450,7 +450,8 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): url = f"ws://{server.hostname}:{server.port}{path}" dummy_credentials = Credentials( - access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + access_key=DUMMY_ACCESS_KEY_ID, + secret_key=DUMMY_SECRET_ACCESS_KEY, ) auth = AppSyncIAMAuthentication( @@ -475,7 +476,8 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): url = f"ws://{server.hostname}:{server.port}{path}" dummy_credentials = Credentials( - access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + access_key=DUMMY_ACCESS_KEY_ID, + secret_key=DUMMY_SECRET_ACCESS_KEY, ) auth = AppSyncIAMAuthentication( @@ -518,7 +520,8 @@ async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): from botocore.credentials import Credentials dummy_credentials = Credentials( - access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + access_key=DUMMY_ACCESS_KEY_ID, + secret_key=DUMMY_SECRET_ACCESS_KEY, ) auth = AppSyncIAMAuthentication( diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index b588e6ba..d39019e8 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -266,7 +266,8 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu sample_transport = WebsocketsTransport(url=url) async with Client( - transport=sample_transport, fetch_schema_from_transport=True, + transport=sample_transport, + fetch_schema_from_transport=True, ) as session: query = gql( diff --git a/tests/test_cli.py b/tests/test_cli.py index 8df47a63..ec268422 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -155,7 +155,8 @@ def test_cli_parse_variable_value_invalid_param(parser, param): @pytest.mark.aiohttp @pytest.mark.parametrize( - "url", ["http://your_server.com", "https://your_server.com"], + "url", + ["http://your_server.com", "https://your_server.com"], ) def test_cli_get_transport_aiohttp(parser, url): @@ -170,7 +171,8 @@ def test_cli_get_transport_aiohttp(parser, url): @pytest.mark.websockets @pytest.mark.parametrize( - "url", ["ws://your_server.com", "wss://your_server.com"], + "url", + ["ws://your_server.com", "wss://your_server.com"], ) def test_cli_get_transport_websockets(parser, url): @@ -185,7 +187,8 @@ def test_cli_get_transport_websockets(parser, url): @pytest.mark.websockets @pytest.mark.parametrize( - "url", ["ws://your_server.com", "wss://your_server.com"], + "url", + ["ws://your_server.com", "wss://your_server.com"], ) def test_cli_get_transport_phoenix(parser, url): @@ -224,7 +227,8 @@ def test_cli_get_transport_appsync_websockets_iam(parser, url): @pytest.mark.websockets @pytest.mark.botocore @pytest.mark.parametrize( - "url", ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], + "url", + ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], ) async def test_cli_main_appsync_websockets_iam(parser, url): @@ -240,7 +244,8 @@ async def test_cli_main_appsync_websockets_iam(parser, url): @pytest.mark.websockets @pytest.mark.parametrize( - "url", ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], + "url", + ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], ) def test_cli_get_transport_appsync_websockets_api_key(parser, url): @@ -260,7 +265,8 @@ def test_cli_get_transport_appsync_websockets_api_key(parser, url): @pytest.mark.websockets @pytest.mark.parametrize( - "url", ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], + "url", + ["wss://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], ) def test_cli_get_transport_appsync_websockets_jwt(parser, url): @@ -281,7 +287,8 @@ def test_cli_get_transport_appsync_websockets_jwt(parser, url): @pytest.mark.aiohttp @pytest.mark.botocore @pytest.mark.parametrize( - "url", ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], + "url", + ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], ) def test_cli_get_transport_appsync_http_iam(parser, url): @@ -296,7 +303,8 @@ def test_cli_get_transport_appsync_http_iam(parser, url): @pytest.mark.aiohttp @pytest.mark.parametrize( - "url", ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], + "url", + ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], ) def test_cli_get_transport_appsync_http_api_key(parser, url): @@ -316,7 +324,8 @@ def test_cli_get_transport_appsync_http_api_key(parser, url): @pytest.mark.aiohttp @pytest.mark.parametrize( - "url", ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], + "url", + ["https://XXXXXX.appsync-api.eu-west-3.amazonaws.com/graphql"], ) def test_cli_get_transport_appsync_http_jwt(parser, url): diff --git a/tests/test_client.py b/tests/test_client.py index fecdf43d..8b6575d7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -52,7 +52,8 @@ def test_retries_on_transport(execute_mock): "Should be HTTPConnection", "Fake connection error" ) transport = RequestsHTTPTransport( - url="http://127.0.0.1:8000/graphql", retries=expected_retries, + url="http://127.0.0.1:8000/graphql", + retries=expected_retries, ) client = Client(transport=transport) @@ -136,7 +137,8 @@ def test_http_transport_verify_error(http_transport_query): with Client( transport=RequestsHTTPTransport( - url="https://countries.trevorblades.com/", verify=False, + url="https://countries.trevorblades.com/", + verify=False, ) ) as client: with pytest.warns(Warning) as record: @@ -152,7 +154,8 @@ def test_http_transport_specify_method_valid(http_transport_query): with Client( transport=RequestsHTTPTransport( - url="https://countries.trevorblades.com/", method="POST", + url="https://countries.trevorblades.com/", + method="POST", ) ) as client: result = client.execute(http_transport_query) @@ -166,7 +169,8 @@ def test_http_transport_specify_method_invalid(http_transport_query): with Client( transport=RequestsHTTPTransport( - url="https://countries.trevorblades.com/", method="GET", + url="https://countries.trevorblades.com/", + method="GET", ) ) as client: with pytest.raises(Exception) as exc_info: diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 7826aca1..ade21911 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -533,7 +533,9 @@ async def test_graphqlws_subscription_with_ping_interval_ok( path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" transport = WebsocketsTransport( - url=url, ping_interval=(5 * COUNTING_DELAY), pong_timeout=(4 * COUNTING_DELAY), + url=url, + ping_interval=(5 * COUNTING_DELAY), + pong_timeout=(4 * COUNTING_DELAY), ) client = Client(transport=transport) @@ -709,7 +711,7 @@ def test_graphqlws_subscription_sync(graphqlws_server, subscription_str): def test_graphqlws_subscription_sync_graceful_shutdown( graphqlws_server, subscription_str ): - """ Note: this test will simulate a control-C happening while a sync subscription + """Note: this test will simulate a control-C happening while a sync subscription is in progress. To do that we will throw a KeyboardInterrupt exception inside the subscription async generator. diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 1711d25a..e2bf0091 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -434,7 +434,9 @@ async def test_phoenix_channel_subscription_protocol_error( @pytest.mark.asyncio @pytest.mark.parametrize( - "server", [query_server(server_error_server_answer)], indirect=True, + "server", + [query_server(server_error_server_answer)], + indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_phoenix_channel_server_error(event_loop, server, query_str): diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 8cccf33b..72db8a87 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -172,7 +172,10 @@ async def test_websocket_sending_invalid_payload( # Monkey patching the _send_query method to send an invalid payload async def monkey_patch_send_query( - self, document, variable_values=None, operation_name=None, + self, + document, + variable_values=None, + operation_name=None, ) -> int: query_id = self.next_query_id self.next_query_id += 1 diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 2382f157..f39409f5 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -492,7 +492,7 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server): url = f"ws://{server.hostname}:{server.port}/graphql" # Increase max payload size to avoid websockets.exceptions.PayloadTooBig exceptions - transport = WebsocketsTransport(url=url, connect_args={"max_size": 2 ** 21}) + transport = WebsocketsTransport(url=url, connect_args={"max_size": 2**21}) query = gql(query1_str) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 14ffe0a2..f1d72dc8 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -498,7 +498,7 @@ def test_websocket_subscription_sync(server, subscription_str): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str): - """ Note: this test will simulate a control-C happening while a sync subscription + """Note: this test will simulate a control-C happening while a sync subscription is in progress. To do that we will throw a KeyboardInterrupt exception inside the subscription async generator. From c202ccaaa1d71b51f72b1939e71fc42a20851675 Mon Sep 17 00:00:00 2001 From: Nicholas Bollweg Date: Sat, 9 Apr 2022 14:04:39 -0500 Subject: [PATCH 070/239] replace use of `scripts` with `entry_points` (#311) Increase test coverage to the script itself. Fix windows gql-cli signal issue found by the new tests. --- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 10 ++++----- MANIFEST.in | 2 -- Makefile | 2 +- gql/cli.py | 45 +++++++++++++++++++++++++++++++++++++ scripts/gql-cli | 34 ---------------------------- setup.py | 7 +++--- tests/test_aiohttp.py | 38 +++++++++++++++++++++++++++++++ tests/test_cli.py | 10 +++++++++ 9 files changed, 104 insertions(+), 46 deletions(-) delete mode 100755 scripts/gql-cli diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dffc5c4b..6ed6d6ea 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,7 +14,7 @@ jobs: python-version: 3.8 - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip wheel pip install tox - name: Run lint and static type checks run: tox diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 870493aa..a0631101 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,7 +30,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip wheel pip install tox tox-gh-actions - name: Test with tox run: tox @@ -52,7 +52,7 @@ jobs: python-version: 3.8 - name: Install dependencies with only ${{ matrix.dependency }} extra dependency run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip wheel pip install .[${{ matrix.dependency }},test_no_transport] - name: Test with --${{ matrix.dependency }}-only run: pytest tests --${{ matrix.dependency }}-only @@ -68,9 +68,9 @@ jobs: python-version: 3.8 - name: Install test dependencies run: | - python -m pip install --upgrade pip - pip install .[test] + python -m pip install --upgrade pip wheel + pip install -e.[test] - name: Test with coverage - run: pytest --cov=gql --cov-report=xml tests + run: pytest --cov=gql --cov-report=xml --cov-report=term-missing tests - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 diff --git a/MANIFEST.in b/MANIFEST.in index 73d59a18..c0f653ab 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -11,8 +11,6 @@ include Makefile include tox.ini -include scripts/gql-cli - include gql/py.typed recursive-include tests *.py *.graphql *.cnf *.yaml *.pem diff --git a/Makefile b/Makefile index 6baff50f..2275092c 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: clean tests docs -SRC_PYTHON := gql tests scripts/gql-cli docs/code_examples +SRC_PYTHON := gql tests docs/code_examples dev-setup: python pip install -e ".[test]" diff --git a/gql/cli.py b/gql/cli.py index 78d82551..27a562b2 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -1,5 +1,7 @@ +import asyncio import json import logging +import signal as signal_module import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from typing import Any, Dict, Optional @@ -407,3 +409,46 @@ async def main(args: Namespace) -> int: exit_code = 1 return exit_code + + +def gql_cli() -> None: + """Synchronously invoke ``main`` with the parsed command line arguments. + + Formerly ``scripts/gql-cli``, now registered as an ``entry_point`` + """ + # Get arguments from command line + parser = get_parser(with_examples=True) + args = parser.parse_args() + + try: + # Create a new asyncio event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a gql-cli task with the supplied arguments + main_task = asyncio.ensure_future(main(args), loop=loop) + + # Add signal handlers to close gql-cli cleanly on Control-C + for signal_name in ["SIGINT", "SIGTERM", "CTRL_C_EVENT", "CTRL_BREAK_EVENT"]: + signal = getattr(signal_module, signal_name, None) + + if signal is None: + continue + + try: + loop.add_signal_handler(signal, main_task.cancel) + except NotImplementedError: # pragma: no cover + # not all signals supported on all platforms + pass + + # Run the asyncio loop to execute the task + exit_code = 0 + try: + exit_code = loop.run_until_complete(main_task) + finally: + loop.close() + + # Return with the correct exit code + sys.exit(exit_code) + except KeyboardInterrupt: # pragma: no cover + pass diff --git a/scripts/gql-cli b/scripts/gql-cli deleted file mode 100755 index b2a079a3..00000000 --- a/scripts/gql-cli +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -import asyncio -import sys -from signal import SIGINT, SIGTERM - -from gql.cli import get_parser, main - -# Get arguments from command line -parser = get_parser(with_examples=True) -args = parser.parse_args() - -try: - # Create a new asyncio event loop - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Create a gql-cli task with the supplied arguments - main_task = asyncio.ensure_future(main(args), loop=loop) - - # Add signal handlers to close gql-cli cleanly on Control-C - for signal in [SIGINT, SIGTERM]: - loop.add_signal_handler(signal, main_task.cancel) - - # Run the asyncio loop to execute the task - exit_code = 0 - try: - exit_code = loop.run_until_complete(main_task) - finally: - loop.close() - - # Return with the correct exit code - sys.exit(exit_code) -except KeyboardInterrupt: - pass diff --git a/setup.py b/setup.py index 1a46c4db..a8b58737 100644 --- a/setup.py +++ b/setup.py @@ -7,14 +7,15 @@ "yarl>=1.6,<2.0", ] -scripts = [ - "scripts/gql-cli", +console_scripts = [ + "gql-cli=gql.cli:gql_cli", ] tests_requires = [ "parse==1.15.0", "pytest==6.2.5", "pytest-asyncio==0.16.0", + "pytest-console-scripts==1.3.1", "pytest-cov==3.0.0", "mock==4.0.2", "vcrpy==4.0.2", @@ -106,5 +107,5 @@ include_package_data=True, zip_safe=False, platforms="any", - scripts=scripts, + entry_points={"console_scripts": console_scripts}, ) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index ab02e8f5..2535ddb3 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1016,6 +1016,44 @@ async def handler(request): assert received_answer == expected_answer +@pytest.mark.asyncio +@pytest.mark.script_launch_mode("subprocess") +async def test_aiohttp_using_cli_ep( + event_loop, aiohttp_server, monkeypatch, script_runner, run_sync_test +): + from aiohttp import web + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + ret = script_runner.run( + "gql-cli", url, "--verbose", stdin=io.StringIO(query1_str) + ) + + assert ret.success + + # Check that the result has been printed on stdout + captured_out = str(ret.stdout).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.asyncio async def test_aiohttp_using_cli_invalid_param( event_loop, aiohttp_server, monkeypatch, capsys diff --git a/tests/test_cli.py b/tests/test_cli.py index ec268422..9066544b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ import pytest +from gql import __version__ from gql.cli import ( get_execute_args, get_parser, @@ -347,3 +348,12 @@ def test_cli_get_transport_no_protocol(parser): with pytest.raises(ValueError): get_transport(args) + + +def test_cli_ep_version(script_runner): + ret = script_runner.run("gql-cli", "--version") + + assert ret.success + + assert ret.stdout == f"v{__version__}\n" + assert ret.stderr == "" From efc00f82d70e8186e5c8beb181334b8528244ad6 Mon Sep 17 00:00:00 2001 From: Nicholas Bollweg Date: Sat, 9 Apr 2022 14:17:59 -0500 Subject: [PATCH 071/239] build wheel (#312) --- .github/workflows/deploy.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index a5800732..2a6cdc6b 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -18,9 +18,9 @@ jobs: - name: Build wheel and source tarball run: | pip install wheel - python setup.py sdist + python setup.py sdist bdist_wheel - name: Publish a Python distribution to PyPI uses: pypa/gh-action-pypi-publish@v1.1.0 with: user: __token__ - password: ${{ secrets.pypi_password }} \ No newline at end of file + password: ${{ secrets.pypi_password }} From 1a1a2ee5e9b823a615a420285fdfb3fb4dbcfabf Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 9 Apr 2022 22:07:01 +0200 Subject: [PATCH 072/239] DOC DSL add note for arguments with Python keywords (#317) --- docs/advanced/dsl_module.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index f4046f27..fd485274 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -125,6 +125,12 @@ It can also be done using the :meth:`args ` method:: ds.Query.human.args(id="1000").select(ds.Human.name) +.. note:: + If your argument name is a Python keyword (for, in, from, ...), you will receive a + SyntaxError (See `issue #308`_). To fix this, you can provide the arguments by unpacking a dictionary. + + For example, instead of using :code:`from=5`, you can use :code:`**{"from":5}` + Aliases ^^^^^^^ @@ -364,3 +370,4 @@ Sync example .. _Fragment: https://graphql.org/learn/queries/#fragments .. _Inline Fragment: https://graphql.org/learn/queries/#inline-fragments +.. _issue #308: https://github.com/graphql-python/gql/issues/308 From ea96294270854ecfcc70a3bc83ecbc4e5a0d911d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 11 Apr 2022 17:59:43 +0200 Subject: [PATCH 073/239] Represent serialized floats to approximately python float precision (#318) --- gql/dsl.py | 7 +++++-- tests/starwars/test_dsl.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 634c10cb..26b9f426 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -106,10 +106,13 @@ def ast_from_serialized_value_untyped(serialized: Any) -> Optional[ValueNode]: return BooleanValueNode(value=serialized) if isinstance(serialized, int): - return IntValueNode(value=f"{serialized:d}") + return IntValueNode(value=str(serialized)) if isinstance(serialized, float) and isfinite(serialized): - return FloatValueNode(value=f"{serialized:g}") + value = str(serialized) + if value.endswith(".0"): + value = value[:-2] + return FloatValueNode(value=value) if isinstance(serialized, str): return StringValueNode(value=serialized) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 50f5449c..c0f2b441 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,6 +1,8 @@ import pytest from graphql import ( + FloatValueNode, GraphQLError, + GraphQLFloat, GraphQLID, GraphQLInt, GraphQLList, @@ -87,6 +89,20 @@ def test_ast_from_value_with_non_null_type_and_none(): assert "Received Null value for a Non-Null type Int." in str(exc_info.value) +def test_ast_from_value_float_precision(): + + # Checking precision of float serialization + # See https://github.com/graphql-python/graphql-core/pull/164 + + assert ast_from_value(123456789.01234567, GraphQLFloat) == FloatValueNode( + value="123456789.01234567" + ) + + assert ast_from_value(1.1, GraphQLFloat) == FloatValueNode(value="1.1") + + assert ast_from_value(123.0, GraphQLFloat) == FloatValueNode(value="123") + + def test_ast_from_serialized_value_untyped_typeerror(): with pytest.raises(TypeError) as exc_info: ast_from_serialized_value_untyped(GraphQLInt) From 0926ed67220de9f7e2399724ab0f5a2ac9248132 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 11 Apr 2022 18:37:09 +0200 Subject: [PATCH 074/239] Fix dsl root operation types custom names (#320) --- gql/dsl.py | 37 +++++++++++++++++++++++++++++++++---- tests/starwars/test_dsl.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 26b9f426..63b71a07 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -297,7 +297,7 @@ def __getattr__(self, name: str) -> "DSLType": assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)) - return DSLType(type_def) + return DSLType(type_def, self) class DSLSelector(ABC): @@ -454,7 +454,27 @@ def is_valid_field(self, field: "DSLSelectable") -> bool: return operation_name != "SUBSCRIPTION" elif isinstance(field, DSLField): - return field.parent_type.name.upper() == operation_name + + assert field.dsl_type is not None + + schema = field.dsl_type._dsl_schema._schema + + root_type = None + + if operation_name == "QUERY": + root_type = schema.query_type + elif operation_name == "MUTATION": + root_type = schema.mutation_type + elif operation_name == "SUBSCRIPTION": + root_type = schema.subscription_type + + if root_type is None: + log.error( + f"Root type of type {operation_name} not found in the schema!" + ) + return False + + return field.parent_type.name == root_type.name return False @@ -585,7 +605,11 @@ class DSLType: instances of :class:`DSLField` """ - def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]): + def __init__( + self, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], + dsl_schema: DSLSchema, + ): """Initialize the DSLType with the GraphQL type. .. warning:: @@ -593,8 +617,10 @@ def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]) Use attributes of the :class:`DSLSchema` instead. :param graphql_type: the GraphQL type definition from the schema + :param dsl_schema: reference to the DSLSchema which created this type """ self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type + self._dsl_schema = dsl_schema log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": @@ -611,7 +637,7 @@ def __getattr__(self, name: str) -> "DSLField": f"Field {name} does not exist in type {self._type.name}." ) - return DSLField(formatted_name, self._type, field) + return DSLField(formatted_name, self._type, field, self) def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" @@ -763,6 +789,7 @@ def __init__( name: str, parent_type: Union[GraphQLObjectType, GraphQLInterfaceType], field: GraphQLField, + dsl_type: Optional[DSLType] = None, ): """Initialize the DSLField. @@ -774,10 +801,12 @@ def __init__( :param parent_type: the GraphQL type definition from the schema of the parent type of the field :param field: the GraphQL field definition from the schema + :param dsl_type: reference of the DSLType instance which created this field """ self.parent_type = parent_type self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=()) + self.dsl_type = dsl_type log.debug(f"Creating {self!r}") diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index c0f2b441..0b881806 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -728,6 +728,42 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): ) +def test_dsl_root_type_not_default(): + + from graphql import parse, build_ast_schema + + schema_str = """ +schema { + query: QueryNotDefault +} + +type QueryNotDefault { + version: String +} +""" + + type_def_ast = parse(schema_str) + schema = build_ast_schema(type_def_ast) + + ds = DSLSchema(schema) + + query = dsl_gql(DSLQuery(ds.QueryNotDefault.version)) + + expected_query = """ +{ + version +} +""" + assert print_ast(query) == expected_query.strip() + + with pytest.raises(GraphQLError) as excinfo: + DSLSubscription(ds.QueryNotDefault.version) + + assert ( + "Invalid field for : " + ) in str(excinfo.value) + + def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( TypeError, match="Operations should be instances of DSLExecutable " From b0729695437cb049bd8d30c9d8d5a5819f960829 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 12 Apr 2022 11:27:32 +0200 Subject: [PATCH 075/239] Bump version number to 3.2.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index f5f41e56..11731085 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.1.0" +__version__ = "3.2.0" From a4641d08e29276a22e2aa55b3d37bd5558aa781d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 23 Apr 2022 12:16:43 +0200 Subject: [PATCH 076/239] Add doc to install gql with conda (#321) --- README.md | 2 +- docs/intro.rst | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ea5e3074..780eaf10 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ You can install GQL with all the optional dependencies using pip: pip install gql[all] -> **NOTE**: See also [the documentation](https://gql.readthedocs.io/en/latest/intro.html#less-dependencies) to install GQL with less extra dependencies depending on the transports you would like to use +> **NOTE**: See also [the documentation](https://gql.readthedocs.io/en/latest/intro.html#less-dependencies) to install GQL with less extra dependencies depending on the transports you would like to use or for alternative installation methods. ## Usage diff --git a/docs/intro.rst b/docs/intro.rst index a6e8ee21..9685a980 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -53,6 +53,19 @@ The corresponding between extra dependencies required and the GQL classes is: It is also possible to install multiple extra dependencies if needed using commas: :code:`gql[aiohttp,websockets]` +Installation with conda +^^^^^^^^^^^^^^^^^^^^^^^ + +It is also possible to install gql using `conda`_. + +To install gql with all extra dependencies:: + + conda install gql-with-all + +To install gql with less dependencies, you might want to instead install a combinaison of the +following packages: :code:`gql-with-aiohttp`, :code:`gql-with-websockets`, :code:`gql-with-requests`, +:code:`gql-with-botocore` + Reporting Issues and Contributing --------------------------------- @@ -69,3 +82,4 @@ Please check the `Contributing`_ file to learn how to make a good pull request. .. _pip: https://pip.pypa.io/ .. _GitHub repository for gql: https://github.com/graphql-python/gql .. _Contributing: https://github.com/graphql-python/gql/blob/master/CONTRIBUTING.md +.. _conda: https://docs.conda.io From 64c9b5b3327c01a4ffdb66b2ad5133dd782355cb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 24 Apr 2022 07:55:20 +0200 Subject: [PATCH 077/239] DSL Add default method for variables (#322) --- docs/advanced/dsl_module.rst | 29 +++++++++++++++++++++ gql/dsl.py | 41 +++++++++++++++++------------ tests/starwars/test_dsl.py | 50 +++++++++++++++++++++++++++++++++--- 3 files changed, 101 insertions(+), 19 deletions(-) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index fd485274..1c2c1c82 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -206,6 +206,35 @@ will generate a query equivalent to:: } } +Variable arguments with a default value +""""""""""""""""""""""""""""""""""""""" + +If you want to provide a **default value** for your variable, you can use +the :code:`default` method on a variable. + +The following code: + +.. code-block:: python + + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review=var.review.default({"stars": 5, "commentary": "Wow!"}), + episode=var.episode, + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + +will generate a query equivalent to:: + + mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { + createReview(review: $review, episode: $episode) { + stars + commentary + } + } + Subscriptions ^^^^^^^^^^^^^ diff --git a/gql/dsl.py b/gql/dsl.py index 63b71a07..7f09b928 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -18,6 +18,7 @@ FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLEnumType, GraphQLError, GraphQLField, GraphQLID, @@ -28,9 +29,9 @@ GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLSchema, GraphQLString, - GraphQLWrappingType, InlineFragmentNode, IntValueNode, ListTypeNode, @@ -50,7 +51,6 @@ ValueNode, VariableDefinitionNode, VariableNode, - assert_named_type, get_named_type, introspection_types, is_enum_type, @@ -134,7 +134,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): - return value.set_type(type_).ast_variable + return value.set_type(type_).ast_variable_name if is_non_null_type(type_): type_ = cast(GraphQLNonNull, type_) @@ -529,26 +529,33 @@ class DSLVariable: def __init__(self, name: str): """:meta private:""" - self.type: Optional[TypeNode] = None self.name = name - self.ast_variable = VariableNode(name=NameNode(value=self.name)) + self.ast_variable_type: Optional[TypeNode] = None + self.ast_variable_name = VariableNode(name=NameNode(value=self.name)) + self.default_value = None + self.type: Optional[GraphQLInputType] = None - def to_ast_type( - self, type_: Union[GraphQLWrappingType, GraphQLNamedType] - ) -> TypeNode: + def to_ast_type(self, type_: GraphQLInputType) -> TypeNode: if is_wrapping_type(type_): if isinstance(type_, GraphQLList): return ListTypeNode(type=self.to_ast_type(type_.of_type)) + elif isinstance(type_, GraphQLNonNull): return NonNullTypeNode(type=self.to_ast_type(type_.of_type)) - type_ = assert_named_type(type_) + assert isinstance( + type_, (GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType) + ) + return NamedTypeNode(name=NameNode(value=type_.name)) - def set_type( - self, type_: Union[GraphQLWrappingType, GraphQLNamedType] - ) -> "DSLVariable": - self.type = self.to_ast_type(type_) + def set_type(self, type_: GraphQLInputType) -> "DSLVariable": + self.type = type_ + self.ast_variable_type = self.to_ast_type(type_) + return self + + def default(self, default_value: Any) -> "DSLVariable": + self.default_value = default_value return self @@ -581,9 +588,11 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: """ return tuple( VariableDefinitionNode( - type=var.type, - variable=var.ast_variable, - default_value=None, + type=var.ast_variable_type, + variable=var.ast_variable_name, + default_value=None + if var.default_value is None + else ast_from_value(var.default_value, var.type), ) for var in self.variables.values() if var.type is not None # only variables used diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 0b881806..d021e122 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -111,11 +111,11 @@ def test_ast_from_serialized_value_untyped_typeerror(): def test_variable_to_ast_type_passing_wrapping_type(): - wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("Droid"))) - variable = DSLVariable("droids") + wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("ReviewInput"))) + variable = DSLVariable("review_input") ast = variable.to_ast_type(wrapping_type) assert ast == NonNullTypeNode( - type=ListTypeNode(type=NamedTypeNode(name=NameNode(value="Droid"))) + type=ListTypeNode(type=NamedTypeNode(name=NameNode(value="ReviewInput"))) ) @@ -170,6 +170,50 @@ def test_add_variable_definitions(ds): ) +def test_add_variable_definitions_with_default_value_enum(ds): + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review=var.review, episode=var.episode.default(4) + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """mutation ($review: ReviewInput, $episode: Episode = NEWHOPE) { + createReview(review: $review, episode: $episode) { + stars + commentary + } +}""" + ) + + +def test_add_variable_definitions_with_default_value_input_object(ds): + var = DSLVariableDefinitions() + op = DSLMutation( + ds.Mutation.createReview.args( + review=var.review.default({"stars": 5, "commentary": "Wow!"}), + episode=var.episode, + ).select(ds.Review.stars, ds.Review.commentary) + ) + op.variable_definitions = var + query = dsl_gql(op) + + assert ( + print_ast(query) + == """ +mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { + createReview(review: $review, episode: $episode) { + stars + commentary + } +}""".strip() + ) + + def test_add_variable_definitions_in_input_object(ds): var = DSLVariableDefinitions() op = DSLMutation( From 321c606cb5e5c7b47eec3ffdfc83f4696b65635c Mon Sep 17 00:00:00 2001 From: Paul van der Linden Date: Wed, 4 May 2022 11:42:28 +0200 Subject: [PATCH 078/239] Fix parsing of None with parse_results=True (#326) --- gql/utilities/parse_result.py | 13 +-- tests/custom_scalars/test_parse_results.py | 98 ++++++++++++++++++++++ 2 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 tests/custom_scalars/test_parse_results.py diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 5f9dd2a4..ede627ae 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -293,8 +293,7 @@ def leave_field( if self.current_result is None: - log.debug(f"Leave field {name}: returning None") - return {name: None} + return_value = None elif node.selection_set is None: @@ -308,23 +307,19 @@ def leave_field( assert is_leaf_type(result_type) # Finally parsing a single scalar using the parse_value method - parsed_value = result_type.parse_value(self.current_result) - - return_value = {name: parsed_value} + return_value = result_type.parse_value(self.current_result) else: partial_results = cast(List[Dict[str, Any]], node.selection_set) - return_value = { - name: {k: v for pr in partial_results for k, v in pr.items()} - } + return_value = {k: v for pr in partial_results for k, v in pr.items()} # Go up a level in the result stack self.result_stack.pop() log.debug(f"Leave field {name}: returning {return_value}") - return return_value + return {name: return_value} # Fragments diff --git a/tests/custom_scalars/test_parse_results.py b/tests/custom_scalars/test_parse_results.py new file mode 100644 index 00000000..e3c6d6f6 --- /dev/null +++ b/tests/custom_scalars/test_parse_results.py @@ -0,0 +1,98 @@ +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLInt, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, +) + +from gql import Client, gql + +static_result = { + "edges": [ + { + "node": { + "from": {"address": "0x45b9ad45995577fe"}, + "to": {"address": "0x6394e988297f5ed2"}, + } + }, + {"node": {"from": None, "to": {"address": "0x6394e988297f5ed2"}}}, + ] +} + + +def resolve_test(root, _info, count): + return static_result + + +Account = GraphQLObjectType( + name="Account", + fields={"address": GraphQLField(GraphQLNonNull(GraphQLString))}, +) + + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "test": GraphQLField( + GraphQLObjectType( + name="test", + fields={ + "edges": GraphQLField( + GraphQLList( + GraphQLObjectType( + "example", + fields={ + "node": GraphQLField( + GraphQLObjectType( + name="node", + fields={ + "from": GraphQLField(Account), + "to": GraphQLField(Account), + }, + ) + ) + }, + ) + ) + ) + }, + ), + args={"count": GraphQLArgument(GraphQLInt)}, + resolve=resolve_test, + ), + }, +) + +schema = GraphQLSchema(query=queryType) + + +def test_parse_results_null_mapping(): + """This is a regression test for the issue: + https://github.com/graphql-python/gql/issues/325 + + Most of the parse_results tests are in tests/starwars/test_parse_results.py + """ + + client = Client(schema=schema, parse_results=True) + query = gql( + """query testQ($count: Int) {test(count: $count){ + edges { + node { + from { + address + } + to { + address + } + } + } + } }""" + ) + + assert client.execute(query, variable_values={"count": 2}) == { + "test": static_result + } From 9f2139b8de76ecd2bb6728989c6a36b329b4500d Mon Sep 17 00:00:00 2001 From: Luke Taverne Date: Thu, 19 May 2022 22:04:09 +0200 Subject: [PATCH 079/239] Check for errors during fetch_schema() (#328) --- docs/advanced/error_handling.rst | 5 ++++ gql/client.py | 31 +++++++++++++++----- tests/test_aiohttp.py | 43 ++++++++++++++++++++++++++++ tests/test_requests.py | 49 ++++++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 7 deletions(-) diff --git a/docs/advanced/error_handling.rst b/docs/advanced/error_handling.rst index 2fd1e39b..4e6618c9 100644 --- a/docs/advanced/error_handling.rst +++ b/docs/advanced/error_handling.rst @@ -41,6 +41,11 @@ Here are the possible Transport Errors: The message of the exception contains the first error returned by the backend. All the errors messages are available in the exception :code:`errors` attribute. + If the error message begins with :code:`Error while fetching schema:`, it means + that gql was not able to get the schema from the backend. + If you don't need the schema, you can try to create the client with + :code:`fetch_schema_from_transport=False` + - :class:`TransportClosed `: This exception is generated when the client is trying to use the transport while the transport was previously closed. diff --git a/gql/client.py b/gql/client.py index c0972133..fdac4a36 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,12 +1,13 @@ import asyncio import sys import warnings -from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, overload +from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast, overload from graphql import ( DocumentNode, ExecutionResult, GraphQLSchema, + IntrospectionQuery, build_ast_schema, get_introspection_query, parse, @@ -55,7 +56,7 @@ class Client: def __init__( self, schema: Optional[Union[str, GraphQLSchema]] = None, - introspection=None, + introspection: Optional[IntrospectionQuery] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[Union[int, float]] = 10, @@ -106,7 +107,7 @@ def __init__( self.schema: Optional[GraphQLSchema] = schema # Answer of the introspection query - self.introspection = introspection + self.introspection: Optional[IntrospectionQuery] = introspection # GraphQL transport chosen self.transport: Optional[Union[Transport, AsyncTransport]] = transport @@ -131,6 +132,22 @@ def validate(self, document: DocumentNode): if validation_errors: raise validation_errors[0] + def _build_schema_from_introspection(self, execution_result: ExecutionResult): + if execution_result.errors: + raise TransportQueryError( + ( + f"Error while fetching schema: {execution_result.errors[0]!s}\n" + "If you don't need the schema, you can try with: " + '"fetch_schema_from_transport=False"' + ), + errors=execution_result.errors, + data=execution_result.data, + extensions=execution_result.extensions, + ) + + self.introspection = cast(IntrospectionQuery, execution_result.data) + self.schema = build_client_schema(self.introspection) + @overload def execute_sync( self, @@ -802,8 +819,8 @@ def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" execution_result = self.transport.execute(parse(get_introspection_query())) - self.client.introspection = execution_result.data - self.client.schema = build_client_schema(self.client.introspection) + + self.client._build_schema_from_introspection(execution_result) @property def transport(self): @@ -1175,8 +1192,8 @@ async def fetch_schema(self) -> None: execution_result = await self.transport.execute( parse(get_introspection_query()) ) - self.client.introspection = execution_result.data - self.client.schema = build_client_schema(self.client.introspection) + + self.client._build_schema_from_introspection(execution_result) @property def transport(self): diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 2535ddb3..a5a3127d 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1190,3 +1190,46 @@ async def handler(request): africa = continents[0] assert africa["code"] == "AF" + + +@pytest.mark.asyncio +async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + error_answer = """ +{ + "errors": [ + { + "errorType": "UnauthorizedException", + "message": "Permission denied" + } + ] +} +""" + + async def handler(request): + return web.Response( + text=error_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + with pytest.raises(TransportQueryError) as exc_info: + async with Client(transport=transport, fetch_schema_from_transport=True): + pass + + expected_error = ( + "Error while fetching schema: " + "{'errorType': 'UnauthorizedException', 'message': 'Permission denied'}" + ) + + assert expected_error in str(exc_info.value) + assert transport.session is None diff --git a/tests/test_requests.py b/tests/test_requests.py index 7cd7f712..70fc337e 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -755,3 +755,52 @@ def test_code(): f2.close() await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_fetching_schema( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + error_answer = """ +{ + "errors": [ + { + "errorType": "UnauthorizedException", + "message": "Permission denied" + } + ] +} +""" + + async def handler(request): + return web.Response( + text=error_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with pytest.raises(TransportQueryError) as exc_info: + with Client(transport=transport, fetch_schema_from_transport=True): + pass + + expected_error = ( + "Error while fetching schema: " + "{'errorType': 'UnauthorizedException', 'message': 'Permission denied'}" + ) + + assert expected_error in str(exc_info.value) + assert transport.session is None + + await run_sync_test(event_loop, server, test_code) From 6c91bb5e4f03eb5fc12d09864a1e0c30e17b423f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 20 May 2022 10:25:17 +0200 Subject: [PATCH 080/239] Bump version number to 3.3.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 11731085..88c513ea 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.2.0" +__version__ = "3.3.0" From 3970f1ce3c5dc65a014326d8d46e9a8070f2cc16 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 3 Jul 2022 15:54:29 +0200 Subject: [PATCH 081/239] Permanent reconnecting async session (#324) --- docs/advanced/async_permanent_session.rst | 115 ++++++++ docs/advanced/index.rst | 1 + docs/code_examples/console_async.py | 73 +++++ docs/code_examples/fastapi_async.py | 101 +++++++ .../reconnecting_mutation_http.py | 47 +++ .../code_examples/reconnecting_mutation_ws.py | 47 +++ .../reconnecting_subscription.py | 32 +++ gql/client.py | 268 +++++++++++++++++- gql/transport/aiohttp.py | 5 + setup.py | 1 + tests/test_aiohttp.py | 106 +++++++ tests/test_graphqlws_subscription.py | 86 +++++- 12 files changed, 872 insertions(+), 10 deletions(-) create mode 100644 docs/advanced/async_permanent_session.rst create mode 100644 docs/code_examples/console_async.py create mode 100644 docs/code_examples/fastapi_async.py create mode 100644 docs/code_examples/reconnecting_mutation_http.py create mode 100644 docs/code_examples/reconnecting_mutation_ws.py create mode 100644 docs/code_examples/reconnecting_subscription.py diff --git a/docs/advanced/async_permanent_session.rst b/docs/advanced/async_permanent_session.rst new file mode 100644 index 00000000..240d8b4f --- /dev/null +++ b/docs/advanced/async_permanent_session.rst @@ -0,0 +1,115 @@ +.. _async_permanent_session: + +Async permanent session +======================= + +Sometimes you want to have a single permanent reconnecting async session to a GraphQL backend, +and that can be `difficult to manage`_ manually with the :code:`async with client as session` syntax. + +It is now possible to have a single reconnecting session using the +:meth:`connect_async ` method of Client +with a :code:`reconnecting=True` argument. + +.. code-block:: python + + # Create a session from the client which will reconnect automatically. + # This session can be kept in a class for example to provide a way + # to execute GraphQL queries from many different places + session = await client.connect_async(reconnecting=True) + + # You can run execute or subscribe method on this session + result = await session.execute(query) + + # When you want the connection to close (for cleanup), + # you call close_async + await client.close_async() + + +When you use :code:`reconnecting=True`, gql will watch the exceptions generated +during the execute and subscribe calls and, if it detects a TransportClosed exception +(indicating that the link to the underlying transport is broken), +it will try to reconnect to the backend again. + +Retries +------- + +Connection retries +^^^^^^^^^^^^^^^^^^ + +With :code:`reconnecting=True`, gql will use the `backoff`_ module to repeatedly try to connect with +exponential backoff and jitter with a maximum delay of 60 seconds by default. + +You can change the default reconnecting profile by providing your own +backoff decorator to the :code:`retry_connect` argument. + +.. code-block:: python + + # Here wait maximum 5 minutes between connection retries + retry_connect = backoff.on_exception( + backoff.expo, # wait generator (here: exponential backoff) + Exception, # which exceptions should cause a retry (here: everything) + max_value=300, # max wait time in seconds + ) + session = await client.connect_async( + reconnecting=True, + retry_connect=retry_connect, + ) + +Execution retries +^^^^^^^^^^^^^^^^^ + +With :code:`reconnecting=True`, by default we will also retry up to 5 times +when an exception happens during an execute call (to manage a possible loss in the connection +to the transport). + +There is no retry in case of a :code:`TransportQueryError` exception as it indicates that +the connection to the backend is working correctly. + +You can change the default execute retry profile by providing your own +backoff decorator to the :code:`retry_execute` argument. + +.. code-block:: python + + # Here Only 3 tries for execute calls + retry_execute = backoff.on_exception( + backoff.expo, + Exception, + max_tries=3, + giveup=lambda e: isinstance(e, TransportQueryError), + ) + session = await client.connect_async( + reconnecting=True, + retry_execute=retry_execute, + ) + +If you don't want any retry on the execute calls, you can disable the retries with :code:`retry_execute=False` + +Subscription retries +^^^^^^^^^^^^^^^^^^^^ + +There is no :code:`retry_subscribe` as it is not feasible with async generators. +If you want retries for your subscriptions, then you can do it yourself +with backoff decorators on your methods. + +.. code-block:: python + + @backoff.on_exception(backoff.expo, + Exception, + max_tries=3, + giveup=lambda e: isinstance(e, TransportQueryError)) + async def execute_subscription1(session): + async for result in session.subscribe(subscription1): + print(result) + +FastAPI example +--------------- + +.. literalinclude:: ../code_examples/fastapi_async.py + +Console example +--------------- + +.. literalinclude:: ../code_examples/console_async.py + +.. _difficult to manage: https://github.com/graphql-python/gql/issues/179 +.. _backoff: https://github.com/litl/backoff diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index 8005b381..baae9276 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -5,6 +5,7 @@ Advanced :maxdepth: 2 async_advanced_usage + async_permanent_session logging error_handling local_schema diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py new file mode 100644 index 00000000..5391f7bf --- /dev/null +++ b/docs/code_examples/console_async.py @@ -0,0 +1,73 @@ +import asyncio +import logging + +from aioconsole import ainput + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport + +logging.basicConfig(level=logging.INFO) + +GET_CONTINENT_NAME = """ + query getContinentName ($code: ID!) { + continent (code: $code) { + name + } + } +""" + + +class GraphQLContinentClient: + def __init__(self): + self._client = Client( + transport=AIOHTTPTransport(url="https://countries.trevorblades.com/") + ) + self._session = None + + self.get_continent_name_query = gql(GET_CONTINENT_NAME) + + async def connect(self): + self._session = await self._client.connect_async(reconnecting=True) + + async def close(self): + await self._client.close_async() + + async def get_continent_name(self, code): + params = {"code": code} + + answer = await self._session.execute( + self.get_continent_name_query, variable_values=params + ) + + return answer.get("continent").get("name") + + +async def main(): + continent_client = GraphQLContinentClient() + + continent_codes = ["AF", "AN", "AS", "EU", "NA", "OC", "SA"] + + await continent_client.connect() + + while True: + + answer = await ainput("\nPlease enter a continent code or 'exit':") + answer = answer.strip() + + if answer == "exit": + break + elif answer in continent_codes: + + try: + continent_name = await continent_client.get_continent_name(answer) + print(f"The continent name is {continent_name}\n") + except Exception as exc: + print(f"Received exception {exc} while trying to get continent name") + + else: + print(f"Please enter a valid continent code from {continent_codes}") + + await continent_client.close() + + +asyncio.run(main()) diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py new file mode 100644 index 00000000..3bedd187 --- /dev/null +++ b/docs/code_examples/fastapi_async.py @@ -0,0 +1,101 @@ +# First install fastapi and uvicorn: +# +# pip install fastapi uvicorn +# +# then run: +# +# uvicorn fastapi_async:app --reload + +import logging + +from fastapi import FastAPI, HTTPException +from fastapi.responses import HTMLResponse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) + +transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") + +client = Client(transport=transport) + +query = gql( + """ +query getContinentInfo($code: ID!) { + continent(code:$code) { + name + code + countries { + name + capital + } + } +} +""" +) + +app = FastAPI() + + +@app.on_event("startup") +async def startup_event(): + print("Connecting to GraphQL backend") + + await client.connect_async(reconnecting=True) + print("End of startup") + + +@app.on_event("shutdown") +async def shutdown_event(): + print("Shutting down GraphQL permanent connection...") + await client.close_async() + print("Shutting down GraphQL permanent connection... done") + + +continent_codes = [ + "AF", + "AN", + "AS", + "EU", + "NA", + "OC", + "SA", +] + + +@app.get("/", response_class=HTMLResponse) +def get_root(): + + continent_links = ", ".join( + [f'{code}' for code in continent_codes] + ) + + return f""" + + + Continents + + + Continents: {continent_links} + + +""" + + +@app.get("/continent/{continent_code}") +async def get_continent(continent_code): + + if continent_code not in continent_codes: + raise HTTPException(status_code=404, detail="Continent not found") + + try: + result = await client.session.execute( + query, variable_values={"code": continent_code} + ) + except Exception as e: + log.debug(f"get_continent Error: {e}") + raise HTTPException(status_code=503, detail="GraphQL backend unavailable") + + return result diff --git a/docs/code_examples/reconnecting_mutation_http.py b/docs/code_examples/reconnecting_mutation_http.py new file mode 100644 index 00000000..f4329c8b --- /dev/null +++ b/docs/code_examples/reconnecting_mutation_http.py @@ -0,0 +1,47 @@ +import asyncio +import logging + +import backoff + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport + +logging.basicConfig(level=logging.INFO) + + +async def main(): + + # Note: this example used the test backend from + # https://github.com/slothmanxyz/typegraphql-ws-apollo + transport = AIOHTTPTransport(url="ws://localhost:5000/graphql") + + client = Client(transport=transport) + + retry_connect = backoff.on_exception( + backoff.expo, + Exception, + max_value=10, + jitter=None, + ) + session = await client.connect_async(reconnecting=True, retry_connect=retry_connect) + + num = 0 + + while True: + num += 1 + + # Execute single query + query = gql("mutation ($message: String!) {sendMessage(message: $message)}") + + params = {"message": f"test {num}"} + + try: + result = await session.execute(query, variable_values=params) + print(result) + except Exception as e: + print(f"Received exception {e}") + + await asyncio.sleep(1) + + +asyncio.run(main()) diff --git a/docs/code_examples/reconnecting_mutation_ws.py b/docs/code_examples/reconnecting_mutation_ws.py new file mode 100644 index 00000000..7d7c8f8a --- /dev/null +++ b/docs/code_examples/reconnecting_mutation_ws.py @@ -0,0 +1,47 @@ +import asyncio +import logging + +import backoff + +from gql import Client, gql +from gql.transport.websockets import WebsocketsTransport + +logging.basicConfig(level=logging.INFO) + + +async def main(): + + # Note: this example used the test backend from + # https://github.com/slothmanxyz/typegraphql-ws-apollo + transport = WebsocketsTransport(url="ws://localhost:5000/graphql") + + client = Client(transport=transport) + + retry_connect = backoff.on_exception( + backoff.expo, + Exception, + max_value=10, + jitter=None, + ) + session = await client.connect_async(reconnecting=True, retry_connect=retry_connect) + + num = 0 + + while True: + num += 1 + + # Execute single query + query = gql("mutation ($message: String!) {sendMessage(message: $message)}") + + params = {"message": f"test {num}"} + + try: + result = await session.execute(query, variable_values=params) + print(result) + except Exception as e: + print(f"Received exception {e}") + + await asyncio.sleep(1) + + +asyncio.run(main()) diff --git a/docs/code_examples/reconnecting_subscription.py b/docs/code_examples/reconnecting_subscription.py new file mode 100644 index 00000000..7ff33950 --- /dev/null +++ b/docs/code_examples/reconnecting_subscription.py @@ -0,0 +1,32 @@ +import asyncio +import logging + +from gql import Client, gql +from gql.transport.websockets import WebsocketsTransport + +logging.basicConfig(level=logging.INFO) + + +async def main(): + + # Note: this example used the test backend from + # https://github.com/slothmanxyz/typegraphql-ws-apollo + transport = WebsocketsTransport(url="ws://localhost:5000/graphql") + + client = Client(transport=transport) + + session = await client.connect_async(reconnecting=True) + + query = gql("subscription {receiveMessage {message}}") + + while True: + try: + async for result in session.subscribe(query): + print(result) + except Exception as e: + print(f"Received exception {e}") + + await asyncio.sleep(1) + + +asyncio.run(main()) diff --git a/gql/client.py b/gql/client.py index fdac4a36..d4a9dfef 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,8 +1,21 @@ import asyncio +import logging import sys import warnings -from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast, overload +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + Optional, + TypeVar, + Union, + cast, + overload, +) +import backoff from graphql import ( DocumentNode, ExecutionResult, @@ -15,7 +28,7 @@ ) from .transport.async_transport import AsyncTransport -from .transport.exceptions import TransportQueryError +from .transport.exceptions import TransportClosed, TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport from .utilities import build_client_schema @@ -33,6 +46,9 @@ from typing_extensions import Literal # pragma: no cover +log = logging.getLogger(__name__) + + class Client: """The Client class is the main entrypoint to execute GraphQL requests on a GQL transport. @@ -588,15 +604,32 @@ def subscribe( # Then reraise the exception raise - async def __aenter__(self): + async def connect_async(self, reconnecting=False, **kwargs): + r"""Connect asynchronously with the underlying async transport to + produce a session. + + That session will be a permanent auto-reconnecting session + if :code:`reconnecting=True`. + + If you call this method, you should call the + :meth:`close_async ` method + for cleanup. + + :param reconnecting: if True, create a permanent reconnecting session + :param \**kwargs: additional arguments for the + :meth:`ReconnectingAsyncClientSession init method + `. + """ assert isinstance( self.transport, AsyncTransport ), "Only a transport of type AsyncTransport can be used asynchronously" - await self.transport.connect() - - if not hasattr(self, "session"): + if reconnecting: + self.session = ReconnectingAsyncClientSession(client=self, **kwargs) + await self.session.start_connecting_task() + else: + await self.transport.connect() self.session = AsyncClientSession(client=self) # Get schema from transport if needed @@ -612,11 +645,30 @@ async def __aenter__(self): return self.session - async def __aexit__(self, exc_type, exc, tb): + async def close_async(self): + """Close the async transport and stop the optional reconnecting task.""" + + if isinstance(self.session, ReconnectingAsyncClientSession): + await self.session.stop_connecting_task() await self.transport.close() - def __enter__(self): + async def __aenter__(self): + + return await self.connect_async() + + async def __aexit__(self, exc_type, exc, tb): + + await self.close_async() + + def connect_sync(self): + r"""Connect synchronously with the underlying sync transport to + produce a session. + + If you call this method, you should call the + :meth:`close_sync ` method + for cleanup. + """ assert not isinstance(self.transport, AsyncTransport), ( "Only a sync transport can be used." @@ -641,9 +693,17 @@ def __enter__(self): return self.session - def __exit__(self, *args): + def close_sync(self): + """Close the sync transport.""" self.transport.close() + def __enter__(self): + + return self.connect_sync() + + def __exit__(self, *args): + self.close_sync() + class SyncClientSession: """An instance of this class is created when using :code:`with` on the client. @@ -1198,3 +1258,193 @@ async def fetch_schema(self) -> None: @property def transport(self): return self.client.transport + + +_CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) +_Decorator = Callable[[_CallableT], _CallableT] + + +class ReconnectingAsyncClientSession(AsyncClientSession): + """An instance of this class is created when using the + :meth:`connect_async ` method of the + :class:`Client ` class with :code:`reconnecting=True`. + + It is used to provide a single session which will reconnect automatically if + the connection fails. + """ + + def __init__( + self, + client: Client, + retry_connect: Union[bool, _Decorator] = True, + retry_execute: Union[bool, _Decorator] = True, + ): + """ + :param client: the :class:`client ` used. + :param retry_connect: Either a Boolean to activate/deactivate the retries + for the connection to the transport OR a backoff decorator to + provide specific retries parameters for the connections. + :param retry_execute: Either a Boolean to activate/deactivate the retries + for the execute method OR a backoff decorator to + provide specific retries parameters for this method. + """ + self.client = client + self._connect_task = None + + self._reconnect_request_event = asyncio.Event() + self._connected_event = asyncio.Event() + + if retry_connect is True: + # By default, retry again and again, with maximum 60 seconds + # between retries + self.retry_connect = backoff.on_exception( + backoff.expo, + Exception, + max_value=60, + ) + elif retry_connect is False: + self.retry_connect = lambda e: e + else: + assert callable(retry_connect) + self.retry_connect = retry_connect + + if retry_execute is True: + # By default, retry 5 times, except if we receive a TransportQueryError + self.retry_execute = backoff.on_exception( + backoff.expo, + Exception, + max_tries=5, + giveup=lambda e: isinstance(e, TransportQueryError), + ) + elif retry_execute is False: + self.retry_execute = lambda e: e + else: + assert callable(retry_execute) + self.retry_execute = retry_execute + + # Creating the _execute_with_retries and _connect_with_retries methods + # using the provided backoff decorators + self._execute_with_retries = self.retry_execute(self._execute_once) + self._connect_with_retries = self.retry_connect(self.transport.connect) + + async def _connection_loop(self): + """Coroutine used for the connection task. + + - try to connect to the transport with retries + - send a connected event when the connection has been made + - then wait for a reconnect request to try to connect again + """ + + while True: + + # Connect to the transport with the retry decorator + # By default it should keep retrying until it connect + await self._connect_with_retries() + + # Once connected, set the connected event + self._connected_event.set() + self._connected_event.clear() + + # Then wait for the reconnect event + self._reconnect_request_event.clear() + await self._reconnect_request_event.wait() + + async def start_connecting_task(self): + """Start the task responsible to restart the connection + of the transport when requested by an event. + """ + if self._connect_task: + log.warning("connect task already started!") + else: + self._connect_task = asyncio.ensure_future(self._connection_loop()) + + await self._connected_event.wait() + + async def stop_connecting_task(self): + """Stop the connecting task.""" + if self._connect_task is not None: + self._connect_task.cancel() + self._connect_task = None + + async def _execute_once( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + **kwargs, + ) -> ExecutionResult: + """Same Coroutine as parent method _execute but requesting a + reconnection if we receive a TransportClosed exception. + """ + + try: + answer = await super()._execute( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + except TransportClosed: + self._reconnect_request_event.set() + raise + + return answer + + async def _execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + **kwargs, + ) -> ExecutionResult: + """Same Coroutine as parent, but with optional retries + and requesting a reconnection if we receive a TransportClosed exception. + """ + + return await self._execute_with_retries( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + + async def _subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + **kwargs, + ) -> AsyncGenerator[ExecutionResult, None]: + """Same Async generator as parent method _subscribe but requesting a + reconnection if we receive a TransportClosed exception. + """ + + inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe( + document, + variable_values=variable_values, + operation_name=operation_name, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + + try: + async for result in inner_generator: + yield result + + except TransportClosed: + self._reconnect_request_event.set() + raise + + finally: + await inner_generator.aclose() diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index de9ab953..6d51f4f3 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -107,6 +107,8 @@ async def connect(self) -> None: if self.client_session_args: client_session_args.update(self.client_session_args) # type: ignore + log.debug("Connecting transport") + self.session = aiohttp.ClientSession(**client_session_args) else: @@ -173,6 +175,9 @@ async def close(self) -> None: when you exit the async context manager. """ if self.session is not None: + + log.debug("Closing transport") + closed_event = self.create_aiohttp_closed_event(self.session) await self.session.close() try: diff --git a/setup.py b/setup.py index a8b58737..835f8abc 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ install_requires = [ "graphql-core>=3.2,<3.3", "yarl>=1.6,<2.0", + "backoff>=1.11.1,<3.0", ] console_scripts = [ diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index a5a3127d..4a70956c 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1233,3 +1233,109 @@ async def handler(request): assert expected_error in str(exc_info.value) assert transport.session is None + + +@pytest.mark.asyncio +async def test_aiohttp_reconnecting_session(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + client = Client(transport=transport) + + session = await client.connect_async(reconnecting=True) + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await client.close_async() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("retries", [False, lambda e: e]) +async def test_aiohttp_reconnecting_session_retries( + event_loop, aiohttp_server, retries +): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + client = Client(transport=transport) + + session = await client.connect_async( + reconnecting=True, retry_execute=retries, retry_connect=retries + ) + + assert session._execute_with_retries == session._execute_once + assert session._connect_with_retries == session.transport.connect + + await client.close_async() + + +@pytest.mark.asyncio +async def test_aiohttp_reconnecting_session_start_connecting_task_twice( + event_loop, aiohttp_server, caplog +): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + client = Client(transport=transport) + + session = await client.connect_async(reconnecting=True) + + await session.start_connecting_task() + + print(f"Captured log: {caplog.text}") + + expected_warning = "connect task already started!" + assert expected_warning in caplog.text + + await client.close_async() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index ade21911..cb705368 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -27,7 +27,9 @@ logged_messages: List[str] = [] -def server_countdown_factory(keepalive=False, answer_pings=True): +def server_countdown_factory( + keepalive=False, answer_pings=True, simulate_disconnect=False +): async def server_countdown_template(ws, path): import websockets @@ -51,6 +53,9 @@ async def server_countdown_template(ws, path): count = count_found[0] print(f" Server: Countdown started from: {count}") + if simulate_disconnect and count == 8: + await ws.close() + pong_received: asyncio.Event = asyncio.Event() async def counting_coro(): @@ -205,6 +210,12 @@ async def server_countdown_dont_answer_pings(ws, path): await server(ws, path) +async def server_countdown_disconnect(ws, path): + + server = server_countdown_factory(simulate_disconnect=True) + await server(ws, path) + + countdown_subscription_str = """ subscription {{ countdown (count: {count}) {{ @@ -792,3 +803,76 @@ def test_code(): assert count == -1 await run_sync_test(event_loop, graphqlws_server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_disconnect], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +@pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) +async def test_graphqlws_subscription_reconnecting_session( + event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe +): + + import websockets + from gql.transport.websockets import WebsocketsTransport + from gql.transport.exceptions import TransportClosed + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = WebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 8 + subscription_with_disconnect = gql(subscription_str.format(count=count)) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + session = await client.connect_async( + reconnecting=True, retry_connect=False, retry_execute=False + ) + + # First we make a subscription which will cause a disconnect in the backend + # (count=8) + try: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): + pass + except websockets.exceptions.ConnectionClosedOK: + pass + + await asyncio.sleep(50 * MS) + + # Then with the same session handle, we make a subscription or an execute + # which will detect that the transport is closed so that the client could + # try to reconnect + try: + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + await session.execute(subscription) + else: + print("\nSUBSCRIPTION_2\n") + async for result in session.subscribe(subscription): + pass + except TransportClosed: + pass + + await asyncio.sleep(50 * MS) + + # And finally with the same session handle, we make a subscription + # which works correctly + print("\nSUBSCRIPTION_3\n") + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await client.close_async() From ffdae9612ea975d6285343c0fb7051edf7a12ae0 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 4 Jul 2022 01:09:01 +0200 Subject: [PATCH 082/239] Adding explicit json_serialize argument in AIOHTTPTransport (#337) --- gql/transport/aiohttp.py | 15 +++++++++----- tests/test_aiohttp.py | 45 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 6d51f4f3..f4f38b69 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -4,7 +4,7 @@ import json import logging from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Type, Union +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -49,6 +49,7 @@ def __init__( ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, + json_serialize: Callable = json.dumps, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -61,6 +62,8 @@ def __init__( :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly + :param json_serialize: Json serializer callable. + By default json.dumps() function :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -77,6 +80,7 @@ def __init__( self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None self.response_headers: Optional[CIMultiDictProxy[str]] + self.json_serialize: Callable = json_serialize async def connect(self) -> None: """Coroutine which will create an aiohttp ClientSession() as self.session. @@ -96,6 +100,7 @@ async def connect(self) -> None: "auth": None if isinstance(self.auth, AppSyncAuthentication) else self.auth, + "json_serialize": self.json_serialize, } if self.timeout is not None: @@ -248,14 +253,14 @@ async def execute( file_streams = {str(i): files[path] for i, path in enumerate(files)} # Add the payload to the operations field - operations_str = json.dumps(payload) + operations_str = self.json_serialize(payload) log.debug("operations %s", operations_str) data.add_field( "operations", operations_str, content_type="application/json" ) # Add the file map field - file_map_str = json.dumps(file_map) + file_map_str = self.json_serialize(file_map) log.debug("file_map %s", file_map_str) data.add_field("map", file_map_str, content_type="application/json") @@ -270,7 +275,7 @@ async def execute( payload["variables"] = variable_values if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(payload)) + log.info(">>> %s", self.json_serialize(payload)) post_args = {"json": payload} @@ -281,7 +286,7 @@ async def execute( # Add headers for AppSync if requested if isinstance(self.auth, AppSyncAuthentication): post_args["headers"] = self.auth.get_headers( - json.dumps(payload), + self.json_serialize(payload), {"content-type": "application/json"}, ) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 4a70956c..3a84d21e 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1339,3 +1339,48 @@ async def handler(request): assert expected_warning in caplog.text await client.close_async() + + +@pytest.mark.asyncio +async def test_aiohttp_json_serializer(event_loop, aiohttp_server, caplog): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + + request_text = await request.text() + print("Received on backend: " + request_text) + + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport( + url=url, + timeout=10, + json_serialize=lambda e: json.dumps(e, separators=(",", ":")), + ) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking that there is no space after the colon in the log + expected_log = '"query":"query getContinents' + assert expected_log in caplog.text From 5912f8fe1526f803d71e76d8450b00c23909da17 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 4 Jul 2022 01:14:44 +0200 Subject: [PATCH 083/239] DOC Add basic example result in README (#336) --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 780eaf10..8e0ac68b 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,13 @@ result = client.execute(query) print(result) ``` +Executing the above code should output the following result: + +``` +$ python basic_example.py +{'continents': [{'code': 'AF', 'name': 'Africa'}, {'code': 'AN', 'name': 'Antarctica'}, {'code': 'AS', 'name': 'Asia'}, {'code': 'EU', 'name': 'Europe'}, {'code': 'NA', 'name': 'North America'}, {'code': 'OC', 'name': 'Oceania'}, {'code': 'SA', 'name': 'South America'}]} +``` + > **WARNING**: Please note that this basic example won't work if you have an asyncio event loop running. In some > python environments (as with Jupyter which uses IPython) an asyncio event loop is created for you. In that case you > should use instead the [async usage example](https://gql.readthedocs.io/en/latest/async/async_usage.html#async-usage). From ddabb226d2ac55ef9dee65c826a021f49edd191b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 4 Jul 2022 01:34:30 +0200 Subject: [PATCH 084/239] Allow omitting optional arguments with serialize_variables=True (#338) --- gql/utilities/serialize_variable_values.py | 1 + tests/custom_scalars/test_datetime.py | 28 +++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/gql/utilities/serialize_variable_values.py b/gql/utilities/serialize_variable_values.py index 833df8bd..38ad1995 100644 --- a/gql/utilities/serialize_variable_values.py +++ b/gql/utilities/serialize_variable_values.py @@ -85,6 +85,7 @@ def serialize_value(type_: GraphQLType, value: Any) -> Any: return { field_name: serialize_value(field.type, value[field_name]) for field_name, field in type_.fields.items() + if field_name in value } raise GraphQLError(f"Impossible to serialize value with type: {inspect(type_)}.") diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index 169ce076..b3e717c5 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -75,7 +75,10 @@ def resolve_seconds(root, _info, interval): IntervalInputType = GraphQLInputObjectType( "IntervalInput", fields={ - "start": GraphQLInputField(DatetimeScalar), + "start": GraphQLInputField( + DatetimeScalar, + default_value=datetime(2021, 11, 12, 11, 58, 13, 461161), + ), "end": GraphQLInputField(DatetimeScalar), }, ) @@ -216,3 +219,26 @@ def test_seconds(): print(result) assert result["seconds"] == 432000 + + +@pytest.mark.skipif( + not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" +) +def test_seconds_omit_optional_start_argument(): + client = Client(schema=schema) + + in_five_days = datetime.fromisoformat("2021-11-17T11:58:13.461161") + + query = gql( + "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" + ) + + variable_values = {"interval": {"end": in_five_days}} + + result = client.execute( + query, variable_values=variable_values, serialize_variables=True + ) + + print(result) + + assert result["seconds"] == 432000 From 05c05a2272351d3eaa9f8bf607c683d208d100bd Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 14 Jul 2022 17:23:44 +0200 Subject: [PATCH 085/239] Bump version number to 3.4.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 88c513ea..903a158a 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.3.0" +__version__ = "3.4.0" From 5713ac7432e24bff68f9a625eef2394098e4d8be Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 21 Jul 2022 13:14:56 +0200 Subject: [PATCH 086/239] DOC Add documentation on websockets level ping frames (#345) --- docs/transports/websockets.rst | 24 ++++++++++++++++++++++++ gql/transport/websockets.py | 9 +++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/docs/transports/websockets.rst b/docs/transports/websockets.rst index 23e4735a..a8f6cac6 100644 --- a/docs/transports/websockets.rst +++ b/docs/transports/websockets.rst @@ -82,6 +82,8 @@ There are two ways to send authentication tokens with websockets depending on th init_payload={'Authorization': 'token'} ) +.. _websockets_transport_keepalives: + Keep-Alives ----------- @@ -125,6 +127,28 @@ Here is an example with a ping sent every 60 seconds, expecting a pong within 10 pong_timeout=10, ) +Underlying websockets protocol +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In addition to the keep-alives described above for the apollo and graphql-ws protocols, +there are also `ping frames`_ sent by the underlying websocket connection itself for both of them. + +These pings are enabled by default (every 20 seconds) and could be modified or disabled +by passing extra arguments to the :code:`connect` call of the websockets client using the +:code:`connect_args` argument of the transport. + +.. code-block:: python + + # Disabling websocket protocol level pings + transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + connect_args={"ping_interval": None}, + ) + +See the `websockets keepalive documentation`_ for details. + .. _version 5.6.1: https://github.com/enisdenjo/graphql-ws/releases/tag/v5.6.1 .. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md .. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md +.. _ping frames: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 +.. _websockets keepalive documentation: https://websockets.readthedocs.io/en/stable/topics/timeouts.html#keepalive-in-websockets diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 9e111551..c385d3d7 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -64,14 +64,19 @@ def __init__( a sign of liveness from the server. :param ping_interval: Delay in seconds between pings sent by the client to the backend for the graphql-ws protocol. None (by default) means that - we don't send pings. + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. :param pong_timeout: Delay in seconds to receive a pong from the backend after we sent a ping (only for the graphql-ws protocol). By default equal to half of the ping_interval. :param answer_pings: Whether the client answers the pings from the backend (for the graphql-ws protocol). By default: True - :param connect_args: Other parameters forwarded to websockets.connect + :param connect_args: Other parameters forwarded to + `websockets.connect `_ :param subprotocols: list of subprotocols sent to the backend in the 'subprotocols' http header. By default: both apollo and graphql-ws subprotocols. From a7f7649f364c52b7dc2c3ba10e66166e84bc9213 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 28 Jul 2022 21:59:29 +0200 Subject: [PATCH 087/239] Add execute-timeout argument for gql-cli (#349) --- gql/cli.py | 30 +++++++++++++++++++++++++++++- tests/test_cli.py | 19 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/gql/cli.py b/gql/cli.py index 27a562b2..2a6ff3f5 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -46,6 +46,25 @@ """ +def positive_int_or_none(value_str: str) -> Optional[int]: + """Convert a string argument value into either an int or None. + + Raise a ValueError if the argument is negative or a string which is not "none" + """ + try: + value_int = int(value_str) + except ValueError: + if value_str.lower() == "none": + return None + else: + raise + + if value_int < 0: + raise ValueError + + return value_int + + def get_parser(with_examples: bool = False) -> ArgumentParser: """Provides an ArgumentParser for the gql-cli script. @@ -103,6 +122,13 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: action="store_true", dest="print_schema", ) + parser.add_argument( + "--execute-timeout", + help="set the execute_timeout argument of the Client (default: 10)", + type=positive_int_or_none, + default=10, + dest="execute_timeout", + ) parser.add_argument( "--transport", default="auto", @@ -367,7 +393,9 @@ async def main(args: Namespace) -> int: # Connect to the backend and provide a session async with Client( - transport=transport, fetch_schema_from_transport=args.print_schema + transport=transport, + fetch_schema_from_transport=args.print_schema, + execute_timeout=args.execute_timeout, ) as session: if args.print_schema: diff --git a/tests/test_cli.py b/tests/test_cli.py index 9066544b..359e94fb 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -73,6 +73,25 @@ def test_cli_parser(parser): ) assert args.operation_name == "my_operation" + # Check execute_timeout + # gql-cli https://your_server.com --execute-timeout 1 + args = parser.parse_args(["https://your_server.com", "--execute-timeout", "1"]) + assert args.execute_timeout == 1 + + # gql-cli https://your_server.com --execute-timeout=none + args = parser.parse_args(["https://your_server.com", "--execute-timeout", "none"]) + assert args.execute_timeout is None + + # gql-cli https://your_server.com --execute-timeout=-1 + with pytest.raises(SystemExit): + args = parser.parse_args(["https://your_server.com", "--execute-timeout", "-1"]) + + # gql-cli https://your_server.com --execute-timeout=invalid + with pytest.raises(SystemExit): + args = parser.parse_args( + ["https://your_server.com", "--execute-timeout", "invalid"] + ) + def test_cli_parse_headers(parser): From e73096f62daafc6774058992c3836de45b186879 Mon Sep 17 00:00:00 2001 From: Will Frey Date: Fri, 5 Aug 2022 04:19:49 -0400 Subject: [PATCH 088/239] Make `AsyncTransport` and `Transport` proper abstract base classes (#350) --- gql/transport/async_transport.py | 2 +- gql/transport/transport.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 18f6df79..4cecc9f9 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -4,7 +4,7 @@ from graphql import DocumentNode, ExecutionResult -class AsyncTransport: +class AsyncTransport(abc.ABC): @abc.abstractmethod async def connect(self): """Coroutine used to create a connection to the specified address""" diff --git a/gql/transport/transport.py b/gql/transport/transport.py index a21502f0..cf5e94da 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -3,7 +3,7 @@ from graphql import DocumentNode, ExecutionResult -class Transport: +class Transport(abc.ABC): @abc.abstractmethod def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: """Execute GraphQL query. From 5719d8fe6a8b518693d633681d00b491ec2328df Mon Sep 17 00:00:00 2001 From: Gabriel Chiong Date: Thu, 11 Aug 2022 18:35:13 +1000 Subject: [PATCH 089/239] DOC Minor Spelling Mistake (#351) --- gql/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gql/client.py b/gql/client.py index d4a9dfef..6a6f5a67 100644 --- a/gql/client.py +++ b/gql/client.py @@ -874,7 +874,7 @@ def execute( return result.data def fetch_schema(self) -> None: - """Fetch the GraphQL schema explicitely using introspection. + """Fetch the GraphQL schema explicitly using introspection. Don't use this function and instead set the fetch_schema_from_transport attribute to True""" @@ -1245,7 +1245,7 @@ async def execute( return result.data async def fetch_schema(self) -> None: - """Fetch the GraphQL schema explicitely using introspection. + """Fetch the GraphQL schema explicitly using introspection. Don't use this function and instead set the fetch_schema_from_transport attribute to True""" From 5a9c0cca271b23915a621339a28004dc457a4bdc Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 19 Aug 2022 00:27:15 +0200 Subject: [PATCH 090/239] DOC add phoenix channel transport example (#354) --- docs/code_examples/phoenix_channel_async.py | 30 +++++++++++++++++++++ docs/transports/phoenix.rst | 2 ++ 2 files changed, 32 insertions(+) create mode 100644 docs/code_examples/phoenix_channel_async.py diff --git a/docs/code_examples/phoenix_channel_async.py b/docs/code_examples/phoenix_channel_async.py new file mode 100644 index 00000000..1fdc2566 --- /dev/null +++ b/docs/code_examples/phoenix_channel_async.py @@ -0,0 +1,30 @@ +import asyncio + +from gql import Client, gql +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + + +async def main(): + + transport = PhoenixChannelWebsocketsTransport( + channel_name="YOUR_CHANNEL", url="wss://YOUR_URL/graphql" + ) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with Client(transport=transport) as session: + + # Execute single query + query = gql( + """ + query yourQuery { + ... + } + """ + ) + + result = await session.execute(query) + print(result) + + +asyncio.run(main()) diff --git a/docs/transports/phoenix.rst b/docs/transports/phoenix.rst index 7fb4a90c..b03c2b93 100644 --- a/docs/transports/phoenix.rst +++ b/docs/transports/phoenix.rst @@ -10,6 +10,8 @@ framework `channels`_. Reference: :class:`gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport` +.. literalinclude:: ../code_examples/phoenix_channel_async.py + .. _Absinthe: http://absinthe-graphql.org .. _Phoenix: https://www.phoenixframework.org .. _channels: https://hexdocs.pm/phoenix/Phoenix.Channel.html#content From e53b168f463b62bc7f577d785f6ed0c72acc03bd Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 11 Sep 2022 01:08:25 +0200 Subject: [PATCH 091/239] Fix KeyError when errors is not iterable (#359) --- gql/client.py | 10 ++++++---- gql/utils.py | 11 ++++++++++- tests/test_aiohttp.py | 13 ++++++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/gql/client.py b/gql/client.py index 6a6f5a67..69804faa 100644 --- a/gql/client.py +++ b/gql/client.py @@ -34,6 +34,7 @@ from .utilities import build_client_schema from .utilities import parse_result as parse_result_fn from .utilities import serialize_variable_values +from .utils import str_first_element """ Load the appropriate instance of the Literal type @@ -152,7 +153,8 @@ def _build_schema_from_introspection(self, execution_result: ExecutionResult): if execution_result.errors: raise TransportQueryError( ( - f"Error while fetching schema: {execution_result.errors[0]!s}\n" + "Error while fetching schema: " + f"{str_first_element(execution_result.errors)}\n" "If you don't need the schema, you can try with: " '"fetch_schema_from_transport=False"' ), @@ -858,7 +860,7 @@ def execute( # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( - str(result.errors[0]), + str_first_element(result.errors), errors=result.errors, data=result.data, extensions=result.extensions, @@ -1066,7 +1068,7 @@ async def subscribe( # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( - str(result.errors[0]), + str_first_element(result.errors), errors=result.errors, data=result.data, extensions=result.extensions, @@ -1229,7 +1231,7 @@ async def execute( # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( - str(result.errors[0]), + str_first_element(result.errors), errors=result.errors, data=result.data, extensions=result.extensions, diff --git a/gql/utils.py b/gql/utils.py index 3edb086c..b4265ce1 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -1,6 +1,6 @@ """Utilities to manipulate several python objects.""" -from typing import Any, Dict, Tuple, Type +from typing import Any, Dict, List, Tuple, Type # From this response in Stackoverflow @@ -47,3 +47,12 @@ def recurse_extract(path, obj): nulled_variables = recurse_extract("variables", variables) return nulled_variables, files + + +def str_first_element(errors: List) -> str: + try: + first_error = errors[0] + except (KeyError, TypeError): + first_error = errors + + return str(first_error) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 3a84d21e..f1a3cdf5 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -199,18 +199,21 @@ async def handler(request): assert "500, message='Internal Server Error'" in str(exc_info.value) -query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' +transport_query_error_responses = [ + '{"errors": ["Error 1", "Error 2"]}', + '{"errors": {"error_1": "Something"}}', + '{"errors": 5}', +] @pytest.mark.asyncio -async def test_aiohttp_error_code(event_loop, aiohttp_server): +@pytest.mark.parametrize("query_error", transport_query_error_responses) +async def test_aiohttp_error_code(event_loop, aiohttp_server, query_error): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): - return web.Response( - text=query1_server_error_answer, content_type="application/json" - ) + return web.Response(text=query_error, content_type="application/json") app = web.Application() app.router.add_route("POST", "/", handler) From 5e47f5ffad064b2119ec59474392364cd5c94321 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 7 Nov 2022 22:24:15 +0100 Subject: [PATCH 092/239] Get response headers even with 4xx return code (#367) --- gql/transport/aiohttp.py | 6 ++--- tests/test_aiohttp.py | 47 ++++++++++++++++++++++++++++++++++++++ tests/test_requests.py | 49 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 3 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index f4f38b69..e6e3a782 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -295,6 +295,9 @@ async def execute( async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + # Saving latest response headers in the transport + self.response_headers = resp.headers + async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases @@ -325,9 +328,6 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): if "errors" not in result and "data" not in result: await raise_response_error(resp, 'No "data" or "errors" keys in answer') - # Saving latest response headers in the transport - self.response_headers = resp.headers - return ExecutionResult( errors=result.get("errors"), data=result.get("data"), diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f1a3cdf5..f4899c82 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -172,6 +172,53 @@ async def handler(request): assert "401, message='Unauthorized'" in str(exc_info.value) +@pytest.mark.asyncio +async def test_aiohttp_error_code_429(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + # Will generate http error code 429 + return web.Response( + text=""" + + + Too Many Requests + + +

Too Many Requests

+

I only allow 50 requests per hour to this Web site per + logged in user. Try again soon.

+ +""", + content_type="text/html", + status=429, + headers={"Retry-After": "3600"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "429, message='Too Many Requests'" in str(exc_info.value) + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["Retry-After"] == "3600" + + @pytest.mark.asyncio async def test_aiohttp_error_code_500(event_loop, aiohttp_server): from aiohttp import web diff --git a/tests/test_requests.py b/tests/test_requests.py index 70fc337e..4a193cc9 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -148,6 +148,55 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 429 + return web.Response( + text=""" + + + Too Many Requests + + +

Too Many Requests

+

I only allow 50 requests per hour to this Web site per + logged in user. Try again soon.

+ +""", + content_type="text/html", + status=429, + headers={"Retry-After": "3600"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "429, message='Too Many Requests'" in str(exc_info.value) + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["Retry-After"] == "3600" + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): From 3d85d6469b55257e80fd22470047aae7ad00d85e Mon Sep 17 00:00:00 2001 From: Jonathan Leitschuh Date: Mon, 7 Nov 2022 17:29:13 -0500 Subject: [PATCH 093/239] Handle JSON response being `None` (#365) --- gql/transport/aiohttp.py | 3 +++ tests/test_aiohttp.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index e6e3a782..2b155870 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -325,6 +325,9 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): except Exception: await raise_response_error(resp, "Not a JSON answer") + if result is None: + await raise_response_error(resp, "Not a JSON answer") + if "errors" not in result and "data" not in result: await raise_response_error(resp, 'No "data" or "errors" keys in answer') diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f4899c82..d78e4333 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -299,6 +299,12 @@ async def handler(request): 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' ), }, + { + "response": "", + "expected_exception": ( + "Server did not return a GraphQL result: Not a JSON answer: " + ), + }, ] From ed1f48061e3a9cfc708b6c30adb1a472140ab8ac Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 7 Nov 2022 23:40:28 +0100 Subject: [PATCH 094/239] Chore update graphql-core to 3.3.0 and dropping support for Python 3.6 (#363) --- .github/workflows/deploy.yml | 4 ++-- .github/workflows/lint.yml | 4 ++-- .github/workflows/tests.yml | 18 ++++++++---------- README.md | 2 +- docs/gql-cli/intro.rst | 2 +- docs/intro.rst | 2 +- setup.py | 6 ++---- tests/custom_scalars/test_json.py | 4 ++-- tests/starwars/test_dsl.py | 6 +++--- tests/test_aiohttp.py | 13 +++++++------ tests/test_requests.py | 11 ++++++----- tox.ini | 7 +++---- 12 files changed, 38 insertions(+), 41 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 2a6cdc6b..73778df5 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -10,9 +10,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Build wheel and source tarball diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6ed6d6ea..0f9f0a07 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,9 +7,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a0631101..3716767d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,11 +8,9 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "pypy3"] + python-version: ["3.7", "3.8", "3.9", "3.10", "pypy3.8"] os: [ubuntu-latest, windows-latest] exclude: - - os: windows-latest - python-version: "3.6" - os: windows-latest python-version: "3.7" - os: windows-latest @@ -20,12 +18,12 @@ jobs: - os: windows-latest python-version: "3.10" - os: windows-latest - python-version: "pypy3" + python-version: "pypy3.8" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -45,9 +43,9 @@ jobs: dependency: ["aiohttp", "requests", "websockets"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Install dependencies with only ${{ matrix.dependency }} extra dependency @@ -61,9 +59,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python 3.8 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Install test dependencies diff --git a/README.md b/README.md index 8e0ac68b..12e34b01 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # GQL -This is a GraphQL client for Python 3.6+. +This is a GraphQL client for Python 3.7+. Plays nicely with `graphene`, `graphql-core`, `graphql-js` and any other GraphQL implementation compatible with the spec. GQL architecture is inspired by `React-Relay` and `Apollo-Client`. diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst index b4565b01..93f16d32 100644 --- a/docs/gql-cli/intro.rst +++ b/docs/gql-cli/intro.rst @@ -3,7 +3,7 @@ gql-cli ======= -GQL provides a python 3.6+ script, called `gql-cli` which allows you to execute +GQL provides a python 3.7+ script, called `gql-cli` which allows you to execute GraphQL queries directly from the terminal. This script supports http(s) or websockets protocols. diff --git a/docs/intro.rst b/docs/intro.rst index 9685a980..bbe1cbf6 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -1,7 +1,7 @@ Introduction ============ -`GQL 3`_ is a `GraphQL`_ Client for Python 3.6+ which plays nicely with other +`GQL 3`_ is a `GraphQL`_ Client for Python 3.7+ which plays nicely with other graphql implementations compatible with the spec. Under the hood, it uses `GraphQL-core`_ which is a Python port of `GraphQL.js`_, diff --git a/setup.py b/setup.py index 835f8abc..9615906f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.2,<3.3", + "graphql-core>=3.3.0a2,<3.4", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", ] @@ -48,8 +48,7 @@ ] install_websockets_requires = [ - "websockets>=9,<10;python_version<='3.6'", - "websockets>=10,<11;python_version>'3.6'", + "websockets>=10,<11", ] install_botocore_requires = [ @@ -82,7 +81,6 @@ "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index 4c4da588..6276b408 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -203,7 +203,7 @@ def test_json_value_input_in_dsl_argument(): assert ( str(query) == """addPlayer( - player: {name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"]} + player: { name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"] } )""" ) @@ -237,6 +237,6 @@ def test_json_value_input_with_none_list_in_dsl_argument(): assert ( str(query) == """addPlayer( - player: {name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null} + player: { name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null } )""" ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index d021e122..714e713a 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -205,7 +205,7 @@ def test_add_variable_definitions_with_default_value_input_object(ds): assert ( print_ast(query) == """ -mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { +mutation ($review: ReviewInput = { stars: 5, commentary: "Wow!" }, $episode: Episode) { createReview(review: $review, episode: $episode) { stars commentary @@ -229,7 +229,7 @@ def test_add_variable_definitions_in_input_object(ds): print_ast(query) == """mutation ($stars: Int, $commentary: String, $episode: Episode) { createReview( - review: {stars: $stars, commentary: $commentary} + review: { stars: $stars, commentary: $commentary } episode: $episode ) { stars @@ -554,7 +554,7 @@ def test_multiple_operations(ds): mutation CreateReviewMutation { createReview( episode: JEDI - review: {stars: 5, commentary: "This is a great movie!"} + review: { stars: 5, commentary: "This is a great movie!" } ) { stars commentary diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index d78e4333..9a62a65c 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -588,15 +588,15 @@ def test_code(): file_upload_mutation_1 = """ mutation($file: Upload!) { - uploadFile(input:{other_var:$other_var, file:$file}) { + uploadFile(input:{ other_var:$other_var, file:$file }) { success } } """ file_upload_mutation_1_operations = ( - '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' - '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' + '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -863,7 +863,7 @@ async def file_sender(file_name): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -951,14 +951,15 @@ async def handler(request): file_upload_mutation_3 = """ mutation($files: [Upload!]!) { - uploadFiles(input:{files:$files}) { + uploadFiles(input:{ files:$files }) { success } } """ file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(input: {files: $files})' + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' + "input: { files: $files })" ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) diff --git a/tests/test_requests.py b/tests/test_requests.py index 4a193cc9..141bb756 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -399,15 +399,15 @@ def test_code(): file_upload_mutation_1 = """ mutation($file: Upload!) { - uploadFile(input:{other_var:$other_var, file:$file}) { + uploadFile(input:{ other_var:$other_var, file:$file }) { success } } """ file_upload_mutation_1_operations = ( - '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' - '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' + '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -611,7 +611,7 @@ def test_code(): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -710,7 +710,8 @@ def test_code(): file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(input: {files: $files})' + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: { files: $files })" ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) diff --git a/tox.ini b/tox.ini index e75b8fac..070b5bf2 100644 --- a/tox.ini +++ b/tox.ini @@ -1,16 +1,15 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{36,37,38,39,310,py3} + py{37,38,39,310,py3} [gh-actions] python = - 3.6: py36 3.7: py37 3.8: py38 3.9: py39 3.10: py310 - pypy3: pypy3 + pypy-3: pypy3 [testenv] conda_channels = conda-forge @@ -28,7 +27,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{36,37,39,310,py3}: pytest {posargs:tests} + py{37,39,310,py3}: pytest {posargs:tests} py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From f0150a85e1669cfc028a77e27a3873b326cc93ee Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 7 Nov 2022 23:46:13 +0100 Subject: [PATCH 095/239] Bump version number to 3.5.0a0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 903a158a..36c60f16 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.4.0" +__version__ = "3.5.0a0" From 0819418046ca684d92872327da8ee72a184baef6 Mon Sep 17 00:00:00 2001 From: Helder Correia <174525+helderco@users.noreply.github.com> Date: Sat, 26 Nov 2022 13:40:22 -0100 Subject: [PATCH 096/239] Add HTTPX transport (#370) --- .github/workflows/tests.yml | 2 +- Makefile | 5 +- docs/code_examples/httpx_async.py | 34 + docs/code_examples/httpx_sync.py | 20 + docs/intro.rst | 4 + docs/modules/gql.rst | 1 + docs/modules/transport_httpx.rst | 7 + docs/transports/async_transports.rst | 1 + docs/transports/httpx.rst | 13 + docs/transports/httpx_async.rst | 39 + docs/transports/sync_transports.rst | 1 + docs/usage/file_upload.rst | 9 +- gql/transport/httpx.py | 306 ++++++ setup.py | 7 +- tests/conftest.py | 4 +- tests/test_httpx.py | 850 ++++++++++++++++ tests/test_httpx_async.py | 1391 ++++++++++++++++++++++++++ tests/test_httpx_online.py | 148 +++ 18 files changed, 2836 insertions(+), 6 deletions(-) create mode 100644 docs/code_examples/httpx_async.py create mode 100644 docs/code_examples/httpx_sync.py create mode 100644 docs/modules/transport_httpx.rst create mode 100644 docs/transports/httpx.rst create mode 100644 docs/transports/httpx_async.rst create mode 100644 gql/transport/httpx.py create mode 100644 tests/test_httpx.py create mode 100644 tests/test_httpx_async.py create mode 100644 tests/test_httpx_online.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3716767d..366a953b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,7 +40,7 @@ jobs: strategy: fail-fast: false matrix: - dependency: ["aiohttp", "requests", "websockets"] + dependency: ["aiohttp", "requests", "httpx", "websockets"] steps: - uses: actions/checkout@v3 diff --git a/Makefile b/Makefile index 2275092c..59d08bac 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ SRC_PYTHON := gql tests docs/code_examples dev-setup: - python pip install -e ".[test]" + python -m pip install -e ".[test]" tests: pytest tests --cov=gql --cov-report=term-missing -vv @@ -17,6 +17,9 @@ tests_aiohttp: tests_requests: pytest tests --requests-only +tests_httpx: + pytest tests --httpx-only + tests_websockets: pytest tests --websockets-only diff --git a/docs/code_examples/httpx_async.py b/docs/code_examples/httpx_async.py new file mode 100644 index 00000000..9a01232d --- /dev/null +++ b/docs/code_examples/httpx_async.py @@ -0,0 +1,34 @@ +import asyncio + +from gql import Client, gql +from gql.transport.httpx import HTTPXAsyncTransport + + +async def main(): + + transport = HTTPXAsyncTransport(url="https://countries.trevorblades.com/graphql") + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with Client( + transport=transport, + fetch_schema_from_transport=True, + ) as session: + + # Execute single query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + result = await session.execute(query) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/httpx_sync.py b/docs/code_examples/httpx_sync.py new file mode 100644 index 00000000..bd26f658 --- /dev/null +++ b/docs/code_examples/httpx_sync.py @@ -0,0 +1,20 @@ +from gql import Client, gql +from gql.transport.httpx import HTTPXTransport + +transport = HTTPXTransport(url="https://countries.trevorblades.com/") + +client = Client(transport=transport, fetch_schema_from_transport=True) + +query = gql( + """ + query getContinents { + continents { + code + name + } + } +""" +) + +result = client.execute(query) +print(result) diff --git a/docs/intro.rst b/docs/intro.rst index bbe1cbf6..f7a4b71d 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -45,6 +45,10 @@ The corresponding between extra dependencies required and the GQL classes is: +---------------------+----------------------------------------------------------------+ | requests | :ref:`RequestsHTTPTransport ` | +---------------------+----------------------------------------------------------------+ +| httpx | :ref:`HTTPTXTransport ` | +| | | +| | :ref:`HTTPXAsyncTransport ` | ++---------------------+----------------------------------------------------------------+ | botocore | :ref:`AppSyncIAMAuthentication ` | +---------------------+----------------------------------------------------------------+ diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index be6f904b..5f9edebe 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -26,6 +26,7 @@ Sub-Packages transport_exceptions transport_phoenix_channel_websockets transport_requests + transport_httpx transport_websockets transport_websockets_base dsl diff --git a/docs/modules/transport_httpx.rst b/docs/modules/transport_httpx.rst new file mode 100644 index 00000000..bf2da116 --- /dev/null +++ b/docs/modules/transport_httpx.rst @@ -0,0 +1,7 @@ +gql.transport.httpx +=================== + +.. currentmodule:: gql.transport.httpx + +.. automodule:: gql.transport.httpx + :member-order: bysource diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index df8c23cf..7d751df0 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -10,6 +10,7 @@ Async transports are transports which are using an underlying async library. The :maxdepth: 1 aiohttp + httpx_async websockets phoenix appsync diff --git a/docs/transports/httpx.rst b/docs/transports/httpx.rst new file mode 100644 index 00000000..25796621 --- /dev/null +++ b/docs/transports/httpx.rst @@ -0,0 +1,13 @@ +.. _httpx_transport: + +HTTPXTransport +============== + +The HTTPXTransport is a sync transport using the `httpx`_ library +and allows you to send GraphQL queries using the HTTP protocol. + +Reference: :class:`gql.transport.httpx.HTTPXTransport` + +.. literalinclude:: ../code_examples/httpx_sync.py + +.. _httpx: https://www.python-httpx.org diff --git a/docs/transports/httpx_async.rst b/docs/transports/httpx_async.rst new file mode 100644 index 00000000..c09d0cdc --- /dev/null +++ b/docs/transports/httpx_async.rst @@ -0,0 +1,39 @@ +.. _httpx_async_transport: + +HTTPXAsyncTransport +=================== + +This transport uses the `httpx`_ library and allows you to send GraphQL queries using the HTTP protocol. + +Reference: :class:`gql.transport.httpx.HTTPXAsyncTransport` + +.. note:: + + GraphQL subscriptions are not supported on the HTTP transport. + For subscriptions you should use the :ref:`websockets transport `. + +.. literalinclude:: ../code_examples/httpx_async.py + +Authentication +-------------- + +There are multiple ways to authenticate depending on the server configuration. + +1. Using HTTP Headers + +.. code-block:: python + + transport = HTTPXAsyncTransport( + url='https://SERVER_URL:SERVER_PORT/graphql', + headers={'Authorization': 'token'} + ) + +2. Using HTTP Cookies + +You can manually set the cookies which will be sent with each connection: + +.. code-block:: python + + transport = HTTPXAsyncTransport(url=url, cookies={"cookie1": "val1"}) + +.. _httpx: https://www.python-httpx.org diff --git a/docs/transports/sync_transports.rst b/docs/transports/sync_transports.rst index 3ed566d3..e0ec51a4 100644 --- a/docs/transports/sync_transports.rst +++ b/docs/transports/sync_transports.rst @@ -10,3 +10,4 @@ They cannot be used asynchronously. :maxdepth: 1 requests + httpx diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index 8062f317..f3769d41 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -1,8 +1,9 @@ File uploads ============ -GQL supports file uploads with the :ref:`aiohttp transport ` -and the :ref:`requests transport ` +GQL supports file uploads with the :ref:`aiohttp transport `, the +:ref:`requests transport `, the :ref:`httpx transport `, +and the :ref:`httpx async transport `, using the `GraphQL multipart request spec`_. .. _GraphQL multipart request spec: https://github.com/jaydenseric/graphql-multipart-request-spec @@ -20,6 +21,8 @@ In order to upload a single file, you need to: transport = AIOHTTPTransport(url='YOUR_URL') # Or transport = RequestsHTTPTransport(url='YOUR_URL') + # Or transport = HTTPXTransport(url='YOUR_URL') + # Or transport = HTTPXAsyncTransport(url='YOUR_URL') client = Client(transport=transport) @@ -48,6 +51,8 @@ It is also possible to upload multiple files using a list. transport = AIOHTTPTransport(url='YOUR_URL') # Or transport = RequestsHTTPTransport(url='YOUR_URL') + # Or transport = HTTPXTransport(url='YOUR_URL') + # Or transport = HTTPXAsyncTransport(url='YOUR_URL') client = Client(transport=transport) diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py new file mode 100644 index 00000000..6e844775 --- /dev/null +++ b/gql/transport/httpx.py @@ -0,0 +1,306 @@ +import io +import json +import logging +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + cast, +) + +import httpx +from graphql import DocumentNode, ExecutionResult, print_ast + +from ..utils import extract_files +from . import AsyncTransport, Transport +from .exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportServerError, +) + +log = logging.getLogger(__name__) + + +class _HTTPXTransport: + file_classes: Tuple[Type[Any], ...] = (io.IOBase,) + + reponse_headers: Optional[httpx.Headers] = None + + def __init__( + self, + url: Union[str, httpx.URL], + json_serialize: Callable = json.dumps, + **kwargs, + ): + """Initialize the transport with the given httpx parameters. + + :param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'. + :param json_serialize: Json serializer callable. + By default json.dumps() function. + :param kwargs: Extra args passed to the `httpx` client. + """ + self.url = url + self.json_serialize = json_serialize + self.kwargs = kwargs + + def _prepare_request( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> Dict[str, Any]: + query_str = print_ast(document) + + payload: Dict[str, Any] = { + "query": query_str, + } + + if operation_name: + payload["operationName"] = operation_name + + if upload_files: + # If the upload_files flag is set, then we need variable_values + assert variable_values is not None + + post_args = self._prepare_file_uploads(variable_values, payload) + else: + if variable_values: + payload["variables"] = variable_values + + post_args = {"json": payload} + + # Log the payload + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) + + # Pass post_args to httpx post method + if extra_args: + post_args.update(extra_args) + + return post_args + + def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=variable_values, + file_classes=self.file_classes, + ) + + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values + + # Prepare to send multipart-encoded data + data: Dict[str, Any] = {} + file_map: Dict[str, List[str]] = {} + file_streams: Dict[str, Tuple[str, Any]] = {} + + for i, (path, val) in enumerate(files.items()): + key = str(i) + + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map[key] = [path] + + # Generate the file streams + # Will generate something like + # {"0": ("variables.file", <_io.BufferedReader ...>)} + filename = cast(str, getattr(val, "name", key)) + file_streams[key] = (filename, val) + + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) + data["operations"] = operations_str + + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) + data["map"] = file_map_str + + return {"data": data, "files": file_streams} + + def _prepare_result(self, response: httpx.Response) -> ExecutionResult: + # Save latest response headers in transport + self.response_headers = response.headers + + if log.isEnabledFor(logging.DEBUG): + log.debug("<<< %s", response.text) + + try: + result: Dict[str, Any] = response.json() + + except Exception: + self._raise_response_error(response, "Not a JSON answer") + + if "errors" not in result and "data" not in result: + self._raise_response_error(response, 'No "data" or "errors" keys in answer') + + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) + + def _raise_response_error(self, response: httpx.Response, reason: str): + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + try: + # Raise a HTTPError if response status is 400 or higher + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise TransportServerError(str(e), e.response.status_code) from e + + raise TransportProtocolError( + f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}" + ) + + +class HTTPXTransport(Transport, _HTTPXTransport): + """:ref:`Sync Transport ` used to execute GraphQL queries + on remote servers. + + The transport uses the httpx library to send HTTP POST requests. + """ + + client: Optional[httpx.Client] = None + + def connect(self): + if self.client: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("Connecting transport") + + self.client = httpx.Client(**self.kwargs) + + def execute( # type: ignore + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute GraphQL query. + + Execute the provided document AST against the configured remote server. This + uses the httpx library to perform a HTTP POST request to the remote server. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters (Default: None). + :param operation_name: Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + :param extra_args: additional arguments to send to the httpx post method + :param upload_files: Set to True if you want to put files in the variable values + :return: The result of execution. + `data` is the result of executing the query, `errors` is null + if no errors occurred, and is a non-empty array if an error occurred. + """ + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + document, + variable_values, + operation_name, + extra_args, + upload_files, + ) + + response = self.client.post(self.url, **post_args) + + return self._prepare_result(response) + + def close(self): + """Closing the transport by closing the inner session""" + if self.client: + self.client.close() + self.client = None + + +class HTTPXAsyncTransport(AsyncTransport, _HTTPXTransport): + """:ref:`Async Transport ` used to execute GraphQL queries + on remote servers. + + The transport uses the httpx library with anyio. + """ + + client: Optional[httpx.AsyncClient] = None + + async def connect(self): + if self.client: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("Connecting transport") + + self.client = httpx.AsyncClient(**self.kwargs) + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute GraphQL query. + + Execute the provided document AST against the configured remote server. This + uses the httpx library to perform a HTTP POST request asynchronously to the + remote server. + + :param document: GraphQL query as AST Node object. + :param variable_values: Dictionary of input parameters (Default: None). + :param operation_name: Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + :param extra_args: additional arguments to send to the httpx post method + :param upload_files: Set to True if you want to put files in the variable values + :return: The result of execution. + `data` is the result of executing the query, `errors` is null + if no errors occurred, and is a non-empty array if an error occurred. + """ + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + document, + variable_values, + operation_name, + extra_args, + upload_files, + ) + + response = await self.client.post(self.url, **post_args) + + return self._prepare_result(response) + + async def close(self): + """Closing the transport by closing the inner session""" + if self.client: + await self.client.aclose() + self.client = None + + def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> AsyncGenerator[ExecutionResult, None]: + """Subscribe is not supported on HTTP. + + :meta private: + """ + raise NotImplementedError("The HTTP transport does not support subscriptions") diff --git a/setup.py b/setup.py index 9615906f..30817ec4 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,10 @@ "urllib3>=1.26", ] +install_httpx_requires = [ + "httpx>=0.23.1,<1", +] + install_websockets_requires = [ "websockets>=10,<11", ] @@ -56,7 +60,7 @@ ] install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_botocore_requires + install_aiohttp_requires + install_requests_requires + install_httpx_requires + install_websockets_requires + install_botocore_requires ) # Get version from __version__.py file @@ -100,6 +104,7 @@ "dev": install_all_requires + dev_requires, "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, + "httpx": install_httpx_requires, "websockets": install_websockets_requires, "botocore": install_botocore_requires, }, diff --git a/tests/conftest.py b/tests/conftest.py index 518d0d3a..b880cff4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from gql import Client -all_transport_dependencies = ["aiohttp", "requests", "websockets", "botocore"] +all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] def pytest_addoption(parser): @@ -55,6 +55,7 @@ def pytest_collection_modifyitems(config, items): # --aiohttp-only # --requests-only + # --httpx-only # --websockets-only for transport in all_transport_dependencies: @@ -119,6 +120,7 @@ async def ssl_aiohttp_server(): "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", + "gql.transport.httpx", "gql.transport.websockets", "gql.dsl", "gql.utilities.parse_result", diff --git a/tests/test_httpx.py b/tests/test_httpx.py new file mode 100644 index 00000000..13f487dd --- /dev/null +++ b/tests/test_httpx.py @@ -0,0 +1,850 @@ +from typing import Mapping + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) +from tests.conftest import TemporaryFile + +# Marking all tests in this file with the httpx marker +pytestmark = pytest.mark.httpx + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + assert "COOKIE" in request.headers + assert "cookie1=val1" == request.headers["COOKIE"] + + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url, cookies={"cookie1": "val1"}) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code_401(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "Client error '401 Unauthorized'" in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code_429(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + # Will generate http error code 429 + return web.Response( + text=""" + + + Too Many Requests + + +

Too Many Requests

+

I only allow 50 requests per hour to this Web site per + logged in user. Try again soon.

+ +""", + content_type="text/html", + status=429, + headers={"Retry-After": "3600"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "429, message='Too Many Requests'" in str(exc_info.value) + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["Retry-After"] == "3600" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code_500(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + # Will generate http error code 500 + raise Exception("Server error") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError): + session.execute(query) + + await run_sync_test(event_loop, server, test_code) + + +query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportQueryError): + session.execute(query) + + await run_sync_test(event_loop, server, test_code) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_httpx_invalid_protocol( + event_loop, aiohttp_server, response, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportProtocolError): + session.execute(query) + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + with pytest.raises(TransportAlreadyConnected): + session.transport.connect() + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_cannot_execute_if_not_connected( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + query = gql(query1_str) + + with pytest.raises(TransportClosed): + transport.execute(query) + + await run_sync_test(event_loop, server, test_code) + + +query1_server_answer_with_extensions = ( + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query_with_extensions(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + execution_result = session.execute(query, get_execution_result=True) + + assert execution_result.extensions["key1"] == "val1" + + await run_sync_test(event_loop, server, test_code) + + +file_upload_server_answer = '{"data":{"success":true}}' + +file_upload_mutation_1 = """ + mutation($file: Upload!) { + uploadFile(input:{ other_var:$other_var, file:$file }) { + success + } + } +""" + +file_upload_mutation_1_operations = ( + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' + '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' + '{"file": null, "other_var": 42}}' +) + +file_upload_mutation_1_map = '{"0": ["variables.file"]}' + +file_1_content = """ +This is a test file +This file will be sent in the GraphQL mutation +""" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def single_upload_handler(request): + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with TemporaryFile(file_1_content) as test_file: + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_additional_headers( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def single_upload_handler(request): + from aiohttp import web + + assert request.headers["X-Auth"] == "foobar" + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url, headers={"X-Auth": "foobar"}) + + with TemporaryFile(file_1_content) as test_file: + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_binary_file_upload(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + + async def binary_upload_handler(request): + + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_binary = await field_2.read() + assert field_2_binary == binary_file_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", binary_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url) + + def test_code(): + with TemporaryFile(binary_file_content) as test_file: + with Client(transport=transport) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + +file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } + """ + + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' + + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ + + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_2_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_2_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + + with Client(transport=transport) as session: + + query = gql(file_upload_mutation_2) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = { + "file1": f1, + "file2": f2, + } + + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + f1.close() + f2.close() + + await run_sync_test(event_loop, server, test_code) + + +file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: { files: $files })" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_list_of_two_files( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{files:$files}) { + success + } + } + """ + + file_upload_mutation_3_map = ( + '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + ) + + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ + + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_3_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_3_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + with Client(transport=transport) as session: + + query = gql(file_upload_mutation_3) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = {"files": [f1, f2]} + + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + f1.close() + f2.close() + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_fetching_schema(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + error_answer = """ +{ + "errors": [ + { + "errorType": "UnauthorizedException", + "message": "Permission denied" + } + ] +} +""" + + async def handler(request): + return web.Response( + text=error_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with pytest.raises(TransportQueryError) as exc_info: + with Client(transport=transport, fetch_schema_from_transport=True): + pass + + expected_error = ( + "Error while fetching schema: " + "{'errorType': 'UnauthorizedException', 'message': 'Permission denied'}" + ) + + assert expected_error in str(exc_info.value) + assert transport.client is None + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py new file mode 100644 index 00000000..362875de --- /dev/null +++ b/tests/test_httpx_async.py @@ -0,0 +1,1391 @@ +import io +import json +from typing import Mapping + +import pytest + +from gql import Client, gql +from gql.cli import get_parser, main +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +from .conftest import TemporaryFile, get_localhost_ssl_context + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_data = ( + '{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}' +) + + +query1_server_answer = f'{{"data":{query1_server_answer_data}}}' + +# Marking all tests in this file with the httpx marker +pytestmark = pytest.mark.httpx + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_ignore_backend_content_type(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="text/plain") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_cookies(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + assert "COOKIE" in request.headers + assert "cookie1=val1" == request.headers["COOKIE"] + + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, cookies={"cookie1": "val1"}) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code_401(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "Client error '401 Unauthorized'" in str(exc_info.value) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code_429(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + # Will generate http error code 429 + return web.Response( + text=""" + + + Too Many Requests + + +

Too Many Requests

+

I only allow 50 requests per hour to this Web site per + logged in user. Try again soon.

+ +""", + content_type="text/html", + status=429, + headers={"Retry-After": "3600"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "Client error '429 Too Many Requests'" in str(exc_info.value) + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["Retry-After"] == "3600" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_code_500(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + # Will generate http error code 500 + raise Exception("Server error") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "Server error '500 Internal Server Error'" in str(exc_info.value) + + +transport_query_error_responses = [ + '{"errors": ["Error 1", "Error 2"]}', + '{"errors": {"error_1": "Something"}}', + '{"errors": 5}', +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("query_error", transport_query_error_responses) +async def test_httpx_error_code(event_loop, aiohttp_server, query_error): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query_error, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportQueryError): + await session.execute(query) + + +invalid_protocol_responses = [ + { + "response": "{}", + "expected_exception": ( + "Server did not return a GraphQL result: " + 'No "data" or "errors" keys in answer: {}' + ), + }, + { + "response": "qlsjfqsdlkj", + "expected_exception": ( + "Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj" + ), + }, + { + "response": '{"not_data_or_errors": 35}', + "expected_exception": ( + "Server did not return a GraphQL result: " + 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' + ), + }, + { + "response": "", + "expected_exception": ( + "Server did not return a GraphQL result: Not a JSON answer: " + ), + }, +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("param", invalid_protocol_responses) +async def test_httpx_invalid_protocol(event_loop, aiohttp_server, param): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + response = param["response"] + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(TransportProtocolError) as exc_info: + await session.execute(query) + + assert param["expected_exception"] in str(exc_info.value) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_subscribe_not_supported(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text="does not matter", content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.raises(NotImplementedError): + async for result in session.subscribe(query): + pass + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + with pytest.raises(TransportAlreadyConnected): + await session.transport.connect() + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_cannot_execute_if_not_connected(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + query = gql(query1_str) + + with pytest.raises(TransportClosed): + await transport.execute(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_extra_args(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + import httpx + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # passing extra arguments to httpx.AsyncClient + transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=transport) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Passing extra arguments to the post method of aiohttp + result = await session.execute(query, extra_args={"follow_redirects": True}) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query2_str = """ + query getEurope ($code: ID!) { + continent (code: $code) { + name + } + } +""" + +query2_server_answer = '{"data": {"continent": {"name": "Europe"}}}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query_variable_values(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + params = {"code": "EU"} + + query = gql(query2_str) + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, operation_name="getEurope" + ) + + continent = result["continent"] + + assert continent["name"] == "Europe" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query_variable_values_fix_issue_292(event_loop, aiohttp_server): + """Allow to specify variable_values without keyword. + + See https://github.com/graphql-python/gql/issues/292""" + + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + params = {"code": "EU"} + + query = gql(query2_str) + + # Execute query asynchronously + result = await session.execute(query, params, operation_name="getEurope") + + continent = result["continent"] + + assert continent["name"] == "Europe" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_execute_running_in_thread( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXAsyncTransport(url=url) + + client = Client(transport=transport) + + query = gql(query1_str) + + client.execute(query) + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_subscribe_running_in_thread( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXAsyncTransport(url=url) + + client = Client(transport=transport) + + query = gql(query1_str) + + # Note: subscriptions are not supported on the httpx transport + # But we add this test in order to have 100% code coverage + # It is to check that we will correctly set an event loop + # in the subscribe function if there is none (in a Thread for example) + # We cannot test this with the websockets transport because + # the websockets transport will set an event loop in its init + + with pytest.raises(NotImplementedError): + for result in client.subscribe(query): + pass + + await run_sync_test(event_loop, server, test_code) + + +file_upload_server_answer = '{"data":{"success":true}}' + +file_upload_mutation_1 = """ + mutation($file: Upload!) { + uploadFile(input:{ other_var:$other_var, file:$file }) { + success + } + } +""" + +file_upload_mutation_1_operations = ( + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' + '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' + '{"file": null, "other_var": 42}}' +) + +file_upload_mutation_1_map = '{"0": ["variables.file"]}' + +file_1_content = """ +This is a test file +This file will be sent in the GraphQL mutation +""" + + +async def single_upload_handler(request): + + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response(text=file_upload_server_answer, content_type="application/json") + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file: + + async with Client(transport=transport) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_without_session( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXAsyncTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file: + + client = Client(transport=transport) + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + result = client.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + await run_sync_test(event_loop, server, test_code) + + +# This is a sample binary file content containing all possible byte values +binary_file_content = bytes(range(0, 256)) + + +async def binary_upload_handler(request): + + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_binary = await field_2.read() + assert field_2_binary == binary_file_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response(text=file_upload_server_answer, content_type="application/json") + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_binary_file_upload(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + app = web.Application() + app.router.add_route("POST", "/", binary_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + with TemporaryFile(binary_file_content) as test_file: + + async with Client(transport=transport) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + +file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } +""" + +file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' +) + +file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' + +file_2_content = """ +This is a second test file +This file will also be sent in the GraphQL mutation +""" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_two_files(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_2_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_2_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + + async with Client(transport=transport) as session: + + query = gql(file_upload_mutation_2) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = { + "file1": f1, + "file2": f2, + } + + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + f1.close() + f2.close() + + success = result["success"] + + assert success + + +file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{ files:$files }) { + success + } + } +""" + +file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' + "input: { files: $files })" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' +) + +file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_list_of_two_files(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_3_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_3_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + + async with Client(transport=transport) as session: + + query = gql(file_upload_mutation_3) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = {"files": [f1, f2]} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + f1.close() + f2.close() + + success = result["success"] + + assert success + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_using_cli(event_loop, aiohttp_server, monkeypatch, capsys): + from aiohttp import web + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, "--verbose"]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.script_launch_mode("subprocess") +async def test_httpx_using_cli_ep( + event_loop, aiohttp_server, monkeypatch, script_runner, run_sync_test +): + from aiohttp import web + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + ret = script_runner.run( + "gql-cli", url, "--verbose", stdin=io.StringIO(query1_str) + ) + + assert ret.success + + # Check that the result has been printed on stdout + captured_out = str(ret.stdout).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_using_cli_invalid_param( + event_loop, aiohttp_server, monkeypatch, capsys +): + from aiohttp import web + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, "--variables", "invalid_param"]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + # Check that the exit_code is an error + exit_code = await main(args) + assert exit_code == 1 + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = "Error: Invalid variable: invalid_param" + + assert expected_error in captured_err + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_using_cli_invalid_query( + event_loop, aiohttp_server, monkeypatch, capsys +): + from aiohttp import web + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Send invalid query on standard input + monkeypatch.setattr("sys.stdin", io.StringIO("BLAHBLAH")) + + exit_code = await main(args) + + assert exit_code == 1 + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = "Syntax Error: Unexpected Name 'BLAHBLAH'" + + assert expected_error in captured_err + + +query1_server_answer_with_extensions = ( + f'{{"data":{query1_server_answer_data}, "extensions":{{"key1": "val1"}}}}' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query_with_extensions(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + execution_result = await session.execute(query, get_execution_result=True) + + assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_query_https(event_loop, ssl_aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert url.startswith("https://") + + cert, _ = get_localhost_ssl_context() + + transport = HTTPXAsyncTransport(url=url, timeout=10, verify=cert.decode()) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + error_answer = """ +{ + "errors": [ + { + "errorType": "UnauthorizedException", + "message": "Permission denied" + } + ] +} +""" + + async def handler(request): + return web.Response( + text=error_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + with pytest.raises(TransportQueryError) as exc_info: + async with Client(transport=transport, fetch_schema_from_transport=True): + pass + + expected_error = ( + "Error while fetching schema: " + "{'errorType': 'UnauthorizedException', 'message': 'Permission denied'}" + ) + + assert expected_error in str(exc_info.value) + assert transport.client is None + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_reconnecting_session(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + client = Client(transport=transport) + + session = await client.connect_async(reconnecting=True) + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await client.close_async() + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("retries", [False, lambda e: e]) +async def test_httpx_reconnecting_session_retries(event_loop, aiohttp_server, retries): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + client = Client(transport=transport) + + session = await client.connect_async( + reconnecting=True, retry_execute=retries, retry_connect=retries + ) + + assert session._execute_with_retries == session._execute_once + assert session._connect_with_retries == session.transport.connect + + await client.close_async() + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_reconnecting_session_start_connecting_task_twice( + event_loop, aiohttp_server, caplog +): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + client = Client(transport=transport) + + session = await client.connect_async(reconnecting=True) + + await session.start_connecting_task() + + print(f"Captured log: {caplog.text}") + + expected_warning = "connect task already started!" + assert expected_warning in caplog.text + + await client.close_async() + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_json_serializer(event_loop, aiohttp_server, caplog): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + + request_text = await request.text() + print(f"Received on backend: {request_text}") + + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport( + url=url, + timeout=10, + json_serialize=lambda e: json.dumps(e, separators=(",", ":")), + ) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking that there is no space after the colon in the log + expected_log = '"query":"query getContinents' + assert expected_log in caplog.text diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py new file mode 100644 index 00000000..ee08e2b1 --- /dev/null +++ b/tests/test_httpx_online.py @@ -0,0 +1,148 @@ +import asyncio +import sys +from typing import Dict + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import TransportQueryError + + +@pytest.mark.httpx +@pytest.mark.online +@pytest.mark.asyncio +@pytest.mark.parametrize("protocol", ["http", "https"]) +async def test_httpx_simple_query(event_loop, protocol): + + from gql.transport.httpx import HTTPXAsyncTransport + + # Create http or https url + url = f"{protocol}://countries.trevorblades.com/graphql" + + # Get transport + sample_transport = HTTPXAsyncTransport(url=url) + + # Instanciate client + async with Client(transport=sample_transport) as session: + + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + # Fetch schema + await session.fetch_schema() + + # Execute query + result = await session.execute(query) + + # Verify result + assert isinstance(result, Dict) + + print(result) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.httpx +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_invalid_query(event_loop): + + from gql.transport.httpx import HTTPXAsyncTransport + + sample_transport = HTTPXAsyncTransport( + url="https://countries.trevorblades.com/graphql" + ) + + async with Client(transport=sample_transport) as session: + + query = gql( + """ + query getContinents { + continents { + code + bloh + } + } + """ + ) + + with pytest.raises(TransportQueryError): + await session.execute(query) + + +@pytest.mark.httpx +@pytest.mark.online +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +@pytest.mark.asyncio +async def test_httpx_two_queries_in_parallel_using_two_tasks(event_loop): + + from gql.transport.httpx import HTTPXAsyncTransport + + sample_transport = HTTPXAsyncTransport( + url="https://countries.trevorblades.com/graphql", + ) + + # Instanciate client + async with Client(transport=sample_transport) as session: + + query1 = gql( + """ + query getContinents { + continents { + code + } + } + """ + ) + + query2 = gql( + """ + query getContinents { + continents { + name + } + } + """ + ) + + async def query_task1(): + result = await session.execute(query1) + + assert isinstance(result, Dict) + + print(result) + + continents = result["continents"] + + africa = continents[0] + assert africa["code"] == "AF" + + async def query_task2(): + result = await session.execute(query2) + + assert isinstance(result, Dict) + + print(result) + + continents = result["continents"] + + africa = continents[0] + assert africa["name"] == "Africa" + + task1 = asyncio.create_task(query_task1()) + task2 = asyncio.create_task(query_task2()) + + await task1 + await task2 From 08e51968cd5d4307adfc5868940b62caaf431c22 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 26 Nov 2022 16:06:18 +0100 Subject: [PATCH 097/239] Update Sphinx dev dependency to 5.3.0 (#371) --- docs/requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b7880231..d9ce8ad1 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -sphinx>=3.0.0,<4 +sphinx>=5.3.0,<6 sphinx_rtd_theme>=0.4,<1 sphinx-argparse==0.2.5 multidict<5.0,>=4.5 diff --git a/setup.py b/setup.py index 30817ec4..1c39d3a1 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "flake8==3.8.1", "isort==4.3.21", "mypy==0.910", - "sphinx>=3.0.0,<4", + "sphinx>=5.3.0,<6", "sphinx_rtd_theme>=0.4,<1", "sphinx-argparse==0.2.5", "types-aiofiles", From f28670da7656d195487cbf6417ac16571a4951f2 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 26 Nov 2022 16:13:57 +0100 Subject: [PATCH 098/239] Bump version number to 3.5.0b0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 36c60f16..53e48109 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0a0" +__version__ = "3.5.0b0" From 2981ce30b8c9a1c3d91dda8977d64ee3488d6ce4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 30 Jan 2023 17:55:55 +0100 Subject: [PATCH 099/239] Don't try to close the aiohttp session if connector_owner is False (#382) --- gql/transport/aiohttp.py | 21 +++++++++++++------ tests/test_aiohttp.py | 44 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2b155870..6dc0a409 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -183,12 +183,21 @@ async def close(self) -> None: log.debug("Closing transport") - closed_event = self.create_aiohttp_closed_event(self.session) - await self.session.close() - try: - await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) - except asyncio.TimeoutError: - pass + if ( + self.client_session_args + and self.client_session_args.get("connector_owner") is False + ): + + log.debug("connector_owner is False -> not closing connector") + + else: + closed_event = self.create_aiohttp_closed_event(self.session) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + self.session = None async def execute( diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 9a62a65c..27af1438 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1441,3 +1441,47 @@ async def handler(request): # Checking that there is no space after the colon in the log expected_log = '"query":"query getContinents' assert expected_log in caplog.text + + +@pytest.mark.asyncio +async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): + from aiohttp import web, TCPConnector + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + connector = TCPConnector() + transport = AIOHTTPTransport( + url=url, + timeout=10, + client_session_args={ + "connector": connector, + "connector_owner": False, + }, + ) + + for _ in range(2): + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await connector.close() From 930fca579220fd4810053491878366fa69862e58 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 16 Feb 2023 18:25:52 +0100 Subject: [PATCH 100/239] Feature allow to set the content-type of file uploads (#386) --- docs/usage/file_upload.rst | 19 +++++++++++ gql/transport/aiohttp.py | 7 ++-- gql/transport/httpx.py | 13 +++++--- gql/transport/requests.py | 10 ++++-- tests/test_aiohttp.py | 68 ++++++++++++++++++++++++++++++++++++++ tests/test_httpx.py | 68 ++++++++++++++++++++++++++++++++++++++ tests/test_requests.py | 68 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 245 insertions(+), 8 deletions(-) diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index f3769d41..10903585 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -42,6 +42,25 @@ In order to upload a single file, you need to: query, variable_values=params, upload_files=True ) +Setting the content-type +^^^^^^^^^^^^^^^^^^^^^^^^ + +If you need to set a specific Content-Type attribute to a file, +you can set the :code:`content_type` attribute of the file like this: + +.. code-block:: python + + with open("YOUR_FILE_PATH", "rb") as f: + + # Setting the content-type to a pdf file for example + f.content_type = "application/pdf" + + params = {"file": f} + + result = client.execute( + query, variable_values=params, upload_files=True + ) + File list --------- diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 6dc0a409..2fd92a72 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -274,8 +274,11 @@ async def execute( data.add_field("map", file_map_str, content_type="application/json") # Add the extracted files as remaining fields - for k, v in file_streams.items(): - data.add_field(k, v, filename=getattr(v, "name", k)) + for k, f in file_streams.items(): + name = getattr(f, "name", k) + content_type = getattr(f, "content_type", None) + + data.add_field(k, f, filename=name, content_type=content_type) post_args: Dict[str, Any] = {"data": data} diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 6e844775..4c1d4f0f 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -103,9 +103,9 @@ def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: # Prepare to send multipart-encoded data data: Dict[str, Any] = {} file_map: Dict[str, List[str]] = {} - file_streams: Dict[str, Tuple[str, Any]] = {} + file_streams: Dict[str, Tuple[str, ...]] = {} - for i, (path, val) in enumerate(files.items()): + for i, (path, f) in enumerate(files.items()): key = str(i) # Generate the file map @@ -117,8 +117,13 @@ def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: # Generate the file streams # Will generate something like # {"0": ("variables.file", <_io.BufferedReader ...>)} - filename = cast(str, getattr(val, "name", key)) - file_streams[key] = (filename, val) + name = cast(str, getattr(f, "name", key)) + content_type = getattr(f, "content_type", None) + + if content_type is None: + file_streams[key] = (name, f) + else: + file_streams[key] = (name, f, content_type) # Add the payload to the operations field operations_str = self.json_serialize(payload) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 690615b4..fa60c38c 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -183,8 +183,14 @@ def execute( # type: ignore fields = {"operations": operations_str, "map": file_map_str} # Add the extracted files as remaining fields - for k, v in file_streams.items(): - fields[k] = (getattr(v, "name", k), v) + for k, f in file_streams.items(): + name = getattr(f, "name", k) + content_type = getattr(f, "content_type", None) + + if content_type is None: + fields[k] = (name, f) + else: + fields[k] = (name, f, content_type) # Prepare requests http to send multipart-encoded data data = MultipartEncoder(fields=fields) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 27af1438..a9b3bda6 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -670,6 +670,74 @@ async def test_aiohttp_file_upload(event_loop, aiohttp_server): assert success +async def single_upload_handler_with_content_type(request): + + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + # Verifying the content_type + assert field_2.headers["Content-Type"] == "application/pdf" + + field_3 = await reader.next() + assert field_3 is None + + return web.Response(text=file_upload_server_answer, content_type="application/json") + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler_with_content_type) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file: + + async with Client(transport=transport) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + # Setting the content_type + f.content_type = "application/pdf" + + params = {"file": f, "other_var": 42} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + @pytest.mark.asyncio async def test_aiohttp_file_upload_without_session( event_loop, aiohttp_server, run_sync_test diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 13f487dd..56a984a4 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -477,6 +477,74 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_with_content_type( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def single_upload_handler(request): + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + # Verifying the content_type + assert field_2.headers["Content-Type"] == "application/pdf" + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with TemporaryFile(file_1_content) as test_file: + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + # Setting the content_type + f.content_type = "application/pdf" + + params = {"file": f, "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_file_upload_additional_headers( diff --git a/tests/test_requests.py b/tests/test_requests.py index 141bb756..a5ff0d8b 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -479,6 +479,74 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_with_content_type( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def single_upload_handler(request): + from aiohttp import web + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + # Verifying the content_type + assert field_2.headers["Content-Type"] == "application/pdf" + + field_3 = await reader.next() + assert field_3 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with TemporaryFile(file_1_content) as test_file: + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + # Setting the content_type + f.content_type = "application/pdf" + + params = {"file": f, "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_file_upload_additional_headers( From 2827d887db4c6951899a8e242af55863328f68a2 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 22 Feb 2023 16:31:36 +0100 Subject: [PATCH 101/239] Bump version number to 3.5.0b1 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 53e48109..94637e71 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b0" +__version__ = "3.5.0b1" From d28ee614ff8a8e281d30c0a19f5c97a1f82f9499 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 23 Feb 2023 17:25:30 +0100 Subject: [PATCH 102/239] TransportQueryError should extend TransportError (#392) --- gql/transport/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 89ae992b..48e9d96b 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -27,7 +27,7 @@ def __init__(self, message: str, code: Optional[int] = None): self.code = code -class TransportQueryError(Exception): +class TransportQueryError(TransportError): """The server returned an error for a specific query. This exception should not close the transport connection. From 9b7b8474a83d4201d113e9c0521764265095eceb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 23 Feb 2023 17:45:11 +0100 Subject: [PATCH 103/239] Update aiohttp minimum dependency version to 3.8.0 (#393) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1c39d3a1..17fda63b 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ ] + tests_requires install_aiohttp_requires = [ - "aiohttp>=3.7.1,<3.9.0", + "aiohttp>=3.8.0,<3.9.0", ] install_requests_requires = [ From 5e37e6a43baa901391592a238aacd688401eb434 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 23 Feb 2023 17:59:15 +0100 Subject: [PATCH 104/239] Bump version number to 3.5.0b2 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 94637e71..2f003410 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b1" +__version__ = "3.5.0b2" From 905b72470733ffe7ddca7e3aba1aba58b41505ec Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 23 Feb 2023 23:00:43 +0100 Subject: [PATCH 105/239] Fix sync subscribe graceful shutdown (#395) --- gql/client.py | 11 +++-- tests/test_websocket_subscription.py | 60 ++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/gql/client.py b/gql/client.py index 69804faa..690c2fce 100644 --- a/gql/client.py +++ b/gql/client.py @@ -593,15 +593,14 @@ def subscribe( except StopAsyncIteration: pass - except (KeyboardInterrupt, Exception): + except (KeyboardInterrupt, Exception, GeneratorExit): + + # Graceful shutdown + asyncio.ensure_future(async_generator.aclose(), loop=loop) - # Graceful shutdown by cancelling the task and waiting clean shutdown generator_task.cancel() - try: - loop.run_until_complete(generator_task) - except (StopAsyncIteration, asyncio.CancelledError): - pass + loop.run_until_complete(loop.shutdown_asyncgens()) # Then reraise the exception raise diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index f1d72dc8..4419783b 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -494,6 +494,66 @@ def test_websocket_subscription_sync(server, subscription_str): assert count == -1 +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_websocket_subscription_sync_user_exception(server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(Exception) as exc_info: + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + raise Exception("This is an user exception") + + assert count == 5 + assert "This is an user exception" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_websocket_subscription_sync_break(server, subscription_str): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + break + + assert count == 5 + + @pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) From 389eb5c6ef869eac50f5a64982aad402260face2 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 23 Feb 2023 23:03:42 +0100 Subject: [PATCH 106/239] Bump version number to 3.5.0b3 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 2f003410..61bc1769 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b2" +__version__ = "3.5.0b3" From 6df7cf916d2f9d59929399a2c42aba4e2fea33d3 Mon Sep 17 00:00:00 2001 From: 0xTiger Date: Thu, 2 Mar 2023 17:27:00 +0000 Subject: [PATCH 107/239] DOC minor typo (#396) --- docs/advanced/logging.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/advanced/logging.rst b/docs/advanced/logging.rst index 7856fa8b..02fdf3fd 100644 --- a/docs/advanced/logging.rst +++ b/docs/advanced/logging.rst @@ -1,7 +1,7 @@ Logging ======= -GQL use the python `logging`_ module. +GQL uses the python `logging`_ module. In order to debug a problem, you can enable logging to see the messages exchanged between the client and the server. To do that, set the loglevel at **INFO** at the beginning of your code: From baa323cb0b298fbf12a0532d39665cd99e498114 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 30 Mar 2023 16:33:46 +0200 Subject: [PATCH 108/239] Allow to configure the introspection query sent to recover the schema (#402) * Implement possibility to change introspection query parameters * Add --schema-download argument to gql-cli --- docs/gql-cli/intro.rst | 10 +++++ docs/usage/validation.rst | 2 +- gql/cli.py | 63 +++++++++++++++++++++++++++- gql/client.py | 16 ++++--- tests/starwars/fixtures.py | 42 +++++++++++++++++++ tests/starwars/schema.py | 5 +++ tests/starwars/test_introspection.py | 62 +++++++++++++++++++++++++++ tests/test_cli.py | 44 +++++++++++++++++++ 8 files changed, 236 insertions(+), 8 deletions(-) create mode 100644 tests/starwars/test_introspection.py diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst index 93f16d32..925958ee 100644 --- a/docs/gql-cli/intro.rst +++ b/docs/gql-cli/intro.rst @@ -78,3 +78,13 @@ Print the GraphQL schema in a file .. code-block:: shell $ gql-cli https://countries.trevorblades.com/graphql --print-schema > schema.graphql + +.. note:: + + By default, deprecated input fields are not requested from the backend. + You can add :code:`--schema-download input_value_deprecation:true` to request them. + +.. note:: + + You can add :code:`--schema-download descriptions:false` to request a compact schema + without comments. diff --git a/docs/usage/validation.rst b/docs/usage/validation.rst index 18b1cda1..f9711f31 100644 --- a/docs/usage/validation.rst +++ b/docs/usage/validation.rst @@ -24,7 +24,7 @@ The schema can be provided as a String (which is usually stored in a .graphql fi .. note:: You can download a schema from a server by using :ref:`gql-cli ` - :code:`$ gql-cli https://SERVER_URL/graphql --print-schema > schema.graphql` + :code:`$ gql-cli https://SERVER_URL/graphql --print-schema --schema-download input_value_deprecation:true > schema.graphql` OR can be created using python classes: diff --git a/gql/cli.py b/gql/cli.py index 2a6ff3f5..dd991546 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -3,7 +3,8 @@ import logging import signal as signal_module import sys -from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +import textwrap +from argparse import ArgumentParser, Namespace, RawTextHelpFormatter from typing import Any, Dict, Optional from graphql import GraphQLError, print_schema @@ -78,7 +79,7 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: parser = ArgumentParser( description=description, epilog=examples if with_examples else None, - formatter_class=RawDescriptionHelpFormatter, + formatter_class=RawTextHelpFormatter, ) parser.add_argument( "server", help="the server url starting with http://, https://, ws:// or wss://" @@ -122,6 +123,27 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: action="store_true", dest="print_schema", ) + parser.add_argument( + "--schema-download", + nargs="*", + help=textwrap.dedent( + """select the introspection query arguments to download the schema. + Only useful if --print-schema is used. + By default, it will: + + - request field descriptions + - not request deprecated input fields + + Possible options: + + - descriptions:false for a compact schema without comments + - input_value_deprecation:true to download deprecated input fields + - specified_by_url:true + - schema_description:true + - directive_is_repeatable:true""" + ), + dest="schema_download", + ) parser.add_argument( "--execute-timeout", help="set the execute_timeout argument of the Client (default: 10)", @@ -362,6 +384,42 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: return None +def get_introspection_args(args: Namespace) -> Dict: + """Get the introspection args depending on the schema_download argument""" + + # Parse the headers argument + introspection_args = {} + + possible_args = [ + "descriptions", + "specified_by_url", + "directive_is_repeatable", + "schema_description", + "input_value_deprecation", + ] + + if args.schema_download is not None: + for arg in args.schema_download: + + try: + # Split only the first colon (throw a ValueError if no colon is present) + arg_key, arg_value = arg.split(":", 1) + + if arg_key not in possible_args: + raise ValueError(f"Invalid schema_download: {args.schema_download}") + + arg_value = arg_value.lower() + if arg_value not in ["true", "false"]: + raise ValueError(f"Invalid schema_download: {args.schema_download}") + + introspection_args[arg_key] = arg_value == "true" + + except ValueError: + raise ValueError(f"Invalid schema_download: {args.schema_download}") + + return introspection_args + + async def main(args: Namespace) -> int: """Main entrypoint of the gql-cli script @@ -395,6 +453,7 @@ async def main(args: Namespace) -> int: async with Client( transport=transport, fetch_schema_from_transport=args.print_schema, + introspection_args=get_introspection_args(args), execute_timeout=args.execute_timeout, ) as session: diff --git a/gql/client.py b/gql/client.py index 690c2fce..f6302987 100644 --- a/gql/client.py +++ b/gql/client.py @@ -76,6 +76,7 @@ def __init__( introspection: Optional[IntrospectionQuery] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, + introspection_args: Optional[Dict] = None, execute_timeout: Optional[Union[int, float]] = 10, serialize_variables: bool = False, parse_results: bool = False, @@ -86,7 +87,9 @@ def __init__( See :ref:`schema_validation` :param transport: The provided :ref:`transport `. :param fetch_schema_from_transport: Boolean to indicate that if we want to fetch - the schema from the transport using an introspection query + the schema from the transport using an introspection query. + :param introspection_args: arguments passed to the get_introspection_query + method of graphql-core. :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. Passing None results in waiting forever for a response. @@ -132,6 +135,9 @@ def __init__( # Flag to indicate that we need to fetch the schema from the transport # On async transports, we fetch the schema before executing the first query self.fetch_schema_from_transport: bool = fetch_schema_from_transport + self.introspection_args = ( + {} if introspection_args is None else introspection_args + ) # Enforced timeout of the execute function (only for async transports) self.execute_timeout = execute_timeout @@ -879,7 +885,8 @@ def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" - execution_result = self.transport.execute(parse(get_introspection_query())) + introspection_query = get_introspection_query(**self.client.introspection_args) + execution_result = self.transport.execute(parse(introspection_query)) self.client._build_schema_from_introspection(execution_result) @@ -1250,9 +1257,8 @@ async def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" - execution_result = await self.transport.execute( - parse(get_introspection_query()) - ) + introspection_query = get_introspection_query(**self.client.introspection_args) + execution_result = await self.transport.execute(parse(introspection_query)) self.client._build_schema_from_introspection(execution_result) diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index efbb1b0e..59d7ddfa 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -144,3 +144,45 @@ def create_review(episode, review): reviews[episode].append(review) review["episode"] = episode return review + + +async def make_starwars_backend(aiohttp_server): + from aiohttp import web + from .schema import StarWarsSchema + from graphql import graphql_sync + + async def handler(request): + data = await request.json() + source = data["query"] + + try: + variables = data["variables"] + except KeyError: + variables = None + + result = graphql_sync(StarWarsSchema, source, variable_values=variables) + + return web.json_response( + { + "data": result.data, + "errors": [str(e) for e in result.errors] if result.errors else None, + } + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + return server + + +async def make_starwars_transport(aiohttp_server): + from gql.transport.aiohttp import AIOHTTPTransport + + server = await make_starwars_backend(aiohttp_server) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + return transport diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index c3db0a3d..5f9a04b4 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -155,6 +155,11 @@ "commentary": GraphQLInputField( GraphQLString, description="Comment about the movie, optional" ), + "deprecated_input_field": GraphQLInputField( + GraphQLString, + description="deprecated field example", + deprecation_reason="deprecated for testing", + ), }, description="The input object sent when someone is creating a new review", ) diff --git a/tests/starwars/test_introspection.py b/tests/starwars/test_introspection.py new file mode 100644 index 00000000..c3063808 --- /dev/null +++ b/tests/starwars/test_introspection.py @@ -0,0 +1,62 @@ +import pytest +from graphql import print_schema + +from gql import Client + +from .fixtures import make_starwars_transport + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +@pytest.mark.asyncio +async def test_starwars_introspection_args(event_loop, aiohttp_server): + + transport = await make_starwars_transport(aiohttp_server) + + # First fetch the schema from transport using default introspection query + # We should receive descriptions in the schema but not deprecated input fields + async with Client( + transport=transport, + fetch_schema_from_transport=True, + ) as session: + + schema_str = print_schema(session.client.schema) + print(schema_str) + + assert '"""The number of stars this review gave, 1-5"""' in schema_str + assert "deprecated_input_field" not in schema_str + + # Then fetch the schema from transport using an introspection query + # without requesting descriptions + # We should NOT receive descriptions in the schema + async with Client( + transport=transport, + fetch_schema_from_transport=True, + introspection_args={ + "descriptions": False, + }, + ) as session: + + schema_str = print_schema(session.client.schema) + print(schema_str) + + assert '"""The number of stars this review gave, 1-5"""' not in schema_str + assert "deprecated_input_field" not in schema_str + + # Then fetch the schema from transport using and introspection query + # requiring deprecated input fields + # We should receive descriptions in the schema and deprecated input fields + async with Client( + transport=transport, + fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": True, + }, + ) as session: + + schema_str = print_schema(session.client.schema) + print(schema_str) + + assert '"""The number of stars this review gave, 1-5"""' in schema_str + assert "deprecated_input_field" in schema_str diff --git a/tests/test_cli.py b/tests/test_cli.py index 359e94fb..f0534957 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,6 +5,7 @@ from gql import __version__ from gql.cli import ( get_execute_args, + get_introspection_args, get_parser, get_transport, get_transport_args, @@ -376,3 +377,46 @@ def test_cli_ep_version(script_runner): assert ret.stdout == f"v{__version__}\n" assert ret.stderr == "" + + +def test_cli_parse_schema_download(parser): + + args = parser.parse_args( + [ + "https://your_server.com", + "--schema-download", + "descriptions:false", + "input_value_deprecation:true", + "specified_by_url:True", + "schema_description:true", + "directive_is_repeatable:true", + "--print-schema", + ] + ) + + introspection_args = get_introspection_args(args) + + expected_args = { + "descriptions": False, + "input_value_deprecation": True, + "specified_by_url": True, + "schema_description": True, + "directive_is_repeatable": True, + } + + assert introspection_args == expected_args + + +@pytest.mark.parametrize( + "invalid_args", + [ + ["https://your_server.com", "--schema-download", "ArgWithoutColon"], + ["https://your_server.com", "--schema-download", "blahblah:true"], + ["https://your_server.com", "--schema-download", "descriptions:invalid_bool"], + ], +) +def test_cli_parse_schema_download_invalid_arg(parser, invalid_args): + args = parser.parse_args(invalid_args) + + with pytest.raises(ValueError): + get_introspection_args(args) From df2b206caaa83bceac417c2bfa3fedf4c118d4df Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 31 Mar 2023 16:41:03 +0200 Subject: [PATCH 109/239] Allow alias on DSLMetaField (#405) --- gql/dsl.py | 6 ------ tests/starwars/test_dsl.py | 24 +++++++++++++++++++----- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 7f09b928..1876742e 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -924,12 +924,6 @@ def __init__(self, name: str): super().__init__(name, self.meta_type, field) - def alias(self, alias: str) -> "DSLSelectableWithAlias": - """ - :meta private: - """ - return self - class DSLInlineFragment(DSLSelectable, DSLFragmentSelector): """DSLInlineFragment represents an inline fragment for the DSL code.""" diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 714e713a..8bbdf0c9 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -870,11 +870,6 @@ def test_invalid_meta_field_selection(ds): metafield = DSLMetaField("__typename") assert metafield.name == "__typename" - # alias does not work - metafield.alias("test") - - assert metafield.name == "__typename" - with pytest.raises(GraphQLError): DSLMetaField("__invalid_meta_field") @@ -936,3 +931,22 @@ def test_get_introspection_query_ast(option): ) assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + + +def test_typename_aliased(ds): + query = """ +hero { + name + typenameField: __typename +} +""".strip() + + query_dsl = ds.Query.hero.select( + ds.Character.name, typenameField=DSLMetaField("__typename") + ) + assert query == str(query_dsl) + + query_dsl = ds.Query.hero.select( + ds.Character.name, DSLMetaField("__typename").alias("typenameField") + ) + assert query == str(query_dsl) From 8e1c6f12b08f2855ecedf6afdb5e3ee01b443c37 Mon Sep 17 00:00:00 2001 From: Marvin Schlegel <5002075+cybniv@users.noreply.github.com> Date: Tue, 25 Apr 2023 14:48:05 +0200 Subject: [PATCH 110/239] Fix typo response_headers in httpx transport (#407) --- gql/transport/httpx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 4c1d4f0f..cfc25dc9 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -32,7 +32,7 @@ class _HTTPXTransport: file_classes: Tuple[Type[Any], ...] = (io.IOBase,) - reponse_headers: Optional[httpx.Headers] = None + response_headers: Optional[httpx.Headers] = None def __init__( self, From f3b0f2661ee204c7a3a4b87ba90e754439442468 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 27 Apr 2023 12:52:08 +0200 Subject: [PATCH 111/239] DSLSchema transform type attribute assert into AttributeError (#409) --- gql/dsl.py | 6 +++++- tests/starwars/test_dsl.py | 39 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 1876742e..adc48bea 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -295,7 +295,11 @@ def __getattr__(self, name: str) -> "DSLType": if type_def is None: raise AttributeError(f"Type '{name}' not found in the schema!") - assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)) + if not isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)): + raise AttributeError( + f'Type "{name} ({type_def!r})" is not valid as an attribute of' + " DSLSchema. Only Object types or Interface types are accepted." + ) return DSLType(type_def, self) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 8bbdf0c9..098a2b50 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -14,6 +14,8 @@ NonNullTypeNode, NullValueNode, Undefined, + build_ast_schema, + parse, print_ast, ) from graphql.utilities import get_introspection_query @@ -774,8 +776,6 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): def test_dsl_root_type_not_default(): - from graphql import parse, build_ast_schema - schema_str = """ schema { query: QueryNotDefault @@ -827,6 +827,41 @@ def test_invalid_type(ds): ds.invalid_type +def test_invalid_type_union(): + schema_str = """ + type FloatValue { + floatValue: Float! + } + + type IntValue { + intValue: Int! + } + + union Value = FloatValue | IntValue + + type Entry { + name: String! + value: Value + } + + type Query { + values: [Entry!]! + } + """ + + schema = build_ast_schema(parse(schema_str)) + ds = DSLSchema(schema) + + with pytest.raises( + AttributeError, + match=( + "Type \"Value \\(\\)\" is not valid as an " + "attribute of DSLSchema. Only Object types or Interface types are accepted." + ), + ): + ds.Value + + def test_hero_name_query_with_typename(ds): query = """ hero { From db6d277219d174be22a342cb17a0e79bc7cc6975 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 27 Apr 2023 13:02:47 +0200 Subject: [PATCH 112/239] Python 3.11 support (#410) --- .github/workflows/tests.yml | 4 +++- setup.py | 1 + tox.ini | 5 +++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 366a953b..77798ada 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "pypy3.8"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "pypy3.8"] os: [ubuntu-latest, windows-latest] exclude: - os: windows-latest @@ -17,6 +17,8 @@ jobs: python-version: "3.9" - os: windows-latest python-version: "3.10" + - os: windows-latest + python-version: "3.11" - os: windows-latest python-version: "pypy3.8" diff --git a/setup.py b/setup.py index 17fda63b..bc11f66c 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,7 @@ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay gql client", diff --git a/tox.ini b/tox.ini index 070b5bf2..df1e81f1 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{37,38,39,310,py3} + py{37,38,39,310,311,py3} [gh-actions] python = @@ -9,6 +9,7 @@ python = 3.8: py38 3.9: py39 3.10: py310 + 3.11: py311 pypy-3: pypy3 [testenv] @@ -27,7 +28,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{37,39,310,py3}: pytest {posargs:tests} + py{37,39,310,311,py3}: pytest {posargs:tests} py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From 0522288db100478fb17955e2ae5bd5346b50cd63 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 5 May 2023 18:44:52 +0200 Subject: [PATCH 113/239] Restrict urllib3 to versions 1.x (#413) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bc11f66c..dbdccb9c 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ install_requests_requires = [ "requests>=2.26,<3", "requests_toolbelt>=0.9.1,<1", - "urllib3>=1.26", + "urllib3>=1.26,<2", ] install_httpx_requires = [ From 218eacb40717703e7d210ac1441ad2885345d30e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 6 May 2023 21:22:21 +0200 Subject: [PATCH 114/239] Switch ubuntu-latest to ubuntu-20.04 to fix github actions --- .github/workflows/deploy.yml | 2 +- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 73778df5..da129836 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -7,7 +7,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0f9f0a07..39f5cf0c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,7 +4,7 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 77798ada..e6d42db7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ jobs: max-parallel: 4 matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "pypy3.8"] - os: [ubuntu-latest, windows-latest] + os: [ubuntu-20.04, windows-latest] exclude: - os: windows-latest python-version: "3.7" @@ -38,7 +38,7 @@ jobs: TOXENV: ${{ matrix.toxenv }} single_extra: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: fail-fast: false matrix: @@ -58,7 +58,7 @@ jobs: run: pytest tests --${{ matrix.dependency }}-only coverage: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 From d140b830cd47fea5513b6bc30b1289c6416d25df Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 6 May 2023 21:34:11 +0200 Subject: [PATCH 115/239] Bump version number to 3.5.0b4 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 61bc1769..c59380f1 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b3" +__version__ = "3.5.0b4" From 66a1c3c2229d38854aeccc24c4daf7660e803b6a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 5 Jun 2023 12:15:45 +0200 Subject: [PATCH 116/239] Bump graphql-core min version to 3.3.0a3 and remove await before subscribe (#417) --- gql/transport/local_schema.py | 2 +- setup.py | 2 +- tests/starwars/test_subscription.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 87395b19..b2423346 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -59,7 +59,7 @@ async def subscribe( The results are sent as an ExecutionResult object """ - subscribe_result = await subscribe(self.schema, document, *args, **kwargs) + subscribe_result = subscribe(self.schema, document, *args, **kwargs) if isinstance(subscribe_result, ExecutionResult): yield subscribe_result diff --git a/setup.py b/setup.py index dbdccb9c..45cf2682 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.3.0a2,<3.4", + "graphql-core>=3.3.0a3,<3.4", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", ] diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 2516701f..c5a50514 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -30,7 +30,7 @@ async def test_subscription_support(): params = {"ep": "JEDI"} expected = [{**review, "episode": "JEDI"} for review in reviews[6]] - ai = await subscribe(StarWarsSchema, subs, variable_values=params) + ai = subscribe(StarWarsSchema, subs, variable_values=params) result = [result.data["reviewAdded"] async for result in ai] From 8b5213433ead11912fad79f5899580406112ed4a Mon Sep 17 00:00:00 2001 From: Hugo Locurcio Date: Wed, 28 Jun 2023 15:03:00 +0200 Subject: [PATCH 117/239] Add quotes to the pip installation command in README (#420) This is required when using certain shells such as zsh, as zsh tries to expand it otherwise, leading to a syntax error. --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 12e34b01..a100e32d 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,10 @@ The complete documentation for GQL can be found at You can install GQL with all the optional dependencies using pip: - pip install gql[all] +```bash +# Quotes may be required on certain shells such as zsh. +pip install "gql[all]" +``` > **NOTE**: See also [the documentation](https://gql.readthedocs.io/en/latest/intro.html#less-dependencies) to install GQL with less extra dependencies depending on the transports you would like to use or for alternative installation methods. From f35970f17cd7290464afdd484a5b133188e3ba9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mert=20Tun=C3=A7?= Date: Thu, 13 Jul 2023 12:52:53 +0300 Subject: [PATCH 118/239] Make retry backoff and status codes customizable (#421) --- gql/transport/requests.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index fa60c38c..6b0bb60b 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,7 +1,7 @@ import io import json import logging -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Collection, Dict, Optional, Tuple, Type, Union import requests from graphql import DocumentNode, ExecutionResult, print_ast @@ -31,6 +31,7 @@ class RequestsHTTPTransport(Transport): """ file_classes: Tuple[Type[Any], ...] = (io.IOBase,) + _default_retry_codes = (429, 500, 502, 503, 504) def __init__( self, @@ -43,6 +44,8 @@ def __init__( verify: Union[bool, str] = True, retries: int = 0, method: str = "POST", + retry_backoff_factor: float = 0.1, + retry_status_forcelist: Collection[int] = _default_retry_codes, **kwargs: Any, ): """Initialize the transport with the given request parameters. @@ -62,6 +65,13 @@ def __init__( to a CA bundle to use. (Default: True). :param retries: Pre-setup of the requests' Session for performing retries :param method: HTTP method used for requests. (Default: POST). + :param retry_backoff_factor: A backoff factor to apply between attempts after + the second try. urllib3 will sleep for: + {backoff factor} * (2 ** ({number of previous retries})) + :param retry_status_forcelist: A set of integer HTTP status codes that we + should force a retry on. A retry is initiated if the request method is + in allowed_methods and the response status code is in status_forcelist. + (Default: [429, 500, 502, 503, 504]) :param kwargs: Optional arguments that ``request`` takes. These can be seen at the `requests`_ source code or the official `docs`_ @@ -77,6 +87,8 @@ def __init__( self.verify = verify self.retries = retries self.method = method + self.retry_backoff_factor = retry_backoff_factor + self.retry_status_forcelist = retry_status_forcelist self.kwargs = kwargs self.session = None @@ -95,8 +107,8 @@ def connect(self): adapter = HTTPAdapter( max_retries=Retry( total=self.retries, - backoff_factor=0.1, - status_forcelist=[500, 502, 503, 504], + backoff_factor=self.retry_backoff_factor, + status_forcelist=self.retry_status_forcelist, allowed_methods=None, ) ) From f7fcaf677cbb98fb99637baab4b7343542a6cd95 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 18 Jul 2023 21:50:03 +0200 Subject: [PATCH 119/239] Bump websockets dependency to allow 11.x versions (#424) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 45cf2682..173c469b 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ ] install_websockets_requires = [ - "websockets>=10,<11", + "websockets>=10,<12", ] install_botocore_requires = [ From f96ae5c2c94b51bdf39ad951c68905f19c17f4cf Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 26 Jul 2023 11:05:58 +0200 Subject: [PATCH 120/239] Adjust aiohttp pin (#425) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 173c469b..993af099 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ ] + tests_requires install_aiohttp_requires = [ - "aiohttp>=3.8.0,<3.9.0", + "aiohttp>=3.8.0,<4", ] install_requests_requires = [ From 013fa6aceb066be843849e83cc72747f51c07d0e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 26 Jul 2023 12:06:47 +0200 Subject: [PATCH 121/239] Bump version number to 3.5.0b5 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index c59380f1..986f222d 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b4" +__version__ = "3.5.0b5" From d4c975198ad617265e4abb4716e5643f31b75e08 Mon Sep 17 00:00:00 2001 From: Ignacio Tolosa Date: Tue, 5 Sep 2023 16:10:41 -0300 Subject: [PATCH 122/239] Add sync batching to requests sync transport (#431) * Add `execute_batch` method for requests sync transport --- gql/__init__.py | 2 + gql/client.py | 173 +++++++- gql/graphql_request.py | 37 ++ gql/transport/aiohttp.py | 2 +- gql/transport/requests.py | 152 ++++++- gql/transport/transport.py | 20 + tests/custom_scalars/test_money.py | 84 +++- .../fixtures/vcr_cassettes/queries_batch.yaml | 385 ++++++++++++++++++ tests/test_client.py | 73 +++- tests/test_graphql_request.py | 202 +++++++++ tests/test_requests_batch.py | 377 +++++++++++++++++ tests/test_transport_batch.py | 151 +++++++ 12 files changed, 1621 insertions(+), 37 deletions(-) create mode 100644 gql/graphql_request.py create mode 100644 tests/fixtures/vcr_cassettes/queries_batch.yaml create mode 100644 tests/test_graphql_request.py create mode 100644 tests/test_requests_batch.py create mode 100644 tests/test_transport_batch.py diff --git a/gql/__init__.py b/gql/__init__.py index a2449700..8eaa0b7c 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -10,9 +10,11 @@ from .__version__ import __version__ from .client import Client from .gql import gql +from .graphql_request import GraphQLRequest __all__ = [ "__version__", "gql", "Client", + "GraphQLRequest", ] diff --git a/gql/client.py b/gql/client.py index f6302987..326442e0 100644 --- a/gql/client.py +++ b/gql/client.py @@ -8,6 +8,7 @@ Callable, Dict, Generator, + List, Optional, TypeVar, Union, @@ -27,6 +28,7 @@ validate, ) +from .graphql_request import GraphQLRequest from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportClosed, TransportQueryError from .transport.local_schema import LocalSchemaTransport @@ -236,6 +238,24 @@ def execute_sync( **kwargs, ) + def execute_batch_sync( + self, + reqs: List[GraphQLRequest], + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """:meta private:""" + with self as session: + return session.execute_batch( + reqs, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + @overload async def execute_async( self, @@ -375,7 +395,6 @@ def execute( """ if isinstance(self.transport, AsyncTransport): - # Get the current asyncio event loop # Or create a new event loop if there isn't one (in a new Thread) try: @@ -418,6 +437,48 @@ def execute( **kwargs, ) + def execute_batch( + self, + reqs: List[GraphQLRequest], + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """Execute multiple GraphQL requests in a batch against the remote server using + the transport provided during init. + + This function **WILL BLOCK** until the result is received from the server. + + Either the transport is sync and we execute the query synchronously directly + OR the transport is async and we execute the query in the asyncio loop + (blocking here until answer). + + This method will: + + - connect using the transport to get a session + - execute the GraphQL requests on the transport session + - close the session and close the connection to the server + + If you want to perform multiple executions, it is better to use + the context manager to keep a session active. + + The extra arguments passed in the method will be passed to the transport + execute method. + """ + + if isinstance(self.transport, AsyncTransport): + raise NotImplementedError("Batching is not implemented for async yet.") + + else: # Sync transports + return self.execute_batch_sync( + reqs, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + @overload def subscribe_async( self, @@ -476,7 +537,6 @@ async def subscribe_async( ]: """:meta private:""" async with self as session: - generator = session.subscribe( document, variable_values=variable_values, @@ -600,7 +660,6 @@ def subscribe( pass except (KeyboardInterrupt, Exception, GeneratorExit): - # Graceful shutdown asyncio.ensure_future(async_generator.aclose(), loop=loop) @@ -661,11 +720,9 @@ async def close_async(self): await self.transport.close() async def __aenter__(self): - return await self.connect_async() async def __aexit__(self, exc_type, exc, tb): - await self.close_async() def connect_sync(self): @@ -705,7 +762,6 @@ def close_sync(self): self.transport.close() def __enter__(self): - return self.connect_sync() def __exit__(self, *args): @@ -880,6 +936,108 @@ def execute( return result.data + def _execute_batch( + self, + reqs: List[GraphQLRequest], + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + **kwargs, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch, using + the sync transport, returning a list of ExecutionResult objects. + + :param reqs: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results argument of the client. + + The extra arguments are passed to the transport execute method.""" + + # Validate document + if self.client.schema: + for req in reqs: + self.client.validate(req.document) + + # Parse variable values for custom scalars if requested + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + reqs = [ + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + for req in reqs + ] + + results = self.transport.execute_batch(reqs, **kwargs) + + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + for result in results: + result.data = parse_result_fn( + self.client.schema, + req.document, + result.data, + operation_name=req.operation_name, + ) + + return results + + def execute_batch( + self, + reqs: List[GraphQLRequest], + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """Execute multiple GraphQL requests in a batch, using + the sync transport. This method sends the requests to the server all at once. + + Raises a TransportQueryError if an error has been returned in any + ExecutionResult. + + :param reqs: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will unserialize the result. + By default use the parse_results argument of the client. + :param get_execution_result: return the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. + + The extra arguments are passed to the transport execute method.""" + + # Validate and execute on the transport + results = self._execute_batch( + reqs, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + + for result in results: + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError( + str_first_element(result.errors), + errors=result.errors, + data=result.data, + extensions=result.extensions, + ) + + assert ( + result.data is not None + ), "Transport returned an ExecutionResult without data or errors" + + if get_execution_result: + return results + + return cast(List[Dict[str, Any]], [result.data for result in results]) + def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. @@ -966,7 +1124,6 @@ async def _subscribe( try: async for result in inner_generator: - if self.client.schema: if parse_result or ( parse_result is None and self.client.parse_results @@ -1070,7 +1227,6 @@ async def subscribe( try: # Validate and subscribe on the transport async for result in inner_generator: - # Raise an error if an error is returned in the ExecutionResult object if result.errors: raise TransportQueryError( @@ -1343,7 +1499,6 @@ async def _connection_loop(self): """ while True: - # Connect to the transport with the retry decorator # By default it should keep retrying until it connect await self._connect_with_retries() diff --git a/gql/graphql_request.py b/gql/graphql_request.py new file mode 100644 index 00000000..b0c68f5c --- /dev/null +++ b/gql/graphql_request.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from graphql import DocumentNode, GraphQLSchema + +from .utilities import serialize_variable_values + + +@dataclass(frozen=True) +class GraphQLRequest: + """GraphQL Request to be executed.""" + + document: DocumentNode + """GraphQL query as AST Node object.""" + + variable_values: Optional[Dict[str, Any]] = None + """Dictionary of input parameters (Default: None).""" + + operation_name: Optional[str] = None + """ + Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + """ + + def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": + assert self.variable_values + + return GraphQLRequest( + document=self.document, + variable_values=serialize_variable_values( + schema=schema, + document=self.document, + variable_values=self.variable_values, + operation_name=self.operation_name, + ), + operation_name=self.operation_name, + ) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2fd92a72..60f42c94 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -205,7 +205,7 @@ async def execute( document: DocumentNode, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - extra_args: Dict[str, Any] = None, + extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute the provided document AST against the configured remote server diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 6b0bb60b..1e464104 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,7 +1,7 @@ import io import json import logging -from typing import Any, Collection, Dict, Optional, Tuple, Type, Union +from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union import requests from graphql import DocumentNode, ExecutionResult, print_ast @@ -12,6 +12,7 @@ from gql.transport import Transport +from ..graphql_request import GraphQLRequest from ..utils import extract_files from .exceptions import ( TransportAlreadyConnected, @@ -96,9 +97,7 @@ def __init__( self.response_headers = None def connect(self): - if self.session is None: - # Creating a session that can later be re-use to configure custom mechanisms self.session = requests.Session() @@ -123,7 +122,7 @@ def execute( # type: ignore variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, timeout: Optional[int] = None, - extra_args: Dict[str, Any] = None, + extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. @@ -275,6 +274,151 @@ def raise_response_error(resp: requests.Response, reason: str): extensions=result.get("extensions"), ) + def execute_batch( # type: ignore + self, + reqs: List[GraphQLRequest], + timeout: Optional[int] = None, + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Execute the provided requests against the configured remote server. This + uses the requests library to perform a HTTP POST request to the remote server. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param timeout: Specifies a default timeout for requests (Default: None). + :param extra_args: additional arguments to send to the requests post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.session: + raise TransportClosed("Transport is not connected") + + # Using the created session to perform requests + response = self.session.request( + self.method, + self.url, + **self._build_batch_post_args(reqs, timeout, extra_args), + ) + self.response_headers = response.headers + + answers = self._extract_response(response) + + self._validate_answer_is_a_list(answers) + self._validate_num_of_answers_same_as_requests(reqs, answers) + self._validate_every_answer_is_a_dict(answers) + self._validate_data_and_errors_keys_in_answers(answers) + + return [self._answer_to_execution_result(answer) for answer in answers] + + def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult: + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) + + def _validate_answer_is_a_list(self, results: Any) -> None: + if not isinstance(results, list): + self._raise_invalid_result( + str(results), + "Answer is not a list", + ) + + def _validate_data_and_errors_keys_in_answers( + self, results: List[Dict[str, Any]] + ) -> None: + for result in results: + if "errors" not in result and "data" not in result: + self._raise_invalid_result( + str(results), + 'No "data" or "errors" keys in answer', + ) + + def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None: + for result in results: + if not isinstance(result, dict): + self._raise_invalid_result(str(results), "Not every answer is dict") + + def _validate_num_of_answers_same_as_requests( + self, + reqs: List[GraphQLRequest], + results: List[Dict[str, Any]], + ) -> None: + if len(reqs) != len(results): + self._raise_invalid_result( + str(results), + "Invalid answer length", + ) + + def _raise_invalid_result(self, result_text: str, reason: str) -> None: + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " + f"{reason}: " + f"{result_text}" + ) + + def _extract_response(self, response: requests.Response) -> Any: + try: + response.raise_for_status() + result = response.json() + + if log.isEnabledFor(logging.INFO): + log.info("<<< %s", response.text) + + except requests.HTTPError as e: + raise TransportServerError(str(e), e.response.status_code) from e + + except Exception: + self._raise_invalid_result(str(response.text), "Not a JSON answer") + + return result + + def _build_batch_post_args( + self, + reqs: List[GraphQLRequest], + timeout: Optional[int] = None, + extra_args: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + post_args: Dict[str, Any] = { + "headers": self.headers, + "auth": self.auth, + "cookies": self.cookies, + "timeout": timeout or self.default_timeout, + "verify": self.verify, + } + + data_key = "json" if self.use_json else "data" + post_args[data_key] = [self._build_data(req) for req in reqs] + + # Log the payload + if log.isEnabledFor(logging.INFO): + log.info(">>> %s", json.dumps(post_args[data_key])) + + # Pass kwargs to requests post method + post_args.update(self.kwargs) + + # Pass post_args to requests post method + if extra_args: + post_args.update(extra_args) + + return post_args + + def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload + def close(self): """Closing the transport by closing the inner session""" if self.session: diff --git a/gql/transport/transport.py b/gql/transport/transport.py index cf5e94da..a5bd7100 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,7 +1,10 @@ import abc +from typing import List from graphql import DocumentNode, ExecutionResult +from ..graphql_request import GraphQLRequest + class Transport(abc.ABC): @abc.abstractmethod @@ -17,6 +20,23 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: "Any Transport subclass must implement execute method" ) # pragma: no cover + def execute_batch( + self, + reqs: List[GraphQLRequest], + *args, + **kwargs, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Execute the provided requests for either a remote or local GraphQL Schema. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :return: a list of ExecutionResult objects + """ + raise NotImplementedError( + "This Transport has not implemented the execute_batch method" + ) # pragma: no cover + def connect(self): """Establish a session with the transport.""" pass # pragma: no cover diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index e67a0bcd..374c70e6 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -3,7 +3,7 @@ from typing import Any, Dict, NamedTuple, Optional import pytest -from graphql import graphql_sync +from graphql import ExecutionResult, graphql_sync from graphql.error import GraphQLError from graphql.language import ValueNode from graphql.pyutils import inspect @@ -20,7 +20,7 @@ ) from graphql.utilities import value_from_ast_untyped -from gql import Client, gql +from gql import Client, GraphQLRequest, gql from gql.transport.exceptions import TransportQueryError from gql.utilities import serialize_value, update_schema_scalar, update_schema_scalars @@ -419,24 +419,45 @@ async def make_money_backend(aiohttp_server): from aiohttp import web async def handler(request): - data = await request.json() - source = data["query"] + req_data = await request.json() - try: - variables = data["variables"] - except KeyError: - variables = None + def handle_single(data: Dict[str, Any]) -> ExecutionResult: + source = data["query"] + try: + variables = data["variables"] + except KeyError: + variables = None - result = graphql_sync( - schema, source, variable_values=variables, root_value=root_value - ) + result = graphql_sync( + schema, source, variable_values=variables, root_value=root_value + ) - return web.json_response( - { - "data": result.data, - "errors": [str(e) for e in result.errors] if result.errors else None, - } - ) + return result + + if isinstance(req_data, list): + results = [handle_single(d) for d in req_data] + + return web.json_response( + [ + { + "data": result.data, + "errors": [str(e) for e in result.errors] + if result.errors + else None, + } + for result in results + ] + ) + else: + result = handle_single(req_data) + return web.json_response( + { + "data": result.data, + "errors": [str(e) for e in result.errors] + if result.errors + else None, + } + ) app = web.Application() app.router.add_route("POST", "/", handler) @@ -736,6 +757,35 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.asyncio +@pytest.mark.requests +async def test_custom_scalar_serialize_variables_sync_transport_2( + event_loop, aiohttp_server, run_sync_test +): + server, transport = await make_sync_money_transport(aiohttp_server) + + def test_code(): + with Client(schema=schema, transport=transport, parse_results=True) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + results = session.execute_batch( + [ + GraphQLRequest(document=query, variable_values=variable_values), + GraphQLRequest(document=query, variable_values=variable_values), + ], + serialize_variables=True, + ) + + print(f"result = {results!r}") + assert results[0]["toEuros"] == 5 + assert results[1]["toEuros"] == 5 + + await run_sync_test(event_loop, server, test_code) + + def test_serialize_value_with_invalid_type(): with pytest.raises(GraphQLError) as exc_info: diff --git a/tests/fixtures/vcr_cassettes/queries_batch.yaml b/tests/fixtures/vcr_cassettes/queries_batch.yaml new file mode 100644 index 00000000..0794cc47 --- /dev/null +++ b/tests/fixtures/vcr_cassettes/queries_batch.yaml @@ -0,0 +1,385 @@ +interactions: +- request: + body: '{"query": "query IntrospectionQuery {\n __schema {\n queryType {\n name\n }\n mutationType + {\n name\n }\n subscriptionType {\n name\n }\n types {\n ...FullType\n }\n directives + {\n name\n description\n locations\n args {\n ...InputValue\n }\n }\n }\n}\n\nfragment + FullType on __Type {\n kind\n name\n description\n fields(includeDeprecated: + true) {\n name\n description\n args {\n ...InputValue\n }\n type + {\n ...TypeRef\n }\n isDeprecated\n deprecationReason\n }\n inputFields + {\n ...InputValue\n }\n interfaces {\n ...TypeRef\n }\n enumValues(includeDeprecated: + true) {\n name\n description\n isDeprecated\n deprecationReason\n }\n possibleTypes + {\n ...TypeRef\n }\n}\n\nfragment InputValue on __InputValue {\n name\n description\n type + {\n ...TypeRef\n }\n defaultValue\n}\n\nfragment TypeRef on __Type {\n kind\n name\n ofType + {\n kind\n name\n ofType {\n kind\n name\n ofType {\n kind\n name\n ofType + {\n kind\n name\n ofType {\n kind\n name\n ofType + {\n kind\n name\n ofType {\n kind\n name\n }\n }\n }\n }\n }\n }\n }\n}"}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '1417' + Content-Type: + - application/json + Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + User-Agent: + - python-requests/2.24.0 + x-csrftoken: + - kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + method: POST + uri: http://127.0.0.1:8000/graphql + response: + body: + string: '{"data":{"__schema":{"queryType":{"name":"Query"},"mutationType":{"name":"Mutation"},"subscriptionType":null,"types":[{"kind":"OBJECT","name":"Query","description":null,"fields":[{"name":"allFilms","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"episodeId_Gt","description":null,"type":{"kind":"SCALAR","name":"Float","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"FilmConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"allSpecies","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"SpecieConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"allCharacters","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PersonConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"allVehicles","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"VehicleConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"allPlanets","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PlanetConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"allStarships","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"StarshipConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"allHeroes","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"HeroConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"film","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Film","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"specie","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Specie","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"character","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Person","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"vehicle","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Vehicle","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"planet","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Planet","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"starship","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Starship","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"hero","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"Hero","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"node","description":"The + ID of the object","args":[{"name":"id","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"defaultValue":null}],"type":{"kind":"INTERFACE","name":"Node","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"viewer","description":null,"args":[],"type":{"kind":"OBJECT","name":"Query","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"_debug","description":null,"args":[],"type":{"kind":"OBJECT","name":"DjangoDebug","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"FilmConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"FilmEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"PageInfo","description":"The + Relay compliant `PageInfo` type, containing data necessary to paginate this + connection.","fields":[{"name":"hasNextPage","description":"When paginating + forwards, are there more items?","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"hasPreviousPage","description":"When + paginating backwards, are there more items?","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"startCursor","description":"When + paginating backwards, the cursor to continue.","args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"endCursor","description":"When + paginating forwards, the cursor to continue.","args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"SCALAR","name":"Boolean","description":"The + `Boolean` scalar type represents `true` or `false`.","fields":null,"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"SCALAR","name":"String","description":"The + `String` scalar type represents textual data, represented as UTF-8 character + sequences. The String type is most often used by GraphQL to represent free-form + human-readable text.","fields":null,"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"FilmEdge","description":"A + Relay edge containing a `Film` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Film","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Film","description":"A + single film.","fields":[{"name":"id","description":"The ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"title","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"episodeId","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Int","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"openingCrawl","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"director","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"releaseDate","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Date","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"characters","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PersonConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"planets","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PlanetConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"starships","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"StarshipConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"vehicles","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"VehicleConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"species","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"SpecieConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"producers","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"INTERFACE","name":"Node","description":"An + object with an ID","fields":[{"name":"id","description":"The ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":[{"kind":"OBJECT","name":"Film","ofType":null},{"kind":"OBJECT","name":"Person","ofType":null},{"kind":"OBJECT","name":"Planet","ofType":null},{"kind":"OBJECT","name":"Specie","ofType":null},{"kind":"OBJECT","name":"Hero","ofType":null},{"kind":"OBJECT","name":"Starship","ofType":null},{"kind":"OBJECT","name":"Vehicle","ofType":null}]},{"kind":"SCALAR","name":"ID","description":"The + `ID` scalar type represents a unique identifier, often used to refetch an + object or as key for a cache. The ID type appears in a JSON response as a + String; however, it is not intended to be human-readable. When expected as + an input type, any string (such as `\"4\"`) or integer (such as `4`) input + value will be accepted as an ID.","fields":null,"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"SCALAR","name":"Int","description":"The + `Int` scalar type represents non-fractional signed whole numeric values. Int + can represent values between -(2^31 - 1) and 2^31 - 1 since represented in + JSON as double-precision floating point numbers specifiedby [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point).","fields":null,"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"SCALAR","name":"Date","description":"The + `Date` scalar type represents a Date\nvalue as specified by\n[iso8601](https://en.wikipedia.org/wiki/ISO_8601).","fields":null,"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"PersonConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"PersonEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"PersonEdge","description":"A + Relay edge containing a `Person` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Person","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Person","description":"An + individual person or character within the Star Wars universe.","fields":[{"name":"id","description":"The + ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"height","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"mass","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"hairColor","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"skinColor","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"eyeColor","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"birthYear","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"gender","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"homeworld","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"Planet","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"species","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"SpecieConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"films","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"episodeId_Gt","description":null,"type":{"kind":"SCALAR","name":"Float","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"FilmConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"starships","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"StarshipConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"vehicles","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"VehicleConnection","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Planet","description":"A + large mass, planet or planetoid in the Star Wars Universe,\nat the time of + 0 ABY.","fields":[{"name":"id","description":"The ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"rotationPeriod","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"orbitalPeriod","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"diameter","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"gravity","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"surfaceWater","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"population","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"speciesSet","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"SpecieConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"films","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"episodeId_Gt","description":null,"type":{"kind":"SCALAR","name":"Float","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"FilmConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"heroes","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name_Startswith","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"name_Contains","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"HeroConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"residents","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PersonConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"climates","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"terrains","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"SpecieConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"SpecieEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"SpecieEdge","description":"A + Relay edge containing a `Specie` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Specie","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Specie","description":"A + type of person or character within the Star Wars Universe.","fields":[{"name":"id","description":"The + ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"classification","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"designation","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"averageHeight","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"averageLifespan","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"homeworld","description":"","args":[],"type":{"kind":"OBJECT","name":"Planet","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"language","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"people","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PersonConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"films","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"episodeId_Gt","description":null,"type":{"kind":"SCALAR","name":"Float","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"FilmConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"eyeColors","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"hairColors","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"skinColors","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"SCALAR","name":"Float","description":"The + `Float` scalar type represents signed double-precision fractional values as + specified by [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point). + ","fields":null,"inputFields":null,"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"HeroConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"HeroEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"HeroEdge","description":"A + Relay edge containing a `Hero` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Hero","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Hero","description":"A + hero created by fans","fields":[{"name":"id","description":"The ID of the + object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"homeworld","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"Planet","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"StarshipConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"StarshipEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"StarshipEdge","description":"A + Relay edge containing a `Starship` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Starship","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Starship","description":"A + single transport craft that has hyperdrive capability.","fields":[{"name":"id","description":"The + ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"model","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"manufacturer","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"costInCredits","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"length","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"maxAtmospheringSpeed","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"crew","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"passengers","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"cargoCapacity","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"consumables","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"hyperdriveRating","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"MGLT","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"starshipClass","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"pilots","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PersonConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"films","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"episodeId_Gt","description":null,"type":{"kind":"SCALAR","name":"Float","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"FilmConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"manufacturers","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"VehicleConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"VehicleEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"VehicleEdge","description":"A + Relay edge containing a `Vehicle` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Vehicle","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Vehicle","description":"A + single transport craft that does not have hyperdrive capability","fields":[{"name":"id","description":"The + ID of the object.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"ID","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"model","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"manufacturer","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"costInCredits","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"length","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"maxAtmospheringSpeed","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"crew","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"passengers","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"cargoCapacity","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"consumables","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"vehicleClass","description":"","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"pilots","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"name","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"PersonConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"films","description":null,"args":[{"name":"before","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"after","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null},{"name":"first","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"last","description":null,"type":{"kind":"SCALAR","name":"Int","ofType":null},"defaultValue":null},{"name":"episodeId_Gt","description":null,"type":{"kind":"SCALAR","name":"Float","ofType":null},"defaultValue":null}],"type":{"kind":"OBJECT","name":"FilmConnection","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"manufacturers","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[{"kind":"INTERFACE","name":"Node","ofType":null}],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"PlanetConnection","description":null,"fields":[{"name":"pageInfo","description":"Pagination + data for this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"PageInfo","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"edges","description":"Contains + the nodes in this connection.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"PlanetEdge","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"totalCount","description":null,"args":[],"type":{"kind":"SCALAR","name":"Int","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"PlanetEdge","description":"A + Relay edge containing a `Planet` and its cursor.","fields":[{"name":"node","description":"The + item at the end of the edge","args":[],"type":{"kind":"OBJECT","name":"Planet","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"cursor","description":"A + cursor for use in pagination","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"DjangoDebug","description":null,"fields":[{"name":"sql","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"OBJECT","name":"DjangoDebugSQL","ofType":null}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"DjangoDebugSQL","description":null,"fields":[{"name":"vendor","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"alias","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"sql","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"duration","description":null,"args":[],"type":{"kind":"SCALAR","name":"Float","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"rawSql","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"params","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"startTime","description":null,"args":[],"type":{"kind":"SCALAR","name":"Float","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"stopTime","description":null,"args":[],"type":{"kind":"SCALAR","name":"Float","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"isSlow","description":null,"args":[],"type":{"kind":"SCALAR","name":"Boolean","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"isSelect","description":null,"args":[],"type":{"kind":"SCALAR","name":"Boolean","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"transId","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"transStatus","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"isoLevel","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"encoding","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"Mutation","description":null,"fields":[{"name":"createHero","description":null,"args":[{"name":"input","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"INPUT_OBJECT","name":"CreateHeroInput","ofType":null}},"defaultValue":null}],"type":{"kind":"OBJECT","name":"CreateHeroPayload","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"CreateHeroPayload","description":null,"fields":[{"name":"hero","description":null,"args":[],"type":{"kind":"OBJECT","name":"Hero","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"ok","description":null,"args":[],"type":{"kind":"SCALAR","name":"Boolean","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"clientMutationId","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"INPUT_OBJECT","name":"CreateHeroInput","description":null,"fields":null,"inputFields":[{"name":"name","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"defaultValue":null},{"name":"homeworldId","description":null,"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"defaultValue":null},{"name":"clientMutationId","description":null,"type":{"kind":"SCALAR","name":"String","ofType":null},"defaultValue":null}],"interfaces":null,"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"__Schema","description":"A + GraphQL Schema defines the capabilities of a GraphQL server. It exposes all + available types and directives on the server, as well as the entry points + for query, mutation and subscription operations.","fields":[{"name":"types","description":"A + list of all types supported by this server.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Type","ofType":null}}}},"isDeprecated":false,"deprecationReason":null},{"name":"queryType","description":"The + type that query operations will be rooted at.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Type","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"mutationType","description":"If + this server supports mutation, the type that mutation operations will be rooted + at.","args":[],"type":{"kind":"OBJECT","name":"__Type","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"subscriptionType","description":"If + this server support subscription, the type that subscription operations will + be rooted at.","args":[],"type":{"kind":"OBJECT","name":"__Type","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"directives","description":"A + list of all directives supported by this server.","args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Directive","ofType":null}}}},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"__Type","description":"The + fundamental unit of any GraphQL Schema is the type. There are many kinds of + types in GraphQL as represented by the `__TypeKind` enum.\n\nDepending on + the kind of a type, certain fields describe information about that type. Scalar + types provide no information beyond a name and description, while Enum types + provide their values. Object and Interface types provide the fields they describe. + Abstract types, Union and Interface, provide the Object types possible at + runtime. List and NonNull types compose other types.","fields":[{"name":"kind","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"ENUM","name":"__TypeKind","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"name","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"description","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"fields","description":null,"args":[{"name":"includeDeprecated","description":null,"type":{"kind":"SCALAR","name":"Boolean","ofType":null},"defaultValue":"false"}],"type":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Field","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"interfaces","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Type","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"possibleTypes","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Type","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"enumValues","description":null,"args":[{"name":"includeDeprecated","description":null,"type":{"kind":"SCALAR","name":"Boolean","ofType":null},"defaultValue":"false"}],"type":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__EnumValue","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"inputFields","description":null,"args":[],"type":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__InputValue","ofType":null}}},"isDeprecated":false,"deprecationReason":null},{"name":"ofType","description":null,"args":[],"type":{"kind":"OBJECT","name":"__Type","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"ENUM","name":"__TypeKind","description":"An + enum describing what kind of type a given `__Type` is","fields":null,"inputFields":null,"interfaces":null,"enumValues":[{"name":"SCALAR","description":"Indicates + this type is a scalar.","isDeprecated":false,"deprecationReason":null},{"name":"OBJECT","description":"Indicates + this type is an object. `fields` and `interfaces` are valid fields.","isDeprecated":false,"deprecationReason":null},{"name":"INTERFACE","description":"Indicates + this type is an interface. `fields` and `possibleTypes` are valid fields.","isDeprecated":false,"deprecationReason":null},{"name":"UNION","description":"Indicates + this type is a union. `possibleTypes` is a valid field.","isDeprecated":false,"deprecationReason":null},{"name":"ENUM","description":"Indicates + this type is an enum. `enumValues` is a valid field.","isDeprecated":false,"deprecationReason":null},{"name":"INPUT_OBJECT","description":"Indicates + this type is an input object. `inputFields` is a valid field.","isDeprecated":false,"deprecationReason":null},{"name":"LIST","description":"Indicates + this type is a list. `ofType` is a valid field.","isDeprecated":false,"deprecationReason":null},{"name":"NON_NULL","description":"Indicates + this type is a non-null. `ofType` is a valid field.","isDeprecated":false,"deprecationReason":null}],"possibleTypes":null},{"kind":"OBJECT","name":"__Field","description":"Object + and Interface types are described by a list of Fields, each of which has a + name, potentially a list of arguments, and a return type.","fields":[{"name":"name","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"description","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"args","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__InputValue","ofType":null}}}},"isDeprecated":false,"deprecationReason":null},{"name":"type","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Type","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"isDeprecated","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"deprecationReason","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"__InputValue","description":"Arguments + provided to Fields or Directives and the input fields of an InputObject are + represented as Input Values which describe their type and optionally a default + value.","fields":[{"name":"name","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"description","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"type","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__Type","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"defaultValue","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"__EnumValue","description":"One + possible value for a given Enum. Enum values are unique values, not a placeholder + for a string or numeric value. However an Enum value is returned in a JSON + response as a string.","fields":[{"name":"name","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"description","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"isDeprecated","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"deprecationReason","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"OBJECT","name":"__Directive","description":"A + Directive provides a way to describe alternate runtime execution and type + validation behavior in a GraphQL document.\n\nIn some cases, you need to provide + options to alter GraphQL''s execution behavior in ways field arguments will + not suffice, such as conditionally including or skipping a field. Directives + provide this by describing additional information to the executor.","fields":[{"name":"name","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"String","ofType":null}},"isDeprecated":false,"deprecationReason":null},{"name":"description","description":null,"args":[],"type":{"kind":"SCALAR","name":"String","ofType":null},"isDeprecated":false,"deprecationReason":null},{"name":"locations","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"ENUM","name":"__DirectiveLocation","ofType":null}}}},"isDeprecated":false,"deprecationReason":null},{"name":"args","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"LIST","name":null,"ofType":{"kind":"NON_NULL","name":null,"ofType":{"kind":"OBJECT","name":"__InputValue","ofType":null}}}},"isDeprecated":false,"deprecationReason":null},{"name":"onOperation","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":true,"deprecationReason":"Use + `locations`."},{"name":"onFragment","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":true,"deprecationReason":"Use + `locations`."},{"name":"onField","description":null,"args":[],"type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"isDeprecated":true,"deprecationReason":"Use + `locations`."}],"inputFields":null,"interfaces":[],"enumValues":null,"possibleTypes":null},{"kind":"ENUM","name":"__DirectiveLocation","description":"A + Directive can be adjacent to many parts of the GraphQL language, a __DirectiveLocation + describes one such possible adjacencies.","fields":null,"inputFields":null,"interfaces":null,"enumValues":[{"name":"QUERY","description":"Location + adjacent to a query operation.","isDeprecated":false,"deprecationReason":null},{"name":"MUTATION","description":"Location + adjacent to a mutation operation.","isDeprecated":false,"deprecationReason":null},{"name":"SUBSCRIPTION","description":"Location + adjacent to a subscription operation.","isDeprecated":false,"deprecationReason":null},{"name":"FIELD","description":"Location + adjacent to a field.","isDeprecated":false,"deprecationReason":null},{"name":"FRAGMENT_DEFINITION","description":"Location + adjacent to a fragment definition.","isDeprecated":false,"deprecationReason":null},{"name":"FRAGMENT_SPREAD","description":"Location + adjacent to a fragment spread.","isDeprecated":false,"deprecationReason":null},{"name":"INLINE_FRAGMENT","description":"Location + adjacent to an inline fragment.","isDeprecated":false,"deprecationReason":null},{"name":"SCHEMA","description":"Location + adjacent to a schema definition.","isDeprecated":false,"deprecationReason":null},{"name":"SCALAR","description":"Location + adjacent to a scalar definition.","isDeprecated":false,"deprecationReason":null},{"name":"OBJECT","description":"Location + adjacent to an object definition.","isDeprecated":false,"deprecationReason":null},{"name":"FIELD_DEFINITION","description":"Location + adjacent to a field definition.","isDeprecated":false,"deprecationReason":null},{"name":"ARGUMENT_DEFINITION","description":"Location + adjacent to an argument definition.","isDeprecated":false,"deprecationReason":null},{"name":"INTERFACE","description":"Location + adjacent to an interface definition.","isDeprecated":false,"deprecationReason":null},{"name":"UNION","description":"Location + adjacent to a union definition.","isDeprecated":false,"deprecationReason":null},{"name":"ENUM","description":"Location + adjacent to an enum definition.","isDeprecated":false,"deprecationReason":null},{"name":"ENUM_VALUE","description":"Location + adjacent to an enum value definition.","isDeprecated":false,"deprecationReason":null},{"name":"INPUT_OBJECT","description":"Location + adjacent to an input object definition.","isDeprecated":false,"deprecationReason":null},{"name":"INPUT_FIELD_DEFINITION","description":"Location + adjacent to an input object field definition.","isDeprecated":false,"deprecationReason":null}],"possibleTypes":null}],"directives":[{"name":"include","description":"Directs + the executor to include this field or fragment only when the `if` argument + is true.","locations":["FIELD","FRAGMENT_SPREAD","INLINE_FRAGMENT"],"args":[{"name":"if","description":"Included + when true.","type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"defaultValue":null}]},{"name":"skip","description":"Directs + the executor to skip this field or fragment when the `if` argument is true.","locations":["FIELD","FRAGMENT_SPREAD","INLINE_FRAGMENT"],"args":[{"name":"if","description":"Skipped + when true.","type":{"kind":"NON_NULL","name":null,"ofType":{"kind":"SCALAR","name":"Boolean","ofType":null}},"defaultValue":null}]}]}}}' + headers: + Content-Length: + - '69553' + Content-Type: + - application/json + Date: + - Fri, 06 Nov 2020 11:30:21 GMT + Server: + - WSGIServer/0.1 Python/2.7.18 + Set-Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + expires=Fri, 05-Nov-2021 11:30:21 GMT; Max-Age=31449600; Path=/ + Vary: + - Cookie + X-Frame-Options: + - SAMEORIGIN + status: + code: 200 + message: OK +- request: + body: '[{"query": "{\n myFavoriteFilm: film(id: \"RmlsbToz\") {\n id\n title\n episodeId\n characters(first: + 5) {\n edges {\n node {\n name\n }\n }\n }\n }\n}"}]' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '204' + Content-Type: + - application/json + Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + User-Agent: + - python-requests/2.24.0 + x-csrftoken: + - kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + method: POST + uri: http://127.0.0.1:8000/graphql + response: + body: + string: '[{"data":{"myFavoriteFilm":{"id":"RmlsbToz","title":"Return of the Jedi","episodeId":6,"characters":{"edges":[{"node":{"name":"Luke + Skywalker"}},{"node":{"name":"C-3PO"}},{"node":{"name":"R2-D2"}},{"node":{"name":"Darth + Vader"}},{"node":{"name":"Leia Organa"}}]}}}}]' + headers: + Content-Length: + - '264' + Content-Type: + - application/json + Date: + - Fri, 06 Nov 2020 11:30:21 GMT + Server: + - WSGIServer/0.1 Python/2.7.18 + Set-Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + expires=Fri, 05-Nov-2021 11:30:21 GMT; Max-Age=31449600; Path=/ + Vary: + - Cookie + X-Frame-Options: + - SAMEORIGIN + status: + code: 200 + message: OK +- request: + body: '[{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}", + "variables": {"id": "UGxhbmV0OjEw"}}]' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '123' + Content-Type: + - application/json + Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + User-Agent: + - python-requests/2.24.0 + x-csrftoken: + - kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + method: POST + uri: http://127.0.0.1:8000/graphql + response: + body: + string: '[{"data":{"planet":{"id":"UGxhbmV0OjEw","name":"Kamino"}}}]' + headers: + Content-Length: + - '57' + Content-Type: + - application/json + Date: + - Fri, 06 Nov 2020 11:30:21 GMT + Server: + - WSGIServer/0.1 Python/2.7.18 + Set-Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + expires=Fri, 05-Nov-2021 11:30:21 GMT; Max-Age=31449600; Path=/ + Vary: + - Cookie + X-Frame-Options: + - SAMEORIGIN + status: + code: 200 + message: OK +- request: + body: '[{"query": "query Planet1 {\n planet(id: \"UGxhbmV0OjEw\") {\n id\n name\n }\n}\n\nquery + Planet2 {\n planet(id: \"UGxhbmV0OjEx\") {\n id\n name\n }\n}", "operationName": + "Planet2"}]' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '197' + Content-Type: + - application/json + Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + User-Agent: + - python-requests/2.24.0 + x-csrftoken: + - kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + method: POST + uri: http://127.0.0.1:8000/graphql + response: + body: + string: '[{"data":{"planet":{"id":"UGxhbmV0OjEx","name":"Geonosis"}}}]' + headers: + Content-Length: + - '59' + Content-Type: + - application/json + Date: + - Fri, 06 Nov 2020 11:30:21 GMT + Server: + - WSGIServer/0.1 Python/2.7.18 + Set-Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + expires=Fri, 05-Nov-2021 11:30:21 GMT; Max-Age=31449600; Path=/ + Vary: + - Cookie + X-Frame-Options: + - SAMEORIGIN + status: + code: 200 + message: OK +- request: + body: '[{"query": "query Planet($id: ID!) {\n planet(id: $id) {\n id\n name\n }\n}"}]' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '86' + Content-Type: + - application/json + Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k + User-Agent: + - python-requests/2.26.0 + authorization: + - xxx-123 + method: POST + uri: http://127.0.0.1:8000/graphql + response: + body: + string: '[{"data":{"planet":{"id":"UGxhbmV0OjEx","name":"Geonosis"}}}]' + headers: + Content-Length: + - '59' + Content-Type: + - application/json + Date: + - Fri, 06 Nov 2020 11:30:21 GMT + Server: + - WSGIServer/0.1 Python/2.7.18 + Set-Cookie: + - csrftoken=kAyQyUjNOGXZfkKUtWtvUROaFfDe2GBiV7yIRsqs3r2j9aYchRDXTNo3lHp72h5k; + expires=Fri, 05-Nov-2021 11:30:21 GMT; Max-Age=31449600; Path=/ + Vary: + - Cookie + X-Frame-Options: + - SAMEORIGIN + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_client.py b/tests/test_client.py index 8b6575d7..2fb333a9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,7 +5,7 @@ import pytest from graphql import build_ast_schema, parse -from gql import Client, gql +from gql import Client, GraphQLRequest, gql from gql.transport import Transport from gql.transport.exceptions import TransportQueryError @@ -34,8 +34,30 @@ def execute(self): with pytest.raises(NotImplementedError) as exc_info: RandomTransport().execute() + assert "Any Transport subclass must implement execute method" == str(exc_info.value) + with pytest.raises(NotImplementedError) as exc_info: + RandomTransport().execute_batch([]) + + assert "This Transport has not implemented the execute_batch method" == str( + exc_info.value + ) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +def test_request_async_execute_batch_not_implemented_yet(): + from gql.transport.aiohttp import AIOHTTPTransport + + transport = AIOHTTPTransport(url="http://localhost/") + client = Client(transport=transport) + + with pytest.raises(NotImplementedError) as exc_info: + client.execute_batch([GraphQLRequest(document=gql("{dummy}"))]) + + assert "Batching is not implemented for async yet." == str(exc_info.value) + @pytest.mark.requests @mock.patch("urllib3.connection.HTTPConnection._new_conn") @@ -76,6 +98,17 @@ def test_retries_on_transport(execute_mock): # means you're actually doing 4 calls. assert execute_mock.call_count == expected_retries + 1 + execute_mock.reset_mock() + queries = map(lambda d: GraphQLRequest(document=d), [query, query, query]) + + with client as session: # We're using the client as context manager + with pytest.raises(Exception): + session.execute_batch(queries) + + # This might look strange compared to the previous test, but making 3 retries + # means you're actually doing 4 calls. + assert execute_mock.call_count == expected_retries + 1 + def test_no_schema_exception(): with pytest.raises(AssertionError) as exc_info: @@ -112,6 +145,10 @@ def test_execute_result_error(): client.execute(failing_query) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) + with pytest.raises(TransportQueryError) as exc_info: + client.execute_batch([GraphQLRequest(document=failing_query)]) + assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) + @pytest.mark.online @pytest.mark.requests @@ -127,7 +164,13 @@ def test_http_transport_raise_for_status_error(http_transport_query): ) as client: with pytest.raises(Exception) as exc_info: client.execute(http_transport_query) - assert "400 Client Error: Bad Request for url" in str(exc_info.value) + + assert "400 Client Error: Bad Request for url" in str(exc_info.value) + + with pytest.raises(Exception) as exc_info: + client.execute_batch([GraphQLRequest(document=http_transport_query)]) + + assert "400 Client Error: Bad Request for url" in str(exc_info.value) @pytest.mark.online @@ -143,8 +186,19 @@ def test_http_transport_verify_error(http_transport_query): ) as client: with pytest.warns(Warning) as record: client.execute(http_transport_query) - assert len(record) == 1 - assert "Unverified HTTPS request is being made to host" in str(record[0].message) + + assert len(record) == 1 + assert "Unverified HTTPS request is being made to host" in str( + record[0].message + ) + + with pytest.warns(Warning) as record: + client.execute_batch([GraphQLRequest(document=http_transport_query)]) + + assert len(record) == 1 + assert "Unverified HTTPS request is being made to host" in str( + record[0].message + ) @pytest.mark.online @@ -159,7 +213,10 @@ def test_http_transport_specify_method_valid(http_transport_query): ) ) as client: result = client.execute(http_transport_query) - assert result is not None + assert result is not None + + result = client.execute_batch([GraphQLRequest(document=http_transport_query)]) + assert result is not None @pytest.mark.online @@ -175,7 +232,11 @@ def test_http_transport_specify_method_invalid(http_transport_query): ) as client: with pytest.raises(Exception) as exc_info: client.execute(http_transport_query) - assert "400 Client Error: Bad Request for url" in str(exc_info.value) + assert "400 Client Error: Bad Request for url" in str(exc_info.value) + + with pytest.raises(Exception) as exc_info: + client.execute_batch([GraphQLRequest(document=http_transport_query)]) + assert "400 Client Error: Bad Request for url" in str(exc_info.value) def test_gql(): diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py new file mode 100644 index 00000000..4c9e7d76 --- /dev/null +++ b/tests/test_graphql_request.py @@ -0,0 +1,202 @@ +import asyncio +from math import isfinite +from typing import Any, Dict, NamedTuple, Optional + +import pytest +from graphql.error import GraphQLError +from graphql.language import ValueNode +from graphql.pyutils import inspect +from graphql.type import ( + GraphQLArgument, + GraphQLField, + GraphQLFloat, + GraphQLList, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, +) +from graphql.utilities import value_from_ast_untyped + +from gql import GraphQLRequest, gql + +from .conftest import MS + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + + +class Money(NamedTuple): + amount: float + currency: str + + +def is_finite(value: Any) -> bool: + """Return true if a value is a finite number.""" + return (isinstance(value, int) and not isinstance(value, bool)) or ( + isinstance(value, float) and isfinite(value) + ) + + +def serialize_money(output_value: Any) -> Dict[str, Any]: + if not isinstance(output_value, Money): + raise GraphQLError("Cannot serialize money value: " + inspect(output_value)) + return output_value._asdict() + + +def parse_money_value(input_value: Any) -> Money: + """Using Money custom scalar from graphql-core tests except here the + input value is supposed to be a dict instead of a Money object.""" + + """ + if isinstance(input_value, Money): + return input_value + """ + + if isinstance(input_value, dict): + amount = input_value.get("amount", None) + currency = input_value.get("currency", None) + + if not is_finite(amount) or not isinstance(currency, str): + raise GraphQLError("Cannot parse money value dict: " + inspect(input_value)) + + return Money(float(amount), currency) + else: + raise GraphQLError("Cannot parse money value: " + inspect(input_value)) + + +def parse_money_literal( + value_node: ValueNode, variables: Optional[Dict[str, Any]] = None +) -> Money: + money = value_from_ast_untyped(value_node, variables) + if variables is not None and ( + # variables are not set when checked with ValuesIOfCorrectTypeRule + not money + or not is_finite(money.get("amount")) + or not isinstance(money.get("currency"), str) + ): + raise GraphQLError("Cannot parse literal money value: " + inspect(money)) + return Money(**money) + + +MoneyScalar = GraphQLScalarType( + name="Money", + serialize=serialize_money, + parse_value=parse_money_value, + parse_literal=parse_money_literal, +) + +root_value = { + "balance": Money(42, "DM"), + "friends_balance": [Money(12, "EUR"), Money(24, "EUR"), Money(150, "DM")], + "countries_balance": { + "Belgium": Money(15000, "EUR"), + "Luxembourg": Money(99999, "EUR"), + }, +} + + +def resolve_balance(root, _info): + return root["balance"] + + +def resolve_friends_balance(root, _info): + return root["friends_balance"] + + +def resolve_countries_balance(root, _info): + return root["countries_balance"] + + +def resolve_belgium_balance(countries_balance, _info): + return countries_balance["Belgium"] + + +def resolve_luxembourg_balance(countries_balance, _info): + return countries_balance["Luxembourg"] + + +def resolve_to_euros(_root, _info, money): + amount = money.amount + currency = money.currency + if not amount or currency == "EUR": + return amount + if currency == "DM": + return amount * 0.5 + raise ValueError("Cannot convert to euros: " + inspect(money)) + + +countriesBalance = GraphQLObjectType( + name="CountriesBalance", + fields={ + "Belgium": GraphQLField( + GraphQLNonNull(MoneyScalar), resolve=resolve_belgium_balance + ), + "Luxembourg": GraphQLField( + GraphQLNonNull(MoneyScalar), resolve=resolve_luxembourg_balance + ), + }, +) + +queryType = GraphQLObjectType( + name="RootQueryType", + fields={ + "balance": GraphQLField(MoneyScalar, resolve=resolve_balance), + "toEuros": GraphQLField( + GraphQLFloat, + args={"money": GraphQLArgument(MoneyScalar)}, + resolve=resolve_to_euros, + ), + "friends_balance": GraphQLField( + GraphQLList(MoneyScalar), resolve=resolve_friends_balance + ), + "countries_balance": GraphQLField( + GraphQLNonNull(countriesBalance), + resolve=resolve_countries_balance, + ), + }, +) + + +def resolve_spent_money(spent_money, _info, **kwargs): + return spent_money + + +async def subscribe_spend_all(_root, _info, money): + while money.amount > 0: + money = Money(money.amount - 1, money.currency) + yield money + await asyncio.sleep(1 * MS) + + +subscriptionType = GraphQLObjectType( + "Subscription", + fields=lambda: { + "spend": GraphQLField( + MoneyScalar, + args={"money": GraphQLArgument(MoneyScalar)}, + subscribe=subscribe_spend_all, + resolve=resolve_spent_money, + ) + }, +) + +schema = GraphQLSchema( + query=queryType, + subscription=subscriptionType, +) + + +def test_serialize_variables_using_money_example(): + req = GraphQLRequest(document=gql("{balance}")) + + money_value = Money(10, "DM") + + req = GraphQLRequest( + document=gql("query myquery($money: Money) {toEuros(money: $money)}"), + variable_values={"money": money_value}, + ) + + req = req.serialize_variable_values(schema) + + assert req.variable_values == {"money": {"amount": 10, "currency": "DM"}} diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py new file mode 100644 index 00000000..23ab1254 --- /dev/null +++ b/tests/test_requests_batch.py @@ -0,0 +1,377 @@ +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +# Marking all tests in this file with the requests marker +pytestmark = pytest.mark.requests + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Execute query synchronously + results = session.execute_batch(query) + + continents = results[0]["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + assert "COOKIE" in request.headers + assert "cookie1=val1" == request.headers["COOKIE"] + + return web.Response( + text=query1_server_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url, cookies={"cookie1": "val1"}) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Execute query synchronously + results = session.execute_batch(query) + + continents = results[0]["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportServerError) as exc_info: + session.execute_batch(query) + + assert "401 Client Error: Unauthorized" in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 429 + return web.Response( + text=""" + + + Too Many Requests + + +

Too Many Requests

+

I only allow 50 requests per hour to this Web site per + logged in user. Try again soon.

+ +""", + content_type="text/html", + status=429, + headers={"Retry-After": "3600"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportServerError) as exc_info: + session.execute_batch(query) + + assert "429, message='Too Many Requests'" in str(exc_info.value) + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["Retry-After"] == "3600" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 500 + raise Exception("Server error") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportServerError): + session.execute_batch(query) + + await run_sync_test(event_loop, server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportQueryError): + session.execute_batch(query) + + await run_sync_test(event_loop, server, test_code) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_requests_invalid_protocol( + event_loop, aiohttp_server, response, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportProtocolError): + session.execute_batch(query) + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_cannot_execute_if_not_connected( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + transport.execute_batch(query) + + await run_sync_test(event_loop, server, test_code) + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_with_extensions( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + execution_results = session.execute_batch(query, get_execution_result=True) + + assert execution_results[0].extensions["key1"] == "val1" + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py new file mode 100644 index 00000000..a9b21e6a --- /dev/null +++ b/tests/test_transport_batch.py @@ -0,0 +1,151 @@ +import os + +import pytest + +from gql import Client, GraphQLRequest, gql + +# We serve https://github.com/graphql-python/swapi-graphene locally: +URL = "http://127.0.0.1:8000/graphql" + +# Marking all tests in this file with the requests marker +pytestmark = pytest.mark.requests + + +def use_cassette(name): + import vcr + + query_vcr = vcr.VCR( + cassette_library_dir=os.path.join( + os.path.dirname(__file__), "fixtures", "vcr_cassettes" + ), + record_mode="new_episodes", + match_on=["uri", "method", "body"], + ) + + return query_vcr.use_cassette(name + ".yaml") + + +@pytest.fixture +def client(): + import requests + from gql.transport.requests import RequestsHTTPTransport + + with use_cassette("client"): + response = requests.get( + URL, headers={"Host": "swapi.graphene-python.org", "Accept": "text/html"} + ) + response.raise_for_status() + csrf = response.cookies["csrftoken"] + + return Client( + transport=RequestsHTTPTransport( + url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} + ), + fetch_schema_from_transport=True, + ) + + +def test_hero_name_query(client): + query = gql( + """ + { + myFavoriteFilm: film(id:"RmlsbToz") { + id + title + episodeId + characters(first:5) { + edges { + node { + name + } + } + } + } + } + """ + ) + expected = [ + { + "myFavoriteFilm": { + "id": "RmlsbToz", + "title": "Return of the Jedi", + "episodeId": 6, + "characters": { + "edges": [ + {"node": {"name": "Luke Skywalker"}}, + {"node": {"name": "C-3PO"}}, + {"node": {"name": "R2-D2"}}, + {"node": {"name": "Darth Vader"}}, + {"node": {"name": "Leia Organa"}}, + ] + }, + } + } + ] + with use_cassette("queries_batch"): + results = client.execute_batch([GraphQLRequest(document=query)]) + assert results == expected + + +def test_query_with_variable(client): + query = gql( + """ + query Planet($id: ID!) { + planet(id: $id) { + id + name + } + } + """ + ) + expected = [{"planet": {"id": "UGxhbmV0OjEw", "name": "Kamino"}}] + with use_cassette("queries_batch"): + results = client.execute_batch( + [GraphQLRequest(document=query, variable_values={"id": "UGxhbmV0OjEw"})] + ) + assert results == expected + + +def test_named_query(client): + query = gql( + """ + query Planet1 { + planet(id: "UGxhbmV0OjEw") { + id + name + } + } + query Planet2 { + planet(id: "UGxhbmV0OjEx") { + id + name + } + } + """ + ) + expected = [{"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}}] + with use_cassette("queries_batch"): + results = client.execute_batch( + [GraphQLRequest(document=query, operation_name="Planet2")] + ) + assert results == expected + + +def test_header_query(client): + query = gql( + """ + query Planet($id: ID!) { + planet(id: $id) { + id + name + } + } + """ + ) + expected = [{"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}}] + with use_cassette("queries_batch"): + results = client.execute_batch( + [GraphQLRequest(document=query)], + extra_args={"headers": {"authorization": "xxx-123"}}, + ) + assert results == expected From dfbcb59957b63a1eee8093603de28411b6457870 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 9 Sep 2023 20:45:33 +0200 Subject: [PATCH 123/239] Validate the argument of the gql function (#435) --- gql/gql.py | 15 +++++++++++---- tests/starwars/test_query.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/gql/gql.py b/gql/gql.py index 903c9609..e35c8045 100644 --- a/gql/gql.py +++ b/gql/gql.py @@ -1,11 +1,13 @@ +from __future__ import annotations + from graphql import DocumentNode, Source, parse -def gql(request_string: str) -> DocumentNode: - """Given a String containing a GraphQL request, parse it into a Document. +def gql(request_string: str | Source) -> DocumentNode: + """Given a string containing a GraphQL request, parse it into a Document. :param request_string: the GraphQL request as a String - :type request_string: str + :type request_string: str | Source :return: a Document which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a @@ -13,5 +15,10 @@ def gql(request_string: str) -> DocumentNode: :raises GraphQLError: if a syntax error is encountered. """ - source = Source(request_string, "GraphQL request") + if isinstance(request_string, Source): + source = request_string + elif isinstance(request_string, str): + source = Source(request_string, "GraphQL request") + else: + raise TypeError("Request must be passed as a string or Source object.") return parse(source) diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 430aa18e..bf15e11a 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,5 +1,5 @@ import pytest -from graphql import GraphQLError +from graphql import GraphQLError, Source from gql import Client, gql from tests.starwars.schema import StarWarsSchema @@ -323,3 +323,17 @@ def test_mutation_result(client): expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} result = client.execute(query, variable_values=params) assert result == expected + + +def test_query_from_source(client): + source = Source("{ hero { name } }") + query = gql(source) + expected = {"hero": {"name": "R2-D2"}} + result = client.execute(query) + assert result == expected + + +def test_already_parsed_query(client): + query = gql("{ hero { name } }") + with pytest.raises(TypeError, match="must be passed as a string"): + gql(query) From ff3082b2c0fb3cc15805897fc8436734af674fdb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 13 Sep 2023 22:43:10 +0200 Subject: [PATCH 124/239] Sync auto batching requests (#436) --- gql/client.py | 295 +++++++++++++++++++++++++++++++--- tests/test_requests_batch.py | 296 +++++++++++++++++++++++++++++++++++ 2 files changed, 566 insertions(+), 25 deletions(-) diff --git a/gql/client.py b/gql/client.py index 326442e0..5c1edffa 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,7 +1,11 @@ import asyncio import logging import sys +import time import warnings +from concurrent.futures import Future +from queue import Queue +from threading import Event, Thread from typing import ( Any, AsyncGenerator, @@ -10,6 +14,7 @@ Generator, List, Optional, + Tuple, TypeVar, Union, cast, @@ -82,6 +87,8 @@ def __init__( execute_timeout: Optional[Union[int, float]] = 10, serialize_variables: bool = False, parse_results: bool = False, + batch_interval: float = 0, + batch_max: int = 10, ): """Initialize the client with the given parameters. @@ -99,6 +106,9 @@ def __init__( serialized. Used for custom scalars and/or enums. Default: False. :param parse_results: Whether gql will try to parse the serialized output sent by the backend. Can be used to unserialize custom scalars or enums. + :param batch_interval: Time to wait in seconds for batching requests together. + Batching is disabled (by default) if 0. + :param batch_max: Maximum number of requests in a single batch. """ if introspection: @@ -146,6 +156,12 @@ def __init__( self.serialize_variables = serialize_variables self.parse_results = parse_results + self.batch_interval = batch_interval + self.batch_max = batch_max + + @property + def batching_enabled(self): + return self.batch_interval != 0 def validate(self, document: DocumentNode): """:meta private:""" @@ -238,9 +254,46 @@ def execute_sync( **kwargs, ) + @overload + def execute_batch_sync( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False], + **kwargs, + ) -> List[Dict[str, Any]]: + ... # pragma: no cover + + @overload + def execute_batch_sync( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs, + ) -> List[ExecutionResult]: + ... # pragma: no cover + + @overload def execute_batch_sync( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover + + def execute_batch_sync( + self, + requests: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -249,7 +302,7 @@ def execute_batch_sync( """:meta private:""" with self as session: return session.execute_batch( - reqs, + requests, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -437,9 +490,46 @@ def execute( **kwargs, ) + @overload def execute_batch( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False], + **kwargs, + ) -> List[Dict[str, Any]]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs, + ) -> List[ExecutionResult]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover + + def execute_batch( + self, + requests: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -472,7 +562,7 @@ def execute_batch( else: # Sync transports return self.execute_batch_sync( - reqs, + requests, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -739,11 +829,11 @@ def connect_sync(self): " Use 'async with Client(...) as session:' instead" ) - self.transport.connect() - if not hasattr(self, "session"): self.session = SyncClientSession(client=self) + self.session.connect() + # Get schema from transport if needed try: if self.fetch_schema_from_transport and not self.schema: @@ -752,14 +842,18 @@ def connect_sync(self): # we don't know what type of exception is thrown here because it # depends on the underlying transport; we just make sure that the # transport is closed and re-raise the exception - self.transport.close() + self.session.close() raise return self.session def close_sync(self): - """Close the sync transport.""" - self.transport.close() + """Close the sync session and the sync transport. + + If batching is enabled, this will block until the remaining queries in the + batching queue have been processed. + """ + self.session.close() def __enter__(self): return self.connect_sync() @@ -818,12 +912,22 @@ def _execute( operation_name=operation_name, ) - result = self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, - ) + if self.client.batching_enabled: + request = GraphQLRequest( + document, + variable_values=variable_values, + operation_name=operation_name, + ) + future_result = self._execute_future(request) + result = future_result.result() + + else: + result = self.transport.execute( + document, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) # Unserialize the result if requested if self.client.schema: @@ -938,40 +1042,45 @@ def execute( def _execute_batch( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, + validate_document: Optional[bool] = True, **kwargs, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch, using the sync transport, returning a list of ExecutionResult objects. - :param reqs: List of requests that will be executed. + :param requests: List of requests that will be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. :param parse_result: Whether gql will unserialize the result. By default use the parse_results argument of the client. + :param validate_document: Whether we still need to validate the document. The extra arguments are passed to the transport execute method.""" # Validate document if self.client.schema: - for req in reqs: - self.client.validate(req.document) + + if validate_document: + for req in requests: + self.client.validate(req.document) # Parse variable values for custom scalars if requested if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - reqs = [ + requests = [ req.serialize_variable_values(self.client.schema) if req.variable_values is not None else req - for req in reqs + for req in requests ] - results = self.transport.execute_batch(reqs, **kwargs) + results = self.transport.execute_batch(requests, **kwargs) # Unserialize the result if requested if self.client.schema: @@ -986,9 +1095,46 @@ def _execute_batch( return results + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False], + **kwargs, + ) -> List[Dict[str, Any]]: + ... # pragma: no cover + + @overload def execute_batch( self, - reqs: List[GraphQLRequest], + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs, + ) -> List[ExecutionResult]: + ... # pragma: no cover + + @overload + def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + ... # pragma: no cover + + def execute_batch( + self, + requests: List[GraphQLRequest], + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -1000,7 +1146,7 @@ def execute_batch( Raises a TransportQueryError if an error has been returned in any ExecutionResult. - :param reqs: List of requests that will be executed. + :param requests: List of requests that will be executed. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1013,7 +1159,7 @@ def execute_batch( # Validate and execute on the transport results = self._execute_batch( - reqs, + requests, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1038,6 +1184,105 @@ def execute_batch( return cast(List[Dict[str, Any]], [result.data for result in results]) + def _batch_loop(self) -> None: + """main loop of the thread used to wait for requests + to execute them in a batch""" + + stop_loop = False + + while not stop_loop: + + # First wait for a first request in from the batch queue + requests_and_futures: List[Tuple[GraphQLRequest, Future]] = [] + request_and_future: Tuple[GraphQLRequest, Future] = self.batch_queue.get() + if request_and_future is None: + break + requests_and_futures.append(request_and_future) + + # Then wait the requested batch interval except if we already + # have the maximum number of requests in the queue + if self.batch_queue.qsize() < self.client.batch_max - 1: + time.sleep(self.client.batch_interval) + + # Then get the requests which had been made during that wait interval + for _ in range(self.client.batch_max - 1): + if self.batch_queue.empty(): + break + request_and_future = self.batch_queue.get() + if request_and_future is None: + stop_loop = True + break + requests_and_futures.append(request_and_future) + + requests = [request for request, _ in requests_and_futures] + futures = [future for _, future in requests_and_futures] + + # Manually execute the requests in a batch + try: + results: List[ExecutionResult] = self._execute_batch( + requests, + serialize_variables=False, # already done + parse_result=False, + validate_document=False, + ) + except Exception as exc: + for future in futures: + future.set_exception(exc) + continue + + # Fill in the future results + for result, future in zip(results, futures): + future.set_result(result) + + # Indicate that the Thread has stopped + self._batch_thread_stopped_event.set() + + def _execute_future( + self, + request: GraphQLRequest, + ) -> Future: + """If batching is enabled, this method will put a request in the batching queue + instead of executing it directly so that the requests could be put in a batch. + """ + + assert hasattr(self, "batch_queue"), "Batching is not enabled" + assert not self._batch_thread_stop_requested, "Batching thread has been stopped" + + future: Future = Future() + self.batch_queue.put((request, future)) + + return future + + def connect(self): + """Connect the transport and initialize the batch threading loop if batching + is enabled.""" + + if self.client.batching_enabled: + self.batch_queue: Queue = Queue() + self._batch_thread_stop_requested = False + self._batch_thread_stopped_event = Event() + self._batch_thread = Thread(target=self._batch_loop, daemon=True) + self._batch_thread.start() + + self.transport.connect() + + def close(self): + """Close the transport and cleanup the batching thread if batching is enabled. + + Will wait until all the remaining requests in the batch processing queue + have been executed. + """ + if hasattr(self, "_batch_thread_stopped_event"): + # Send a None in the queue to indicate that the batching Thread must stop + # after having processed the remaining requests in the queue + self._batch_thread_stop_requested = True + self.batch_queue.put(None) + + # Wait for the Thread to stop + self._batch_thread_stopped_event.wait() + + self.transport.close() + def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 23ab1254..1f922db7 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -30,6 +30,21 @@ '{"code":"SA","name":"South America"}]}}]' ) +query1_server_answer_twice_list = ( + "[" + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' + "]" +) + @pytest.mark.aiohttp @pytest.mark.asyncio @@ -74,6 +89,114 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_auto_batch_enabled( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_query_auto_batch_enabled_two_requests( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + from threading import Thread + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + threads = [] + + with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + def test_thread(): + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + for _ in range(2): + thread = Thread(target=test_thread) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): @@ -148,6 +271,46 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_error_code_401_auto_batch_enabled( + event_loop, aiohttp_server, run_sync_test +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + session.execute(query) + + assert "401 Client Error: Unauthorized" in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): @@ -375,3 +538,136 @@ def test_code(): assert execution_results[0].extensions["key1"] == "val1" await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.online +@pytest.mark.requests +def test_requests_sync_batch_auto(): + + from threading import Thread + from gql.transport.requests import RequestsHTTPTransport + + client = Client( + transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + batch_interval=0.01, + batch_max=3, + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + def get_continent_name(session, continent_code): + variables = { + "continent_code": continent_code, + } + + result = session.execute(query, variable_values=variables) + + name = result["continent"]["name"] + print(f"The continent with the code {continent_code} has the name: '{name}'") + + continent_codes = ["EU", "AF", "NA", "OC", "SA", "AS", "AN"] + + with client as session: + + for continent_code in continent_codes: + + thread = Thread( + target=get_continent_name, + args=( + session, + continent_code, + ), + ) + thread.start() + thread.join() + + # Doing it twice to check that everything is closing and reconnecting correctly + with client as session: + + for continent_code in continent_codes: + + thread = Thread( + target=get_continent_name, + args=( + session, + continent_code, + ), + ) + thread.start() + thread.join() + + +@pytest.mark.online +@pytest.mark.requests +def test_requests_sync_batch_auto_execute_future(): + + from gql.transport.requests import RequestsHTTPTransport + + client = Client( + transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + batch_interval=0.01, + batch_max=3, + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + future_result_eu = session._execute_future(request_eu) + + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + future_result_af = session._execute_future(request_af) + + result_eu = future_result_eu.result().data + result_af = future_result_af.result().data + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" + + +@pytest.mark.online +@pytest.mark.requests +def test_requests_sync_batch_manual(): + + from gql.transport.requests import RequestsHTTPTransport + + client = Client( + transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" From d959ef427fa5364728f11833374fbdebb512a80c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 4 Oct 2023 15:43:18 +0200 Subject: [PATCH 125/239] Bump vcrpy and requests_toolbelt (#440) * Bump vcrpy and remove urllib3 restriction * Restrict vcrpy to 4.4.0 to still support Python 3.7 * Bump requests_toolbelt to 1.0.0 to make it work with urllib3 2.x --- setup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 993af099..ce289fe2 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ "pytest-console-scripts==1.3.1", "pytest-cov==3.0.0", "mock==4.0.2", - "vcrpy==4.0.2", + "vcrpy==4.4.0", "aiofiles", ] @@ -43,8 +43,7 @@ install_requests_requires = [ "requests>=2.26,<3", - "requests_toolbelt>=0.9.1,<1", - "urllib3>=1.26,<2", + "requests_toolbelt>=1.0.0,<2", ] install_httpx_requires = [ From c9395738c7d2846e4c2edc1d69c93c847975cae3 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 4 Oct 2023 15:47:20 +0200 Subject: [PATCH 126/239] Bump version number to 3.5.0b6 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 986f222d..8d4ea956 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b5" +__version__ = "3.5.0b6" From e0bd4979b2af7ac18c315d7546bcea8a2d6e71b6 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 22 Oct 2023 23:13:07 +0200 Subject: [PATCH 127/239] Fix tests with Python 3.12 (#442) * Adding Python 3.12 to automated tests * Bumping `pytest` and `pytest_asyncio` versions * Fixing test cleanup for Python 3.12 * Force aiohttp 3.9.0b0 version for Python version >= 3.12 --- .github/workflows/tests.yml | 4 +++- docs/code_examples/console_async.py | 1 - docs/code_examples/fastapi_async.py | 1 - setup.py | 8 +++++--- tests/conftest.py | 19 ++++++++++--------- tests/starwars/test_dsl.py | 3 ++- tests/test_appsync_websockets.py | 11 ++++++++++- tests/test_websocket_query.py | 2 +- tox.ini | 5 +++-- 9 files changed, 34 insertions(+), 20 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e6d42db7..30e8289c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "pypy3.8"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.8"] os: [ubuntu-20.04, windows-latest] exclude: - os: windows-latest @@ -19,6 +19,8 @@ jobs: python-version: "3.10" - os: windows-latest python-version: "3.11" + - os: windows-latest + python-version: "3.12" - os: windows-latest python-version: "pypy3.8" diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 5391f7bf..9a5e94e5 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -2,7 +2,6 @@ import logging from aioconsole import ainput - from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 3bedd187..80920252 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,7 +10,6 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse - from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/setup.py b/setup.py index ce289fe2..eb215b53 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,8 @@ tests_requires = [ "parse==1.15.0", - "pytest==6.2.5", - "pytest-asyncio==0.16.0", + "pytest==7.4.2", + "pytest-asyncio==0.21.1", "pytest-console-scripts==1.3.1", "pytest-cov==3.0.0", "mock==4.0.2", @@ -38,7 +38,8 @@ ] + tests_requires install_aiohttp_requires = [ - "aiohttp>=3.8.0,<4", + "aiohttp>=3.8.0,<4;python_version<='3.11'", + "aiohttp>=3.9.0b0,<4;python_version>'3.11'", ] install_requests_requires = [ @@ -89,6 +90,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay gql client", diff --git a/tests/conftest.py b/tests/conftest.py index b880cff4..30c0d6f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from typing import Union import pytest +import pytest_asyncio from gql import Client @@ -101,13 +102,13 @@ async def go(app, *, port=None, **kwargs): # type: ignore await servers.pop().close() -@pytest.fixture +@pytest_asyncio.fixture async def aiohttp_server(): async for server in aiohttp_server_base(): yield server -@pytest.fixture +@pytest_asyncio.fixture async def ssl_aiohttp_server(): async for server in aiohttp_server_base(with_ssl=True): yield server @@ -203,7 +204,7 @@ async def stop(self): try: await asyncio.wait_for(self.server.wait_closed(), timeout=5) except asyncio.TimeoutError: # pragma: no cover - assert False, "Server failed to stop" + pass print("Server stopped\n\n\n") @@ -349,7 +350,7 @@ async def default_server_handler(ws, path): return server_handler -@pytest.fixture +@pytest_asyncio.fixture async def ws_ssl_server(request): """Websockets server fixture using SSL. @@ -372,7 +373,7 @@ async def ws_ssl_server(request): await test_server.stop() -@pytest.fixture +@pytest_asyncio.fixture async def server(request): """Fixture used to start a dummy server to test the client behaviour. @@ -395,7 +396,7 @@ async def server(request): await test_server.stop() -@pytest.fixture +@pytest_asyncio.fixture async def graphqlws_server(request): """Fixture used to start a dummy server with the graphql-ws protocol. @@ -443,7 +444,7 @@ def process_subprotocol(self, headers, available_subprotocols): await test_server.stop() -@pytest.fixture +@pytest_asyncio.fixture async def client_and_server(server): """Helper fixture to start a server and a client connected to its port.""" @@ -460,7 +461,7 @@ async def client_and_server(server): yield session, server -@pytest.fixture +@pytest_asyncio.fixture async def client_and_graphqlws_server(graphqlws_server): """Helper fixture to start a server with the graphql-ws prototocol and a client connected to its port.""" @@ -481,7 +482,7 @@ async def client_and_graphqlws_server(graphqlws_server): yield session, graphqlws_server -@pytest.fixture +@pytest_asyncio.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): """This function will run the test in a different Thread. diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 098a2b50..9dc87910 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -138,7 +138,8 @@ def test_use_variable_definition_multiple_times(ds): assert ( print_ast(query) - == """mutation ($badReview: ReviewInput, $episode: Episode, $goodReview: ReviewInput) { + == """mutation \ +($badReview: ReviewInput, $episode: Episode, $goodReview: ReviewInput) { badReview: createReview(review: $badReview, episode: $episode) { stars commentary diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 62816cc9..14c40e75 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -333,7 +333,16 @@ async def receiving_coro(): print(f"\n Server: Exception received: {e!s}\n") finally: print(" Server: waiting for websocket connection to close") - await ws.wait_closed() + try: + await asyncio.wait_for(ws.wait_closed(), 1000 * MS) + except asyncio.TimeoutError: + pass + + try: + await asyncio.wait_for(ws.close(), 1000 * MS) + except asyncio.TimeoutError: + pass + print(" Server: connection closed") return realtime_appsync_server_template diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index f39409f5..e8b7a022 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -382,7 +382,7 @@ async def server_with_authentication_in_connection_init_payload(ws, path): '{"type":"connection_error", "payload": "No Authorization token"}' ) - await ws.wait_closed() + await ws.close() @pytest.mark.asyncio diff --git a/tox.ini b/tox.ini index df1e81f1..e4794be5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{37,38,39,310,311,py3} + py{37,38,39,310,311,312,py3} [gh-actions] python = @@ -10,6 +10,7 @@ python = 3.9: py39 3.10: py310 3.11: py311 + 3.12: py312 pypy-3: pypy3 [testenv] @@ -28,7 +29,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{37,39,310,311,py3}: pytest {posargs:tests} + py{37,39,310,311,312,py3}: pytest {posargs:tests} py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From f273a2bb3a005dae6e36a350db9e92b51a8439d0 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 23 Oct 2023 16:50:38 +0200 Subject: [PATCH 128/239] DOC explain how to install pre-releases versions --- docs/intro.rst | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/intro.rst b/docs/intro.rst index f7a4b71d..8f59ed16 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -12,7 +12,11 @@ Installation You can install GQL 3 and all the extra dependencies using pip_:: - pip install gql[all] + pip install "gql[all]" + +To have the latest pre-releases versions of gql, you can use:: + + pip install --pre "gql[all]" After installation, you can start using GQL by importing from the top-level :mod:`gql` package. @@ -70,6 +74,11 @@ To install gql with less dependencies, you might want to instead install a combi following packages: :code:`gql-with-aiohttp`, :code:`gql-with-websockets`, :code:`gql-with-requests`, :code:`gql-with-botocore` +If you want to have the latest pre-releases version of gql and graphql-core, you can install +them with conda using:: + + conda install -c conda-forge -c conda-forge/label/graphql_core_alpha -c conda-forge/label/gql_beta gql-with-all + Reporting Issues and Contributing --------------------------------- From 87fac0f6093a31aee5ee7133a80ef1df2f321834 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 23 Oct 2023 16:59:09 +0200 Subject: [PATCH 129/239] readthedocs remove system_packages: true in config See https://blog.readthedocs.com/drop-support-system-packages/ --- .readthedocs.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 749771cf..63eed863 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -26,4 +26,3 @@ python: path: . extra_requirements: - all - system_packages: true From ff6352bd19ff22a36eef94b69e88a0a5281f8170 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 14 Nov 2023 18:59:33 +0100 Subject: [PATCH 130/239] Fix parse_results with fragments (#446) * Fix issue #445 --- gql/utilities/parse_result.py | 9 ++++--- tests/starwars/test_parse_results.py | 36 ++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index ede627ae..02355425 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -193,7 +193,7 @@ def enter_field( # Key not found in result. # Should never happen in theory with a correct GraphQL backend # Silently ignoring this field - log.debug(f"Key {name} not found in result --> REMOVE") + log.debug(f" Key {name} not found in result --> REMOVE") return REMOVE log.debug(f" result_value={result_value}") @@ -232,8 +232,11 @@ def enter_field( ) # Get parent SelectionSet node - new_node = ancestors[-1] - assert isinstance(new_node, SelectionSetNode) + selection_set_node = ancestors[-1] + assert isinstance(selection_set_node, SelectionSetNode) + + # Keep only the current node in a new selection set node + new_node = SelectionSetNode(selections=[node]) for item in result_value: diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py index 23073839..e8f3f8d4 100644 --- a/tests/starwars/test_parse_results.py +++ b/tests/starwars/test_parse_results.py @@ -37,6 +37,42 @@ def test_hero_name_and_friends_query(): assert result == parsed_result +def test_hero_name_and_friends_query_with_fragment(): + """Testing for issue #445""" + + query = gql( + """ + query HeroNameAndFriendsQuery { + hero { + ...HeroSummary + friends { + name + } + } + } + fragment HeroSummary on Character { + id + name + } + """ + ) + result = { + "hero": { + "id": "2001", + "friends": [ + {"name": "Luke Skywalker"}, + {"name": "Han Solo"}, + {"name": "Leia Organa"}, + ], + "name": "R2-D2", + } + } + + parsed_result = parse_result(StarWarsSchema, query, result) + + assert result == parsed_result + + def test_key_not_found_in_result(): query = gql( From c5a164c26ffa29ae1bc3803708a485b8557207b4 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 14 Nov 2023 19:03:42 +0100 Subject: [PATCH 131/239] Bump version number to 3.5.0b7 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 8d4ea956..a8c6dffe 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b6" +__version__ = "3.5.0b7" From 632ec966f04cd50d29e1225a96bb2aaacf38ce52 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 19 Nov 2023 17:34:35 +0100 Subject: [PATCH 132/239] Fix missing empty directives in DSL nodes (#448) --- gql/dsl.py | 13 +++- .../test_dsl_directives.py | 59 +++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py diff --git a/gql/dsl.py b/gql/dsl.py index adc48bea..0c834b33 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -503,6 +503,7 @@ def executable_ast(self) -> OperationDefinitionNode: selection_set=self.selection_set, variable_definitions=self.variable_definitions.get_ast_definitions(), **({"name": NameNode(value=self.name)} if self.name else {}), + directives=(), ) def __repr__(self) -> str: @@ -597,6 +598,7 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: default_value=None if var.default_value is None else ast_from_value(var.default_value, var.type), + directives=(), ) for var in self.variables.values() if var.type is not None # only variables used @@ -818,7 +820,11 @@ def __init__( """ self.parent_type = parent_type self.field = field - self.ast_field = FieldNode(name=NameNode(value=name), arguments=()) + self.ast_field = FieldNode( + name=NameNode(value=name), + arguments=(), + directives=(), + ) self.dsl_type = dsl_type log.debug(f"Creating {self!r}") @@ -950,7 +956,7 @@ def __init__( log.debug(f"Creating {self!r}") - self.ast_field = InlineFragmentNode() + self.ast_field = InlineFragmentNode(directives=()) DSLSelector.__init__(self, *fields, **fields_with_alias) @@ -1018,7 +1024,7 @@ def ast_field(self) -> FragmentSpreadNode: # type: ignore `issue #4125 of mypy `_. """ - spread_node = FragmentSpreadNode() + spread_node = FragmentSpreadNode(directives=()) spread_node.name = NameNode(value=self.name) return spread_node @@ -1067,6 +1073,7 @@ def executable_ast(self) -> FragmentDefinitionNode: selection_set=self.selection_set, variable_definitions=self.variable_definitions.get_ast_definitions(), name=NameNode(value=self.name), + directives=(), ) def __repr__(self) -> str: diff --git a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py new file mode 100644 index 00000000..61cc21e9 --- /dev/null +++ b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py @@ -0,0 +1,59 @@ +from gql import Client +from gql.dsl import DSLFragment, DSLQuery, DSLSchema, dsl_gql + +schema_str = """ +type MonsterForm { + sprites: MonsterFormSprites! +} + +union SpriteUnion = Sprite | CopyOf + +type Query { + monster: [Monster!]! +} + +type MonsterFormSprites { + actions: [SpriteUnion!]! +} + +type CopyOf { + action: String! +} + +type Monster { + manual(path: String!): MonsterForm +} + +type Sprite { + action: String! +} +""" + + +def test_issue_447(): + + client = Client(schema=schema_str) + ds = DSLSchema(client.schema) + + sprite = DSLFragment("SpriteUnionAsSprite") + sprite.on(ds.Sprite) + sprite.select( + ds.Sprite.action, + ) + copy_of = DSLFragment("SpriteUnionAsCopyOf") + copy_of.on(ds.CopyOf) + copy_of.select( + ds.CopyOf.action, + ) + + query = ds.Query.monster.select( + ds.Monster.manual(path="").select( + ds.MonsterForm.sprites.select( + ds.MonsterFormSprites.actions.select(sprite, copy_of), + ), + ), + ) + + q = dsl_gql(sprite, copy_of, DSLQuery(query)) + + client.validate(q) From 3a48a2f26975f6af35474385f5e3a5ee16253f28 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 19 Nov 2023 17:37:46 +0100 Subject: [PATCH 133/239] Bump version number to 3.5.0b8 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index a8c6dffe..fdaa43c1 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b7" +__version__ = "3.5.0b8" From a2f327fef4ebdcb6899279b41e0ff715f2b0858a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 21 Nov 2023 01:19:41 +0100 Subject: [PATCH 134/239] Adding node_tree method in utilities to debug and compare DocumentNode instances (#449) DSL: Set variable_definitions to None for Fragments by default instead of empty tuple --- gql/dsl.py | 18 ++- gql/utilities/__init__.py | 2 + gql/utilities/node_tree.py | 89 +++++++++++ .../test_dsl_directives.py | 19 ++- tests/starwars/test_dsl.py | 146 +++++++++++++++++- 5 files changed, 270 insertions(+), 4 deletions(-) create mode 100644 gql/utilities/node_tree.py diff --git a/gql/dsl.py b/gql/dsl.py index 0c834b33..536a8b8b 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1068,10 +1068,26 @@ def executable_ast(self) -> FragmentDefinitionNode: "Missing type condition. Please use .on(type_condition) method" ) + fragment_variable_definitions = self.variable_definitions.get_ast_definitions() + + if len(fragment_variable_definitions) == 0: + """Fragment variable definitions are obsolete and only supported on + graphql-core if the Parser is initialized with: + allow_legacy_fragment_variables=True. + + We will not provide variable_definitions instead of providing an empty + tuple to be coherent with how it works by default on graphql-core. + """ + variable_definition_kwargs = {} + else: + variable_definition_kwargs = { + "variable_definitions": fragment_variable_definitions + } + return FragmentDefinitionNode( type_condition=NamedTypeNode(name=NameNode(value=self._type.name)), selection_set=self.selection_set, - variable_definitions=self.variable_definitions.get_ast_definitions(), + **variable_definition_kwargs, name=NameNode(value=self.name), directives=(), ) diff --git a/gql/utilities/__init__.py b/gql/utilities/__init__.py index 3d29dfe3..302c226a 100644 --- a/gql/utilities/__init__.py +++ b/gql/utilities/__init__.py @@ -1,5 +1,6 @@ from .build_client_schema import build_client_schema from .get_introspection_query_ast import get_introspection_query_ast +from .node_tree import node_tree from .parse_result import parse_result from .serialize_variable_values import serialize_value, serialize_variable_values from .update_schema_enum import update_schema_enum @@ -7,6 +8,7 @@ __all__ = [ "build_client_schema", + "node_tree", "parse_result", "get_introspection_query_ast", "serialize_variable_values", diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py new file mode 100644 index 00000000..c307d937 --- /dev/null +++ b/gql/utilities/node_tree.py @@ -0,0 +1,89 @@ +from typing import Any, Iterable, List, Optional, Sized + +from graphql import Node + + +def _node_tree_recursive( + obj: Any, + *, + indent: int = 0, + ignored_keys: List, +): + + assert ignored_keys is not None + + results = [] + + if hasattr(obj, "__slots__"): + + results.append(" " * indent + f"{type(obj).__name__}") + + try: + keys = obj.keys + except AttributeError: + # If the object has no keys attribute, print its repr and return. + results.append(" " * (indent + 1) + repr(obj)) + else: + for key in keys: + if key in ignored_keys: + continue + attr_value = getattr(obj, key, None) + results.append(" " * (indent + 1) + f"{key}:") + if isinstance(attr_value, Iterable) and not isinstance( + attr_value, (str, bytes) + ): + if isinstance(attr_value, Sized) and len(attr_value) == 0: + results.append( + " " * (indent + 2) + f"empty {type(attr_value).__name__}" + ) + else: + for item in attr_value: + results.append( + _node_tree_recursive( + item, + indent=indent + 2, + ignored_keys=ignored_keys, + ) + ) + else: + results.append( + _node_tree_recursive( + attr_value, + indent=indent + 2, + ignored_keys=ignored_keys, + ) + ) + else: + results.append(" " * indent + repr(obj)) + + return "\n".join(results) + + +def node_tree( + obj: Node, + *, + ignore_loc: bool = True, + ignore_block: bool = True, + ignored_keys: Optional[List] = None, +): + """Method which returns a tree of Node elements as a String. + + Useful to debug deep DocumentNode instances created by gql or dsl_gql. + + WARNING: the output of this method is not guaranteed and may change without notice. + """ + + assert isinstance(obj, Node) + + if ignored_keys is None: + ignored_keys = [] + + if ignore_loc: + # We are ignoring loc attributes by default + ignored_keys.append("loc") + + if ignore_block: + # We are ignoring block attributes by default (in StringValueNode) + ignored_keys.append("block") + + return _node_tree_recursive(obj, ignored_keys=ignored_keys) diff --git a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py index 61cc21e9..b31ade7f 100644 --- a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py +++ b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py @@ -1,5 +1,6 @@ -from gql import Client -from gql.dsl import DSLFragment, DSLQuery, DSLSchema, dsl_gql +from gql import Client, gql +from gql.dsl import DSLFragment, DSLQuery, DSLSchema, dsl_gql, print_ast +from gql.utilities import node_tree schema_str = """ type MonsterForm { @@ -57,3 +58,17 @@ def test_issue_447(): q = dsl_gql(sprite, copy_of, DSLQuery(query)) client.validate(q) + + # Creating a tree from the DocumentNode created by dsl_gql + dsl_tree = node_tree(q) + + # Creating a tree from the DocumentNode created by gql + gql_tree = node_tree(gql(print_ast(q))) + + print("=======") + print(dsl_tree) + print("+++++++") + print(gql_tree) + print("=======") + + assert dsl_tree == gql_tree diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 9dc87910..4860e3a0 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -35,7 +35,7 @@ ast_from_value, dsl_gql, ) -from gql.utilities import get_introspection_query_ast +from gql.utilities import get_introspection_query_ast, node_tree from .schema import StarWarsSchema @@ -151,6 +151,8 @@ def test_use_variable_definition_multiple_times(ds): }""" ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_add_variable_definitions(ds): var = DSLVariableDefinitions() @@ -172,6 +174,8 @@ def test_add_variable_definitions(ds): }""" ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_add_variable_definitions_with_default_value_enum(ds): var = DSLVariableDefinitions() @@ -216,6 +220,8 @@ def test_add_variable_definitions_with_default_value_input_object(ds): }""".strip() ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_add_variable_definitions_in_input_object(ds): var = DSLVariableDefinitions() @@ -241,6 +247,8 @@ def test_add_variable_definitions_in_input_object(ds): }""" ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_invalid_field_on_type_query(ds): with pytest.raises(AttributeError) as exc_info: @@ -402,6 +410,7 @@ def test_hero_name_query_result(ds, client): result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected + assert node_tree(query) == node_tree(gql(print_ast(query))) def test_arg_serializer_list(ds, client): @@ -421,6 +430,7 @@ def test_arg_serializer_list(ds, client): ] } assert result == expected + assert node_tree(query) == node_tree(gql(print_ast(query))) def test_arg_serializer_enum(ds, client): @@ -428,6 +438,7 @@ def test_arg_serializer_enum(ds, client): result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected + assert node_tree(query) == node_tree(gql(print_ast(query))) def test_create_review_mutation_result(ds, client): @@ -442,6 +453,7 @@ def test_create_review_mutation_result(ds, client): result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} assert result == expected + assert node_tree(query) == node_tree(gql(print_ast(query))) def test_subscription(ds): @@ -463,6 +475,8 @@ def test_subscription(ds): }""" ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_field_does_not_exit_in_type(ds): with pytest.raises( @@ -502,6 +516,7 @@ def test_multiple_root_fields(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected + assert node_tree(query) == node_tree(gql(print_ast(query))) def test_root_fields_aliased(ds, client): @@ -517,6 +532,7 @@ def test_root_fields_aliased(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected + assert node_tree(query) == node_tree(gql(print_ast(query))) def test_operation_name(ds): @@ -535,6 +551,8 @@ def test_operation_name(ds): }""" ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_multiple_operations(ds): query = dsl_gql( @@ -565,6 +583,8 @@ def test_multiple_operations(ds): }""" ) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_inline_fragments(ds): query = """hero(episode: JEDI) { @@ -635,6 +655,7 @@ def test_fragments(ds): print(print_ast(document)) assert query == print_ast(document) + assert node_tree(document) == node_tree(gql(print_ast(document))) def test_fragment_without_type_condition_error(ds): @@ -731,6 +752,7 @@ def test_dsl_nested_query_with_fragment(ds): print(print_ast(document)) assert query == print_ast(document) + assert node_tree(document) == node_tree(gql(print_ast(document))) # Same thing, but incrementaly @@ -756,6 +778,7 @@ def test_dsl_nested_query_with_fragment(ds): print(print_ast(document)) assert query == print_ast(document) + assert node_tree(document) == node_tree(gql(print_ast(document))) def test_dsl_query_all_fields_should_be_instances_of_DSLField(): @@ -808,6 +831,8 @@ def test_dsl_root_type_not_default(): "Invalid field for : " ) in str(excinfo.value) + assert node_tree(query) == node_tree(gql(print_ast(query))) + def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( @@ -967,6 +992,9 @@ def test_get_introspection_query_ast(option): ) assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert node_tree(dsl_introspection_query) == node_tree( + gql(print_ast(dsl_introspection_query)) + ) def test_typename_aliased(ds): @@ -986,3 +1014,119 @@ def test_typename_aliased(ds): ds.Character.name, DSLMetaField("__typename").alias("typenameField") ) assert query == str(query_dsl) + + +def test_node_tree_with_loc(ds): + query = """query GetHeroName { + hero { + name + } +}""".strip() + + document = gql(query) + + node_tree_result = """ +DocumentNode + loc: + Location + + definitions: + OperationDefinitionNode + loc: + Location + + name: + NameNode + loc: + Location + + value: + 'GetHeroName' + directives: + empty tuple + variable_definitions: + empty tuple + selection_set: + SelectionSetNode + loc: + Location + + selections: + FieldNode + loc: + Location + + directives: + empty tuple + alias: + None + name: + NameNode + loc: + Location + + value: + 'hero' + arguments: + empty tuple + nullability_assertion: + None + selection_set: + SelectionSetNode + loc: + Location + + selections: + FieldNode + loc: + Location + + directives: + empty tuple + alias: + None + name: + NameNode + loc: + Location + + value: + 'name' + arguments: + empty tuple + nullability_assertion: + None + selection_set: + None + operation: + +""".strip() + + assert node_tree(document, ignore_loc=False) == node_tree_result + + +def test_legacy_fragment_with_variables(ds): + var = DSLVariableDefinitions() + + hero_fragment = ( + DSLFragment("heroFragment") + .on(ds.Query) + .select( + ds.Query.hero.args(episode=var.episode).select(ds.Character.name), + ) + ) + + print(hero_fragment) + + hero_fragment.variable_definitions = var + + query = dsl_gql(hero_fragment) + + expected = """ +fragment heroFragment($episode: Episode) on Query { + hero(episode: $episode) { + name + } +} +""".strip() + assert print_ast(query) == expected From 528636a1cb4c9c51175dd2364006fbad6d6e6588 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 15 Dec 2023 15:10:16 +0100 Subject: [PATCH 135/239] Feature httpx transport working with trio (#455) --- docs/code_examples/httpx_async_trio.py | 34 ++++++++++++++++++++++++++ gql/client.py | 9 +++---- setup.py | 1 + 3 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 docs/code_examples/httpx_async_trio.py diff --git a/docs/code_examples/httpx_async_trio.py b/docs/code_examples/httpx_async_trio.py new file mode 100644 index 00000000..058b952b --- /dev/null +++ b/docs/code_examples/httpx_async_trio.py @@ -0,0 +1,34 @@ +import trio + +from gql import Client, gql +from gql.transport.httpx import HTTPXAsyncTransport + + +async def main(): + + transport = HTTPXAsyncTransport(url="https://countries.trevorblades.com/graphql") + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with Client( + transport=transport, + fetch_schema_from_transport=True, + ) as session: + + # Execute single query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + result = await session.execute(query) + print(result) + + +trio.run(main) diff --git a/gql/client.py b/gql/client.py index 5c1edffa..a79d4b72 100644 --- a/gql/client.py +++ b/gql/client.py @@ -22,6 +22,7 @@ ) import backoff +from anyio import fail_after from graphql import ( DocumentNode, ExecutionResult, @@ -1532,15 +1533,13 @@ async def _execute( ) # Execute the query with the transport with a timeout - result = await asyncio.wait_for( - self.transport.execute( + with fail_after(self.client.execute_timeout): + result = await self.transport.execute( document, variable_values=variable_values, operation_name=operation_name, **kwargs, - ), - self.client.execute_timeout, - ) + ) # Unserialize the result if requested if self.client.schema: diff --git a/setup.py b/setup.py index eb215b53..773aacc5 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ "graphql-core>=3.3.0a3,<3.4", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", + "anyio>=3.0,<5", ] console_scripts = [ From 039236c0eca09e44d49e9ccc438f8310da514fd7 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 15 Dec 2023 15:11:17 +0100 Subject: [PATCH 136/239] Bump version number to 3.5.0b9 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index fdaa43c1..c6cc6fbf 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b8" +__version__ = "3.5.0b9" From c23d3e037b63aa11d9e67d30aecff17ba1921e30 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 3 Jan 2024 01:17:31 +0100 Subject: [PATCH 137/239] Fix online tests using the countries.trevorblades.com backend (#459) * Remove http online tests - only https is supported on backend now * Skip online websockets tests - backend does not support it anymore * Skip/remove batching online tests as backend does not support it anymore * Remove 2 flaky online tests --- tests/test_aiohttp_online.py | 7 ++--- tests/test_client.py | 56 ++++++++-------------------------- tests/test_http_async_sync.py | 23 ++++++-------- tests/test_httpx_online.py | 7 ++--- tests/test_requests_batch.py | 14 +++++++-- tests/test_websocket_online.py | 11 +++++++ 6 files changed, 49 insertions(+), 69 deletions(-) diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 53c246ea..39b8a9d2 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -11,13 +11,12 @@ @pytest.mark.aiohttp @pytest.mark.online @pytest.mark.asyncio -@pytest.mark.parametrize("protocol", ["http", "https"]) -async def test_aiohttp_simple_query(event_loop, protocol): +async def test_aiohttp_simple_query(event_loop): from gql.transport.aiohttp import AIOHTTPTransport - # Create http or https url - url = f"{protocol}://countries.trevorblades.com/graphql" + # Create https url + url = "https://countries.trevorblades.com/graphql" # Get transport sample_transport = AIOHTTPTransport(url=url) diff --git a/tests/test_client.py b/tests/test_client.py index 2fb333a9..ada129c6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -46,7 +46,6 @@ def execute(self): @pytest.mark.aiohttp -@pytest.mark.asyncio def test_request_async_execute_batch_not_implemented_yet(): from gql.transport.aiohttp import AIOHTTPTransport @@ -145,32 +144,13 @@ def test_execute_result_error(): client.execute(failing_query) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) + """ + Batching is not supported anymore on countries backend + with pytest.raises(TransportQueryError) as exc_info: client.execute_batch([GraphQLRequest(document=failing_query)]) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) - - -@pytest.mark.online -@pytest.mark.requests -def test_http_transport_raise_for_status_error(http_transport_query): - from gql.transport.requests import RequestsHTTPTransport - - with Client( - transport=RequestsHTTPTransport( - url="https://countries.trevorblades.com/", - use_json=False, - headers={"Content-type": "application/json"}, - ) - ) as client: - with pytest.raises(Exception) as exc_info: - client.execute(http_transport_query) - - assert "400 Client Error: Bad Request for url" in str(exc_info.value) - - with pytest.raises(Exception) as exc_info: - client.execute_batch([GraphQLRequest(document=http_transport_query)]) - - assert "400 Client Error: Bad Request for url" in str(exc_info.value) + """ @pytest.mark.online @@ -192,6 +172,9 @@ def test_http_transport_verify_error(http_transport_query): record[0].message ) + """ + Batching is not supported anymore on countries backend + with pytest.warns(Warning) as record: client.execute_batch([GraphQLRequest(document=http_transport_query)]) @@ -199,6 +182,7 @@ def test_http_transport_verify_error(http_transport_query): assert "Unverified HTTPS request is being made to host" in str( record[0].message ) + """ @pytest.mark.online @@ -215,28 +199,12 @@ def test_http_transport_specify_method_valid(http_transport_query): result = client.execute(http_transport_query) assert result is not None + """ + Batching is not supported anymore on countries backend + result = client.execute_batch([GraphQLRequest(document=http_transport_query)]) assert result is not None - - -@pytest.mark.online -@pytest.mark.requests -def test_http_transport_specify_method_invalid(http_transport_query): - from gql.transport.requests import RequestsHTTPTransport - - with Client( - transport=RequestsHTTPTransport( - url="https://countries.trevorblades.com/", - method="GET", - ) - ) as client: - with pytest.raises(Exception) as exc_info: - client.execute(http_transport_query) - assert "400 Client Error: Bad Request for url" in str(exc_info.value) - - with pytest.raises(Exception) as exc_info: - client.execute_batch([GraphQLRequest(document=http_transport_query)]) - assert "400 Client Error: Bad Request for url" in str(exc_info.value) + """ def test_gql(): diff --git a/tests/test_http_async_sync.py b/tests/test_http_async_sync.py index a086d442..19b6cfa2 100644 --- a/tests/test_http_async_sync.py +++ b/tests/test_http_async_sync.py @@ -6,16 +6,13 @@ @pytest.mark.aiohttp @pytest.mark.online @pytest.mark.asyncio -@pytest.mark.parametrize("protocol", ["http", "https"]) @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -async def test_async_client_async_transport( - event_loop, protocol, fetch_schema_from_transport -): +async def test_async_client_async_transport(event_loop, fetch_schema_from_transport): from gql.transport.aiohttp import AIOHTTPTransport - # Create http or https url - url = f"{protocol}://countries.trevorblades.com/graphql" + # Create https url + url = "https://countries.trevorblades.com/graphql" # Get async transport sample_transport = AIOHTTPTransport(url=url) @@ -76,14 +73,13 @@ async def test_async_client_sync_transport(event_loop, fetch_schema_from_transpo @pytest.mark.aiohttp @pytest.mark.online -@pytest.mark.parametrize("protocol", ["http", "https"]) @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -def test_sync_client_async_transport(protocol, fetch_schema_from_transport): +def test_sync_client_async_transport(fetch_schema_from_transport): from gql.transport.aiohttp import AIOHTTPTransport - # Create http or https url - url = f"{protocol}://countries.trevorblades.com/graphql" + # Create https url + url = "https://countries.trevorblades.com/graphql" # Get async transport sample_transport = AIOHTTPTransport(url=url) @@ -120,14 +116,13 @@ def test_sync_client_async_transport(protocol, fetch_schema_from_transport): @pytest.mark.requests @pytest.mark.online -@pytest.mark.parametrize("protocol", ["http", "https"]) @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -def test_sync_client_sync_transport(protocol, fetch_schema_from_transport): +def test_sync_client_sync_transport(fetch_schema_from_transport): from gql.transport.requests import RequestsHTTPTransport - # Create http or https url - url = f"{protocol}://countries.trevorblades.com/graphql" + # Create https url + url = "https://countries.trevorblades.com/graphql" # Get sync transport sample_transport = RequestsHTTPTransport(url=url, use_json=True) diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index ee08e2b1..23d28dcc 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -11,13 +11,12 @@ @pytest.mark.httpx @pytest.mark.online @pytest.mark.asyncio -@pytest.mark.parametrize("protocol", ["http", "https"]) -async def test_httpx_simple_query(event_loop, protocol): +async def test_httpx_simple_query(event_loop): from gql.transport.httpx import HTTPXAsyncTransport - # Create http or https url - url = f"{protocol}://countries.trevorblades.com/graphql" + # Create https url + url = "https://countries.trevorblades.com/graphql" # Get transport sample_transport = HTTPXAsyncTransport(url=url) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 1f922db7..4d8bf27e 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -540,15 +540,21 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +ONLINE_URL = "https://countries.trevorblades.com/" + +skip_reason = "backend does not support batching anymore..." + + @pytest.mark.online @pytest.mark.requests +@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto(): from threading import Thread from gql.transport.requests import RequestsHTTPTransport client = Client( - transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + transport=RequestsHTTPTransport(url=ONLINE_URL), batch_interval=0.01, batch_max=3, ) @@ -607,12 +613,13 @@ def get_continent_name(session, continent_code): @pytest.mark.online @pytest.mark.requests +@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto_execute_future(): from gql.transport.requests import RequestsHTTPTransport client = Client( - transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + transport=RequestsHTTPTransport(url=ONLINE_URL), batch_interval=0.01, batch_max=3, ) @@ -644,12 +651,13 @@ def test_requests_sync_batch_auto_execute_future(): @pytest.mark.online @pytest.mark.requests +@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_manual(): from gql.transport.requests import RequestsHTTPTransport client = Client( - transport=RequestsHTTPTransport(url="https://countries.trevorblades.com/"), + transport=RequestsHTTPTransport(url=ONLINE_URL), ) query = gql( diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index 7aa869a9..b5fca837 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -15,8 +15,14 @@ logging.basicConfig(level=logging.INFO) +skip_reason = ( + "backend does not support websockets anymore: " + "https://github.com/trevorblades/countries/issues/42" +) + @pytest.mark.online +@pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_simple_query(): from gql.transport.websockets import WebsocketsTransport @@ -57,6 +63,7 @@ async def test_websocket_simple_query(): @pytest.mark.online +@pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_invalid_query(): from gql.transport.websockets import WebsocketsTransport @@ -86,6 +93,7 @@ async def test_websocket_invalid_query(): @pytest.mark.online +@pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_sending_invalid_data(): from gql.transport.websockets import WebsocketsTransport @@ -121,6 +129,7 @@ async def test_websocket_sending_invalid_data(): @pytest.mark.online +@pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_sending_invalid_payload(): from gql.transport.websockets import WebsocketsTransport @@ -143,6 +152,7 @@ async def test_websocket_sending_invalid_payload(): @pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +@pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_sending_invalid_data_while_other_query_is_running(): from gql.transport.websockets import WebsocketsTransport @@ -194,6 +204,7 @@ async def query_task2(): @pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +@pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_two_queries_in_parallel_using_two_tasks(): from gql.transport.websockets import WebsocketsTransport From d14d1f4a543c3af5820a2a49d38310632abc432a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 3 Jan 2024 15:15:56 +0100 Subject: [PATCH 138/239] Modify tests to work with multiple versions of graphql-core (#460) * Tests modified to work with multiple graphql-core versions regarding braces spaces * Modify gql to work before and after graphql-core 3.3.0a3 subscribe changes --- gql/transport/local_schema.py | 11 +++- tests/conftest.py | 13 ++++ tests/custom_scalars/test_json.py | 10 ++-- tests/starwars/test_dsl.py | 92 ++++++++++++++++++++++++++--- tests/starwars/test_subscription.py | 22 +++++-- tests/test_aiohttp.py | 24 ++++---- tests/test_httpx.py | 25 ++++---- tests/test_httpx_async.py | 22 +++---- tests/test_requests.py | 25 ++++---- 9 files changed, 180 insertions(+), 64 deletions(-) diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index b2423346..04ed4ff1 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,3 +1,4 @@ +import asyncio from inspect import isawaitable from typing import AsyncGenerator, Awaitable, cast @@ -48,6 +49,12 @@ async def execute( return execution_result + @staticmethod + async def _await_if_necessary(obj): + """This method is necessary to work with + graphql-core versions < and >= 3.3.0a3""" + return await obj if asyncio.iscoroutine(obj) else obj + async def subscribe( self, document: DocumentNode, @@ -59,7 +66,9 @@ async def subscribe( The results are sent as an ExecutionResult object """ - subscribe_result = subscribe(self.schema, document, *args, **kwargs) + subscribe_result = await self._await_if_necessary( + subscribe(self.schema, document, *args, **kwargs) + ) if isinstance(subscribe_result, ExecutionResult): yield subscribe_result diff --git a/tests/conftest.py b/tests/conftest.py index 30c0d6f0..6a37a5d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import logging import os import pathlib +import re import ssl import sys import tempfile @@ -506,3 +507,15 @@ async def run_sync_test_inner(event_loop, server, test_function): "tests.fixtures.aws.fake_session", "tests.fixtures.aws.fake_signer", ] + + +def strip_braces_spaces(s): + """Allow to ignore differences in graphql-core syntax between versions""" + + # Strip spaces after starting braces + strip_front = s.replace("{ ", "{") + + # Strip spaces before closing braces only if one space is present + strip_back = re.sub(r"([^\s]) }", r"\1}", strip_front) + + return strip_back diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index 6276b408..d3eae3b8 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -18,6 +18,8 @@ from gql import Client, gql from gql.dsl import DSLSchema +from ..conftest import strip_braces_spaces + # Marking all tests in this file with the aiohttp marker pytestmark = pytest.mark.aiohttp @@ -201,9 +203,9 @@ def test_json_value_input_in_dsl_argument(): print(str(query)) assert ( - str(query) + strip_braces_spaces(str(query)) == """addPlayer( - player: { name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"] } + player: {name: "Tim", level: 0, is_connected: false, score: 5, friends: ["Lea"]} )""" ) @@ -235,8 +237,8 @@ def test_json_value_input_with_none_list_in_dsl_argument(): print(str(query)) assert ( - str(query) + strip_braces_spaces(str(query)) == """addPlayer( - player: { name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null } + player: {name: "Bob", level: 9001, is_connected: true, score: 666.66, friends: null} )""" ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 4860e3a0..2aadf92f 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -37,6 +37,7 @@ ) from gql.utilities import get_introspection_query_ast, node_tree +from ..conftest import strip_braces_spaces from .schema import StarWarsSchema @@ -210,9 +211,9 @@ def test_add_variable_definitions_with_default_value_input_object(ds): query = dsl_gql(op) assert ( - print_ast(query) + strip_braces_spaces(print_ast(query)) == """ -mutation ($review: ReviewInput = { stars: 5, commentary: "Wow!" }, $episode: Episode) { +mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { createReview(review: $review, episode: $episode) { stars commentary @@ -235,10 +236,10 @@ def test_add_variable_definitions_in_input_object(ds): query = dsl_gql(op) assert ( - print_ast(query) + strip_braces_spaces(print_ast(query)) == """mutation ($stars: Int, $commentary: String, $episode: Episode) { createReview( - review: { stars: $stars, commentary: $commentary } + review: {stars: $stars, commentary: $commentary} episode: $episode ) { stars @@ -565,7 +566,7 @@ def test_multiple_operations(ds): ) assert ( - print_ast(query) + strip_braces_spaces(print_ast(query)) == """query GetHeroName { hero { name @@ -575,7 +576,7 @@ def test_multiple_operations(ds): mutation CreateReviewMutation { createReview( episode: JEDI - review: { stars: 5, commentary: "This is a great movie!" } + review: {stars: 5, commentary: "This is a great movie!"} ) { stars commentary @@ -1102,7 +1103,84 @@ def test_node_tree_with_loc(ds): """.strip() - assert node_tree(document, ignore_loc=False) == node_tree_result + node_tree_result_stable = """ +DocumentNode + loc: + Location + + definitions: + OperationDefinitionNode + loc: + Location + + name: + NameNode + loc: + Location + + value: + 'GetHeroName' + directives: + empty tuple + variable_definitions: + empty tuple + selection_set: + SelectionSetNode + loc: + Location + + selections: + FieldNode + loc: + Location + + directives: + empty tuple + alias: + None + name: + NameNode + loc: + Location + + value: + 'hero' + arguments: + empty tuple + selection_set: + SelectionSetNode + loc: + Location + + selections: + FieldNode + loc: + Location + + directives: + empty tuple + alias: + None + name: + NameNode + loc: + Location + + value: + 'name' + arguments: + empty tuple + selection_set: + None + operation: + +""".strip() + + try: + assert node_tree(document, ignore_loc=False) == node_tree_result + except AssertionError: + # graphql-core version 3.2.3 + assert node_tree(document, ignore_loc=False) == node_tree_result_stable def test_legacy_fragment_with_variables(ds): diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index c5a50514..0f412acc 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from graphql import ExecutionResult, GraphQLError, subscribe @@ -17,6 +19,14 @@ """ +async def await_if_coroutine(obj): + """Function to make tests work for graphql-core versions before and after 3.3.0a3""" + if asyncio.iscoroutine(obj): + return await obj + + return obj + + @pytest.mark.asyncio async def test_subscription_support(): # reset review data for this test @@ -30,7 +40,9 @@ async def test_subscription_support(): params = {"ep": "JEDI"} expected = [{**review, "episode": "JEDI"} for review in reviews[6]] - ai = subscribe(StarWarsSchema, subs, variable_values=params) + ai = await await_if_coroutine( + subscribe(StarWarsSchema, subs, variable_values=params) + ) result = [result.data["reviewAdded"] async for result in ai] @@ -53,8 +65,8 @@ async def test_subscription_support_using_client(): async with Client(schema=StarWarsSchema) as session: results = [ result["reviewAdded"] - async for result in session.subscribe( - subs, variable_values=params, parse_result=False + async for result in await await_if_coroutine( + session.subscribe(subs, variable_values=params, parse_result=False) ) ] @@ -80,8 +92,8 @@ async def test_subscription_support_using_client_invalid_field(): # We subscribe directly from the transport to avoid local validation results = [ result - async for result in session.transport.subscribe( - subs, variable_values=params + async for result in await await_if_coroutine( + session.transport.subscribe(subs, variable_values=params) ) ] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index a9b3bda6..09259e51 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -14,7 +14,7 @@ TransportServerError, ) -from .conftest import TemporaryFile +from .conftest import TemporaryFile, strip_braces_spaces query1_str = """ query getContinents { @@ -588,15 +588,15 @@ def test_code(): file_upload_mutation_1 = """ mutation($file: Upload!) { - uploadFile(input:{ other_var:$other_var, file:$file }) { + uploadFile(input:{other_var:$other_var, file:$file}) { success } } """ file_upload_mutation_1_operations = ( - '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' - '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' + '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -617,7 +617,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -679,7 +679,7 @@ async def single_upload_handler_with_content_type(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -790,7 +790,7 @@ async def binary_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -931,7 +931,7 @@ async def file_sender(file_name): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -955,7 +955,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_2_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations field_1 = await reader.next() assert field_1.name == "map" @@ -1019,7 +1019,7 @@ async def handler(request): file_upload_mutation_3 = """ mutation($files: [Upload!]!) { - uploadFiles(input:{ files:$files }) { + uploadFiles(input:{files:$files}) { success } } @@ -1027,7 +1027,7 @@ async def handler(request): file_upload_mutation_3_operations = ( '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' - "input: { files: $files })" + "input: {files: $files})" ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) @@ -1046,7 +1046,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_3_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations field_1 = await reader.next() assert field_1.name == "map" diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 56a984a4..af12f717 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -10,7 +10,8 @@ TransportQueryError, TransportServerError, ) -from tests.conftest import TemporaryFile + +from .conftest import TemporaryFile, strip_braces_spaces # Marking all tests in this file with the httpx marker pytestmark = pytest.mark.httpx @@ -397,15 +398,15 @@ def test_code(): file_upload_mutation_1 = """ mutation($file: Upload!) { - uploadFile(input:{ other_var:$other_var, file:$file }) { + uploadFile(input:{other_var:$other_var, file:$file}) { success } } """ file_upload_mutation_1_operations = ( - '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' - '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' + '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -431,7 +432,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -493,7 +494,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -563,7 +564,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -627,7 +628,7 @@ async def binary_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -677,7 +678,7 @@ def test_code(): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -710,7 +711,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_2_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations field_1 = await reader.next() assert field_1.name == "map" @@ -775,7 +776,7 @@ def test_code(): file_upload_mutation_3_operations = ( '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' - "(input: { files: $files })" + "(input: {files: $files})" ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) @@ -812,7 +813,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_3_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations field_1 = await reader.next() assert field_1.name == "map" diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 362875de..e5be73ec 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -14,7 +14,7 @@ TransportServerError, ) -from .conftest import TemporaryFile, get_localhost_ssl_context +from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces query1_str = """ query getContinents { @@ -601,15 +601,15 @@ def test_code(): file_upload_mutation_1 = """ mutation($file: Upload!) { - uploadFile(input:{ other_var:$other_var, file:$file }) { + uploadFile(input:{other_var:$other_var, file:$file}) { success } } """ file_upload_mutation_1_operations = ( - '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' - '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' + '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -630,7 +630,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -737,7 +737,7 @@ async def binary_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -801,7 +801,7 @@ async def test_httpx_binary_file_upload(event_loop, aiohttp_server): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -826,7 +826,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_2_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations field_1 = await reader.next() assert field_1.name == "map" @@ -890,7 +890,7 @@ async def handler(request): file_upload_mutation_3 = """ mutation($files: [Upload!]!) { - uploadFiles(input:{ files:$files }) { + uploadFiles(input:{files:$files}) { success } } @@ -898,7 +898,7 @@ async def handler(request): file_upload_mutation_3_operations = ( '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' - "input: { files: $files })" + "input: {files: $files})" ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) @@ -918,7 +918,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_3_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations field_1 = await reader.next() assert field_1.name == "map" diff --git a/tests/test_requests.py b/tests/test_requests.py index a5ff0d8b..639d2b73 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -10,7 +10,8 @@ TransportQueryError, TransportServerError, ) -from tests.conftest import TemporaryFile + +from .conftest import TemporaryFile, strip_braces_spaces # Marking all tests in this file with the requests marker pytestmark = pytest.mark.requests @@ -399,15 +400,15 @@ def test_code(): file_upload_mutation_1 = """ mutation($file: Upload!) { - uploadFile(input:{ other_var:$other_var, file:$file }) { + uploadFile(input:{other_var:$other_var, file:$file}) { success } } """ file_upload_mutation_1_operations = ( - '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: { other_var: ' - '$other_var, file: $file }) {\\n success\\n }\\n}", "variables": ' + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' + '$other_var, file: $file}) {\\n success\\n }\\n}", "variables": ' '{"file": null, "other_var": 42}}' ) @@ -433,7 +434,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -495,7 +496,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -565,7 +566,7 @@ async def single_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -629,7 +630,7 @@ async def binary_upload_handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_1_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations field_1 = await reader.next() assert field_1.name == "map" @@ -679,7 +680,7 @@ def test_code(): file_upload_mutation_2_operations = ( '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: { file1: $file, file2: $file }) {\\n success\\n }\\n}", ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' '"variables": {"file1": null, "file2": null}}' ) @@ -714,7 +715,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_2_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations field_1 = await reader.next() assert field_1.name == "map" @@ -779,7 +780,7 @@ def test_code(): file_upload_mutation_3_operations = ( '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' - "(input: { files: $files })" + "(input: {files: $files})" ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' ) @@ -816,7 +817,7 @@ async def handler(request): field_0 = await reader.next() assert field_0.name == "operations" field_0_text = await field_0.text() - assert field_0_text == file_upload_mutation_3_operations + assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations field_1 = await reader.next() assert field_1.name == "map" From aa1ffadc49f94ba0e1dacd2e75678cbca3ed02a4 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 3 Jan 2024 15:25:30 +0100 Subject: [PATCH 139/239] Revert graphql-core to stable versions 3.2.x --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 773aacc5..233900d2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.3.0a3,<3.4", + "graphql-core>=3.2,<3.3", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", From 96041132296849264e31601b5a30f54aa0f94185 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 3 Jan 2024 15:32:05 +0100 Subject: [PATCH 140/239] Bump version number to 3.5.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index c6cc6fbf..dcbfb52f 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b9" +__version__ = "3.5.0" From e6a7873dea7d23855ddd00c96f2125974241fe1e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 3 Jan 2024 15:55:28 +0100 Subject: [PATCH 141/239] Bump version number to 3.6.0b0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index c6cc6fbf..59890c93 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0b9" +__version__ = "3.6.0b0" From 3a641b13098e956125382822f274e7f4bc222d09 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 24 Jan 2024 19:25:33 +0100 Subject: [PATCH 142/239] Empty commit to try to fix codecov From a3f0bd93d21cf0f7b7219d88173738d38df2c42e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 8 Feb 2024 20:51:17 +0100 Subject: [PATCH 143/239] Adding json_deserialize parameter to aiohttp and httpx transports (#465) --- gql/client.py | 18 +++++++------- gql/transport/aiohttp.py | 6 ++++- gql/transport/httpx.py | 6 ++++- tests/test_aiohttp.py | 50 ++++++++++++++++++++++++++++++++++++++ tests/test_httpx_async.py | 51 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 120 insertions(+), 11 deletions(-) diff --git a/gql/client.py b/gql/client.py index a79d4b72..0d9e36c7 100644 --- a/gql/client.py +++ b/gql/client.py @@ -106,7 +106,7 @@ def __init__( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. :param parse_results: Whether gql will try to parse the serialized output - sent by the backend. Can be used to unserialize custom scalars or enums. + sent by the backend. Can be used to deserialize custom scalars or enums. :param batch_interval: Time to wait in seconds for batching requests together. Batching is disabled (by default) if 0. :param batch_max: Maximum number of requests in a single batch. @@ -892,7 +892,7 @@ def _execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. The extra arguments are passed to the transport execute method.""" @@ -1006,7 +1006,7 @@ def execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -1057,7 +1057,7 @@ def _execute_batch( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param validate_document: Whether we still need to validate the document. @@ -1151,7 +1151,7 @@ def execute_batch( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -1333,7 +1333,7 @@ async def _subscribe( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. The extra arguments are passed to the transport subscribe method.""" @@ -1454,7 +1454,7 @@ async def subscribe( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: yield the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -1511,7 +1511,7 @@ async def _execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. The extra arguments are passed to the transport execute method.""" @@ -1617,7 +1617,7 @@ async def execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 60f42c94..be22ce9c 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -50,6 +50,7 @@ def __init__( timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, + json_deserialize: Callable = json.loads, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -64,6 +65,8 @@ def __init__( to close properly :param json_serialize: Json serializer callable. By default json.dumps() function + :param json_deserialize: Json deserializer callable. + By default json.loads() function :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -81,6 +84,7 @@ def __init__( self.session: Optional[aiohttp.ClientSession] = None self.response_headers: Optional[CIMultiDictProxy[str]] self.json_serialize: Callable = json_serialize + self.json_deserialize: Callable = json_deserialize async def connect(self) -> None: """Coroutine which will create an aiohttp ClientSession() as self.session. @@ -328,7 +332,7 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): ) try: - result = await resp.json(content_type=None) + result = await resp.json(loads=self.json_deserialize, content_type=None) if log.isEnabledFor(logging.INFO): result_text = await resp.text() diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index cfc25dc9..811601b8 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -38,6 +38,7 @@ def __init__( self, url: Union[str, httpx.URL], json_serialize: Callable = json.dumps, + json_deserialize: Callable = json.loads, **kwargs, ): """Initialize the transport with the given httpx parameters. @@ -45,10 +46,13 @@ def __init__( :param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'. :param json_serialize: Json serializer callable. By default json.dumps() function. + :param json_deserialize: Json deserializer callable. + By default json.loads() function. :param kwargs: Extra args passed to the `httpx` client. """ self.url = url self.json_serialize = json_serialize + self.json_deserialize = json_deserialize self.kwargs = kwargs def _prepare_request( @@ -145,7 +149,7 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: log.debug("<<< %s", response.text) try: - result: Dict[str, Any] = response.json() + result: Dict[str, Any] = self.json_deserialize(response.content) except Exception: self._raise_response_error(response, "Not a JSON answer") diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 09259e51..b16964d0 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1511,6 +1511,56 @@ async def handler(request): assert expected_log in caplog.text +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.asyncio +async def test_aiohttp_json_deserializer(event_loop, aiohttp_server): + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = AIOHTTPTransport( + url=url, + timeout=10, + json_deserialize=json_loads, + ) + + async with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = await session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197") + + @pytest.mark.asyncio async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): from aiohttp import web, TCPConnector diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index e5be73ec..3665f5d8 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1389,3 +1389,54 @@ async def handler(request): # Checking that there is no space after the colon in the log expected_log = '"query":"query getContinents' assert expected_log in caplog.text + + +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_json_deserializer(event_loop, aiohttp_server): + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = HTTPXAsyncTransport( + url=url, + timeout=10, + json_deserialize=json_loads, + ) + + async with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = await session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197") From e5c7c8f3d498fe3397caae370f8ad35cd5140962 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 8 Feb 2024 21:33:24 +0100 Subject: [PATCH 144/239] Adding json_serialize and json_deserialize to requests transport (#466) --- gql/transport/requests.py | 23 ++++++--- tests/test_requests.py | 106 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 6 deletions(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 1e464104..0c6eb3fc 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,7 +1,7 @@ import io import json import logging -from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union import requests from graphql import DocumentNode, ExecutionResult, print_ast @@ -47,6 +47,8 @@ def __init__( method: str = "POST", retry_backoff_factor: float = 0.1, retry_status_forcelist: Collection[int] = _default_retry_codes, + json_serialize: Callable = json.dumps, + json_deserialize: Callable = json.loads, **kwargs: Any, ): """Initialize the transport with the given request parameters. @@ -73,6 +75,10 @@ def __init__( should force a retry on. A retry is initiated if the request method is in allowed_methods and the response status code is in status_forcelist. (Default: [429, 500, 502, 503, 504]) + :param json_serialize: Json serializer callable. + By default json.dumps() function + :param json_deserialize: Json deserializer callable. + By default json.loads() function :param kwargs: Optional arguments that ``request`` takes. These can be seen at the `requests`_ source code or the official `docs`_ @@ -90,6 +96,8 @@ def __init__( self.method = method self.retry_backoff_factor = retry_backoff_factor self.retry_status_forcelist = retry_status_forcelist + self.json_serialize: Callable = json_serialize + self.json_deserialize: Callable = json_deserialize self.kwargs = kwargs self.session = None @@ -174,7 +182,7 @@ def execute( # type: ignore payload["variables"] = nulled_variable_values # Add the payload to the operations field - operations_str = json.dumps(payload) + operations_str = self.json_serialize(payload) log.debug("operations %s", operations_str) # Generate the file map @@ -188,7 +196,7 @@ def execute( # type: ignore file_streams = {str(i): files[path] for i, path in enumerate(files)} # Add the file map field - file_map_str = json.dumps(file_map) + file_map_str = self.json_serialize(file_map) log.debug("file_map %s", file_map_str) fields = {"operations": operations_str, "map": file_map_str} @@ -224,7 +232,7 @@ def execute( # type: ignore # Log the payload if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(payload)) + log.info(">>> %s", self.json_serialize(payload)) # Pass kwargs to requests post method post_args.update(self.kwargs) @@ -257,7 +265,10 @@ def raise_response_error(resp: requests.Response, reason: str): ) try: - result = response.json() + if self.json_deserialize == json.loads: + result = response.json() + else: + result = self.json_deserialize(response.text) if log.isEnabledFor(logging.INFO): log.info("<<< %s", response.text) @@ -396,7 +407,7 @@ def _build_batch_post_args( # Log the payload if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(post_args[data_key])) + log.info(">>> %s", self.json_serialize(post_args[data_key])) # Pass kwargs to requests post method post_args.update(self.kwargs) diff --git a/tests/test_requests.py b/tests/test_requests.py index 639d2b73..ba666243 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -923,3 +923,109 @@ def test_code(): assert transport.session is None await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_json_serializer( + event_loop, aiohttp_server, run_sync_test, caplog +): + import json + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + + request_text = await request.text() + print("Received on backend: " + request_text) + + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport( + url=url, + json_serialize=lambda e: json.dumps(e, separators=(",", ":")), + ) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking that there is no space after the colon in the log + expected_log = '"query":"query getContinents' + assert expected_log in caplog.text + + await run_sync_test(event_loop, server, test_code) + + +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_test): + import json + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.requests import RequestsHTTPTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = RequestsHTTPTransport( + url=url, + json_deserialize=json_loads, + ) + + with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197") + + await run_sync_test(event_loop, server, test_code) From 48bb94cc4fc0755419c9edc7e8ef1470036c193d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 8 Feb 2024 23:44:54 +0100 Subject: [PATCH 145/239] Bump version number to 3.6.0b1 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 59890c93..372ccd70 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0b0" +__version__ = "3.6.0b1" From 40f07cfaf5d88e0949c9838480b12d56402c64dd Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 6 Mar 2024 02:14:49 +0100 Subject: [PATCH 146/239] Doc Fix confusion about TransportQueryError execute retries (#473) --- docs/advanced/async_permanent_session.rst | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/advanced/async_permanent_session.rst b/docs/advanced/async_permanent_session.rst index 240d8b4f..e42010cf 100644 --- a/docs/advanced/async_permanent_session.rst +++ b/docs/advanced/async_permanent_session.rst @@ -75,7 +75,6 @@ backoff decorator to the :code:`retry_execute` argument. backoff.expo, Exception, max_tries=3, - giveup=lambda e: isinstance(e, TransportQueryError), ) session = await client.connect_async( reconnecting=True, @@ -84,6 +83,18 @@ backoff decorator to the :code:`retry_execute` argument. If you don't want any retry on the execute calls, you can disable the retries with :code:`retry_execute=False` +.. note:: + If you want to retry even with :code:`TransportQueryError` exceptions, + then you need to make your own backoff decorator on your own method: + + .. code-block:: python + + @backoff.on_exception(backoff.expo, + Exception, + max_tries=3) + async def execute_with_retry(session, query): + return await session.execute(query) + Subscription retries ^^^^^^^^^^^^^^^^^^^^ From 90524038c63294bc3843928c092792700ad9e7af Mon Sep 17 00:00:00 2001 From: Paul Heasley Date: Sun, 17 Mar 2024 04:21:48 +1100 Subject: [PATCH 147/239] chore(tests): Fix spelling of _source (#474) --- tests/starwars/schema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 5f9a04b4..4b672ad3 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -176,7 +176,7 @@ "provided, returns the hero of that particular episode.", ) }, - resolve=lambda _souce, _info, episode=None: get_hero_async(episode), + resolve=lambda _source, _info, episode=None: get_hero_async(episode), ), "human": GraphQLField( human_type, @@ -186,7 +186,7 @@ type_=GraphQLNonNull(GraphQLString), ) }, - resolve=lambda _souce, _info, id: get_human(id), + resolve=lambda _source, _info, id: get_human(id), ), "droid": GraphQLField( droid_type, From ba53126360086230ccba61107e0c92bea1c5ca59 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 14 Apr 2024 21:51:45 +0200 Subject: [PATCH 148/239] Fix importing DirectiveLocation directly from graphql (#477) Should fix running gql with graphql-core 3.3.0a5 --- gql/utilities/build_client_schema.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/gql/utilities/build_client_schema.py b/gql/utilities/build_client_schema.py index 048ed80d..30402868 100644 --- a/gql/utilities/build_client_schema.py +++ b/gql/utilities/build_client_schema.py @@ -1,10 +1,7 @@ -from graphql import GraphQLSchema, IntrospectionQuery +from graphql import DirectiveLocation, GraphQLSchema, IntrospectionQuery from graphql import build_client_schema as build_client_schema_orig from graphql.pyutils import inspect -from graphql.utilities.get_introspection_query import ( - DirectiveLocation, - IntrospectionDirective, -) +from graphql.utilities.get_introspection_query import IntrospectionDirective __all__ = ["build_client_schema"] From 23636985857affd9b35bfc895f4bafdf2dc0801c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 14 Apr 2024 21:53:49 +0200 Subject: [PATCH 149/239] Bump version number to 3.6.0b2 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 372ccd70..dc9e18d0 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0b1" +__version__ = "3.6.0b2" From 8c33e8f8684b158b434df3e486879c54768c880c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 13 Jul 2024 16:40:33 +0200 Subject: [PATCH 150/239] Bump mypy to 1.10 (#485) --- gql/transport/requests.py | 4 +++- setup.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 0c6eb3fc..fd9759ed 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -381,7 +381,9 @@ def _extract_response(self, response: requests.Response) -> Any: log.info("<<< %s", response.text) except requests.HTTPError as e: - raise TransportServerError(str(e), e.response.status_code) from e + raise TransportServerError( + str(e), e.response.status_code if e.response is not None else None + ) from e except Exception: self._raise_invalid_result(str(response.text), "Not a JSON answer") diff --git a/setup.py b/setup.py index 773aacc5..0a2fd418 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "check-manifest>=0.42,<1", "flake8==3.8.1", "isort==4.3.21", - "mypy==0.910", + "mypy==1.10", "sphinx>=5.3.0,<6", "sphinx_rtd_theme>=0.4,<1", "sphinx-argparse==0.2.5", From 9f932159e084fe2686febced979ca21f4b75b745 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 14 Jul 2024 18:27:01 +0200 Subject: [PATCH 151/239] Fix properly exiting the WebsocketsTransport when a connection_error is received during init (#486) --- gql/transport/websockets_base.py | 6 +++++- tests/test_websocket_query.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index 45c96d3e..5c7713e9 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -512,7 +512,11 @@ async def connect(self) -> None: await self._initialize() except ConnectionClosed as e: raise e - except (TransportProtocolError, asyncio.TimeoutError) as e: + except ( + TransportProtocolError, + TransportServerError, + asyncio.TimeoutError, + ) as e: await self._fail(e, clean_close=False) raise e diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index e8b7a022..d2270e7d 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -441,6 +441,9 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) + await asyncio.sleep(1) + assert transport.websocket is None + @pytest.mark.parametrize("server", [server1_answers], indirect=True) def test_websocket_execute_sync(server): From 85605939053e37da976b7827303a821adcf59dc1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 14 Jul 2024 18:40:15 +0200 Subject: [PATCH 152/239] Always close transport when an exception appears during the transport connect (#488) --- gql/client.py | 6 +++++- tests/test_websocket_query.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/gql/client.py b/gql/client.py index 0d9e36c7..a9a2c7e2 100644 --- a/gql/client.py +++ b/gql/client.py @@ -786,7 +786,11 @@ async def connect_async(self, reconnecting=False, **kwargs): self.session = ReconnectingAsyncClientSession(client=self, **kwargs) await self.session.start_connecting_task() else: - await self.transport.connect() + try: + await self.transport.connect() + except Exception as e: + await self.transport.close() + raise e self.session = AsyncClientSession(client=self) # Get schema from transport if needed diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index d2270e7d..9e6fd4ab 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -441,7 +441,6 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) - await asyncio.sleep(1) assert transport.websocket is None From e63ed0faaae675143a494d7cff6e4b081dcd9d1b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 14 Jul 2024 18:40:52 +0200 Subject: [PATCH 153/239] Remove Python 3.7 support (#489) --- .github/workflows/tests.yml | 4 +--- setup.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 30e8289c..7588a997 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,11 +8,9 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.8"] os: [ubuntu-20.04, windows-latest] exclude: - - os: windows-latest - python-version: "3.7" - os: windows-latest python-version: "3.9" - os: windows-latest diff --git a/setup.py b/setup.py index 0a2fd418..e5dbb8ed 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,6 @@ "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", From ede1350b61bf94f98ba92a57da43936e4f2b45ee Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 14 Jul 2024 18:47:08 +0200 Subject: [PATCH 154/239] Bump pytest-cov dev-dependency to 5.0.0 (#487) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e5dbb8ed..8828f8f0 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "pytest==7.4.2", "pytest-asyncio==0.21.1", "pytest-console-scripts==1.3.1", - "pytest-cov==3.0.0", + "pytest-cov==5.0.0", "mock==4.0.2", "vcrpy==4.4.0", "aiofiles", From 00b61d51ef25394e03803fd5a4c05ef910fa6a0b Mon Sep 17 00:00:00 2001 From: tlowery-scwx <150165182+tlowery-scwx@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:15:41 -0700 Subject: [PATCH 155/239] New transport AIOHTTPWebsocketsTransport (#478) --- .../code_examples/aiohttp_websockets_async.py | 50 + docs/intro.rst | 40 +- docs/modules/gql.rst | 1 + docs/modules/transport_aiohttp_websockets.rst | 7 + docs/transports/aiohttp.rst | 4 +- docs/transports/aiohttp_websockets.rst | 31 + docs/transports/async_transports.rst | 1 + gql/cli.py | 13 +- gql/transport/aiohttp_websockets.py | 1196 +++++++++++++++++ tests/conftest.py | 220 +++ tests/test_aiohttp_websocket_exceptions.py | 406 ++++++ ..._aiohttp_websocket_graphqlws_exceptions.py | 276 ++++ ...iohttp_websocket_graphqlws_subscription.py | 879 ++++++++++++ tests/test_aiohttp_websocket_query.py | 707 ++++++++++ tests/test_aiohttp_websocket_subscription.py | 809 +++++++++++ 15 files changed, 4619 insertions(+), 21 deletions(-) create mode 100644 docs/code_examples/aiohttp_websockets_async.py create mode 100644 docs/modules/transport_aiohttp_websockets.rst create mode 100644 docs/transports/aiohttp_websockets.rst create mode 100644 gql/transport/aiohttp_websockets.py create mode 100644 tests/test_aiohttp_websocket_exceptions.py create mode 100644 tests/test_aiohttp_websocket_graphqlws_exceptions.py create mode 100644 tests/test_aiohttp_websocket_graphqlws_subscription.py create mode 100644 tests/test_aiohttp_websocket_query.py create mode 100644 tests/test_aiohttp_websocket_subscription.py diff --git a/docs/code_examples/aiohttp_websockets_async.py b/docs/code_examples/aiohttp_websockets_async.py new file mode 100644 index 00000000..69520053 --- /dev/null +++ b/docs/code_examples/aiohttp_websockets_async.py @@ -0,0 +1,50 @@ +import asyncio +import logging + +from gql import Client, gql +from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + +logging.basicConfig(level=logging.INFO) + + +async def main(): + + transport = AIOHTTPWebsocketsTransport( + url="wss://countries.trevorblades.com/graphql" + ) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with Client( + transport=transport, + ) as session: + + # Execute single query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + result = await session.execute(query) + print(result) + + # Request subscription + subscription = gql( + """ + subscription { + somethingChanged { + id + } + } + """ + ) + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/intro.rst b/docs/intro.rst index 8f59ed16..21de16bd 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -36,25 +36,27 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: The corresponding between extra dependencies required and the GQL classes is: -+---------------------+----------------------------------------------------------------+ -| Extra dependencies | Classes | -+=====================+================================================================+ -| aiohttp | :ref:`AIOHTTPTransport ` | -+---------------------+----------------------------------------------------------------+ -| websockets | :ref:`WebsocketsTransport ` | -| | | -| | :ref:`PhoenixChannelWebsocketsTransport ` | -| | | -| | :ref:`AppSyncWebsocketsTransport ` | -+---------------------+----------------------------------------------------------------+ -| requests | :ref:`RequestsHTTPTransport ` | -+---------------------+----------------------------------------------------------------+ -| httpx | :ref:`HTTPTXTransport ` | -| | | -| | :ref:`HTTPXAsyncTransport ` | -+---------------------+----------------------------------------------------------------+ -| botocore | :ref:`AppSyncIAMAuthentication ` | -+---------------------+----------------------------------------------------------------+ ++---------------------+------------------------------------------------------------------+ +| Extra dependencies | Classes | ++=====================+==================================================================+ +| aiohttp | :ref:`AIOHTTPTransport ` | +| | | +| | :ref:`AIOHTTPWebsocketsTransport ` | ++---------------------+------------------------------------------------------------------+ +| websockets | :ref:`WebsocketsTransport ` | +| | | +| | :ref:`PhoenixChannelWebsocketsTransport ` | +| | | +| | :ref:`AppSyncWebsocketsTransport ` | ++---------------------+------------------------------------------------------------------+ +| requests | :ref:`RequestsHTTPTransport ` | ++---------------------+------------------------------------------------------------------+ +| httpx | :ref:`HTTPTXTransport ` | +| | | +| | :ref:`HTTPXAsyncTransport ` | ++---------------------+------------------------------------------------------------------+ +| botocore | :ref:`AppSyncIAMAuthentication ` | ++---------------------+------------------------------------------------------------------+ .. note:: diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 5f9edebe..b7c13c7c 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,6 +21,7 @@ Sub-Packages client transport transport_aiohttp + transport_aiohttp_websockets transport_appsync_auth transport_appsync_websockets transport_exceptions diff --git a/docs/modules/transport_aiohttp_websockets.rst b/docs/modules/transport_aiohttp_websockets.rst new file mode 100644 index 00000000..efa7e1bc --- /dev/null +++ b/docs/modules/transport_aiohttp_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.aiohttp_websockets +================================ + +.. currentmodule:: gql.transport.aiohttp_websockets + +.. automodule:: gql.transport.aiohttp_websockets + :member-order: bysource diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index 68b3eb99..b852108b 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -10,7 +10,9 @@ Reference: :class:`gql.transport.aiohttp.AIOHTTPTransport` .. note:: GraphQL subscriptions are not supported on the HTTP transport. - For subscriptions you should use the :ref:`websockets transport `. + For subscriptions you should use a websockets transport: + :ref:`WebsocketsTransport ` or + :ref:`AIOHTTPWebsocketsTransport `. .. literalinclude:: ../code_examples/aiohttp_async.py diff --git a/docs/transports/aiohttp_websockets.rst b/docs/transports/aiohttp_websockets.rst new file mode 100644 index 00000000..def3372e --- /dev/null +++ b/docs/transports/aiohttp_websockets.rst @@ -0,0 +1,31 @@ +.. _aiohttp_websockets_transport: + +AIOHTTPWebsocketsTransport +========================== + +The AIOHTTPWebsocketsTransport is an alternative to the :ref:`websockets_transport`, +using the `aiohttp` dependency instead of the `websockets` dependency. + +It also supports both: + + - the `Apollo websockets transport protocol`_. + - the `GraphQL-ws websockets transport protocol`_ + +It will propose both subprotocols to the backend and detect the supported protocol +from the response http headers returned by the backend. + +.. note:: + For some backends (graphql-ws before `version 5.6.1`_ without backwards compatibility), it may be necessary to specify + only one subprotocol to the backend. It can be done by using + :code:`subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL]` + or :code:`subprotocols=[AIOHTTPWebsocketsTransport.APOLLO_SUBPROTOCOL]` in the transport arguments. + +This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection. + +Reference: :class:`gql.transport.aiohttp_websockets.AIOHTTPWebsocketsTransport` + +.. literalinclude:: ../code_examples/aiohttp_websockets_async.py + +.. _version 5.6.1: https://github.com/enisdenjo/graphql-ws/releases/tag/v5.6.1 +.. _Apollo websockets transport protocol: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +.. _GraphQL-ws websockets transport protocol: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index 7d751df0..ba5ca136 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -12,5 +12,6 @@ Async transports are transports which are using an underlying async library. The aiohttp httpx_async websockets + aiohttp_websockets phoenix appsync diff --git a/gql/cli.py b/gql/cli.py index dd991546..a7d129e2 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -159,6 +159,7 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: "aiohttp", "phoenix", "websockets", + "aiohttp_websockets", "appsync_http", "appsync_websockets", ], @@ -286,7 +287,12 @@ def autodetect_transport(url: URL) -> str: """Detects which transport should be used depending on url.""" if url.scheme in ["ws", "wss"]: - transport_name = "websockets" + try: + import websockets # noqa: F401 + + transport_name = "websockets" + except ImportError: # pragma: no cover + transport_name = "aiohttp_websockets" else: assert url.scheme in ["http", "https"] @@ -338,6 +344,11 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: return WebsocketsTransport(url=args.server, **transport_args) + elif transport_name == "aiohttp_websockets": + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + return AIOHTTPWebsocketsTransport(url=args.server, **transport_args) + else: from gql.transport.appsync_auth import AppSyncAuthentication diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py new file mode 100644 index 00000000..ff310a82 --- /dev/null +++ b/gql/transport/aiohttp_websockets.py @@ -0,0 +1,1196 @@ +import asyncio +import json +import logging +import sys +import warnings +from contextlib import suppress +from ssl import SSLContext +from typing import ( + Any, + AsyncGenerator, + Collection, + Dict, + Mapping, + Optional, + Tuple, + Union, +) + +import aiohttp +from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp.typedefs import LooseHeaders, StrOrURL +from graphql import DocumentNode, ExecutionResult, print_ast +from multidict import CIMultiDictProxy + +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.async_transport import AsyncTransport +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +""" +Load the appropriate instance of the Literal type +Note: we cannot use try: except ImportError because of the following mypy issue: +https://github.com/python/mypy/issues/8520 +""" +if sys.version_info[:2] >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal # pragma: no cover + +log = logging.getLogger("gql.transport.aiohttp_websockets") + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + + +class AIOHTTPWebsocketsTransport(AsyncTransport): + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL: str = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" + + def __init__( + self, + url: StrOrURL, + *, + subprotocols: Optional[Collection[str]] = None, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + init_payload: Dict[str, Any] = {}, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param headers: HTTP Headers that sent with every request + May be either *iterable of key-value pairs* or + :class:`~collections.abc.Mapping` + (e.g. :class:`dict`, + :class:`~multidict.CIMultiDict`). + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param init_payload: Dict of the payload sent in the connection_init message. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param client_session_args: Dict of extra args passed to + `aiohttp.ClientSession`_ + :param connect_args: Dict of extra args passed to + `aiohttp.ClientSession.ws_connect`_ + + .. _aiohttp.ClientSession.ws_connect: + https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession.ws_connect + .. _aiohttp.ClientSession: + https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession + """ + self.url: StrOrURL = url + self.heartbeat: Optional[float] = heartbeat + self.auth: Optional[BasicAuth] = auth + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + self.headers: Optional[LooseHeaders] = headers + + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl + + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + + self.init_payload: Dict[str, Any] = init_payload + + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + + self.session: Optional[aiohttp.ClientSession] = None + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + self.next_query_id: int = 1 + self.listeners: Dict[int, ListenerQueue] = {} + self._connecting: bool = False + self.response_headers: Optional[CIMultiDictProxy[str]] = None + + self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() + + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + + self.payloads: Dict[str, Any] = {} + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + self.supported_subprotocols: Collection[str] = subprotocols or ( + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ) + + self.close_exception: Optional[Exception] = None + + self.client_session_args = client_session_args + self.connect_args = connect_args + + def _parse_answer_graphqlws( + self, answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, _, _ = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = {"type": "connection_init", "payload": self.init_payload} + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + await self._send_init_message_and_wait_ack() + + async def _stop_listener(self, query_id: int): + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket._response.headers + log.debug(f"Response headers: {response_headers!r}") + try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(ping_message) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(pong_message) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = {"id": str(query_id), "type": "stop"} + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = {"id": str(query_id), "type": "complete"} + + await self._send(complete_message) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close""" + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task + self.send_ping_task = None + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = {"type": "connection_terminate"} + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query = {"id": str(query_id), "type": query_type, "payload": payload} + + await self._send(query) + + return query_id + + async def _send(self, message: Dict[str, Any]) -> None: + """Send the provided message to the websocket connection and log the message""" + + if self.websocket is None: + raise TransportClosed("WebSocket connection is closed") + + try: + await self.websocket.send_json(message) + log.info(">>> %s", message) + except ConnectionResetError as e: + await self._fail(e, clean_close=False) + raise e + + async def _receive(self) -> str: + """Wait the next message from the websocket connection and log the answer""" + + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") + + while True: + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break + + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): + raise ConnectionResetError + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") + + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + log.info("<<< %s", answer) + + return answer + + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + log.debug("Entering _receive_data_loop()") + + try: + while True: + + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionResetError, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed as e: + await self._fail(e, clean_close=False) + raise e + + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break + + await self._handle_answer(answer_type, answer_id, execution_result) + + finally: + log.debug("Exiting _receive_data_loop()") + + async def connect(self) -> None: + log.debug("connect: starting") + + if self.session is None: + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.client_session_args: + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) + + if self.websocket is None and not self._connecting: + self._connecting = True + + connect_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + if self.connect_args: + connect_args.update(self.connect_args) + + try: + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + self.websocket = await asyncio.wait_for( + self.session.ws_connect( + url=self.url, + headers=self.headers, + auth=self.auth, + heartbeat=self.heartbeat, + origin=self.origin, + params=self.params, + protocols=self.supported_subprotocols, + proxy=self.proxy, + proxy_auth=self.proxy_auth, + proxy_headers=self.proxy_headers, + timeout=self.websocket_close_timeout, + receive_timeout=self.receive_timeout, + ssl=self.ssl, + **connect_args, + ), + self.connect_timeout, + ) + finally: + self._connecting = False + + self.response_headers = self.websocket._response.headers + + await self._after_connect() + + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() + + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionResetError as e: + raise e + except ( + TransportProtocolError, + TransportServerError, + asyncio.TimeoutError, + ) as e: + await self._fail(e, clean_close=False) + raise e + + # Run the after_init hook of the subclass + await self._after_initialize() + + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + else: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("connect: done") + + async def _clean_close(self) -> None: + """Coroutine which will: + + - send stop messages for each active subscription to the server + - send the connection terminate message + """ + log.debug(f"Listeners: {self.listeners}") + + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): + print(f"Listener {query_id} send_stop: {listener.send_stop}") + + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") + + # Calling the subclass hook + await self._connection_terminate() + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: + + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ + + log.debug("_close_coro: starting") + + try: + + try: + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + except Exception as exc: # pragma: no cover + log.warning( + "_close_coro cancel keep alive task exception: " + repr(exc) + ) + + try: + # Calling the subclass close hook + await self._close_hook() + except Exception as exc: # pragma: no cover + log.warning("_close_coro close_hook exception: " + repr(exc)) + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close() + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) + + log.debug("_close_coro: close websocket connection") + + try: + assert self.websocket is not None + + await self.websocket.close() + self.websocket = None + except Exception as exc: + log.warning("_close_coro websocket close exception: " + repr(exc)) + + log.debug("_close_coro: close aiohttp session") + + if ( + self.client_session_args + and self.client_session_args.get("connector_owner") is False + ): + + log.debug("connector_owner is False -> not closing connector") + + else: + try: + assert self.session is not None + + closed_event = AIOHTTPTransport.create_aiohttp_closed_event( + self.session + ) + await self.session.close() + try: + await asyncio.wait_for( + closed_event.wait(), self.ssl_close_timeout + ) + except asyncio.TimeoutError: + pass + except Exception as exc: # pragma: no cover + log.warning("_close_coro session close exception: " + repr(exc)) + + self.session = None + + log.debug("_close_coro: aiohttp session closed") + + try: + assert self.receive_data_task is not None + + self.receive_data_task.cancel() + with suppress(asyncio.CancelledError): + await self.receive_data_task + except Exception as exc: # pragma: no cover + log.warning( + "_close_coro cancel receive data task exception: " + repr(exc) + ) + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + + log.debug("_close_coro: final cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self.receive_data_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self._wait_closed.is_set(): + log.debug("_fail started but transport is already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) + + async def close(self) -> None: + log.debug("close: starting") + + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() + + log.debug("close: done") + + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + if not self._wait_closed.is_set(): + await self._wait_closed.wait() + + log.debug("wait_close: done") + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + + Send a query but close the async generator as soon as we have the first answer. + + The result is sent as an ExecutionResult object. + """ + first_result = None + + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) + + async for result in generator: + first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) + + return first_result + + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. + + The query can be a graphql query, mutation or subscription. + + The results are sent as an ExecutionResult object. + """ + + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) + + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener + + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + + try: + # Loop over the received answers + while True: + + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() + + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break + + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) diff --git a/tests/conftest.py b/tests/conftest.py index 6a37a5d3..c164c355 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,6 +119,7 @@ async def ssl_aiohttp_server(): for name in [ "websockets.legacy.server", "gql.transport.aiohttp", + "gql.transport.aiohttp_websockets", "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", @@ -210,6 +211,145 @@ async def stop(self): print("Server stopped\n\n\n") +class AIOHTTPWebsocketServer: + def __init__(self, with_ssl=False): + self.runner = None + self.site = None + self.port = None + self.hostname = "127.0.0.1" + self.with_ssl = with_ssl + self.ssl_context = None + if with_ssl: + _, self.ssl_context = get_localhost_ssl_context() + + def get_default_server_handler(answers): + async def default_server_handler(request): + + import aiohttp + import aiohttp.web + from aiohttp import WSMsgType + + ws = aiohttp.web.WebSocketResponse() + ws.headers.update({"dummy": "test1234"}) + await ws.prepare(request) + + try: + # Init and ack + msg = await ws.__anext__() + assert msg.type == WSMsgType.TEXT + result = msg.data + json_result = json.loads(result) + assert json_result["type"] == "connection_init" + await ws.send_str('{"type":"connection_ack"}') + query_id = 1 + + # Wait for queries and send answers + for answer in answers: + msg = await ws.__anext__() + if msg.type == WSMsgType.TEXT: + result = msg.data + + print(f"Server received: {result}", file=sys.stderr) + if isinstance(answer, str) and "{query_id}" in answer: + answer_format_params = {"query_id": query_id} + formatted_answer = answer.format(**answer_format_params) + else: + formatted_answer = answer + await ws.send_str(formatted_answer) + await ws.send_str( + f'{{"type":"complete","id":"{query_id}","payload":null}}' + ) + query_id += 1 + + elif msg.type == WSMsgType.ERROR: + print(f"WebSocket connection closed with: {ws.exception()}") + raise ws.exception() + elif msg.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + ): + print("WebSocket connection closed") + raise ConnectionResetError + + # Wait for connection_terminate + msg = await ws.__anext__() + result = msg.data + json_result = json.loads(result) + assert json_result["type"] == "connection_terminate" + + # Wait for connection close + msg = await ws.__anext__() + + except ConnectionResetError: + pass + + except Exception as e: + print(f"Server exception {e!s}", file=sys.stderr) + + await ws.close() + return ws + + return default_server_handler + + async def shutdown_server(self, app): + print("Shutting down server...") + await app.shutdown() + await app.cleanup() + + async def start(self, handler): + import aiohttp + import aiohttp.web + + app = aiohttp.web.Application() + app.router.add_get("/graphql", handler) + self.runner = aiohttp.web.AppRunner(app) + await self.runner.setup() + + # Use port 0 to bind to an available port + self.site = aiohttp.web.TCPSite( + self.runner, self.hostname, 0, ssl_context=self.ssl_context + ) + await self.site.start() + + # Retrieve the actual port the server is listening on + sockets = self.site._server.sockets + if sockets: + self.port = sockets[0].getsockname()[1] + protocol = "https" if self.with_ssl else "http" + print(f"Server started at {protocol}://{self.hostname}:{self.port}") + + async def stop(self): + if self.site: + await self.site.stop() + if self.runner: + await self.runner.cleanup() + + +@pytest_asyncio.fixture +async def aiohttp_ws_server(request): + """Fixture used to start a dummy server to test the client behaviour + using the aiohttp dependency. + + It can take as argument either a handler function for the websocket server for + complete control OR an array of answers to be sent by the default server handler. + """ + + server_handler = get_aiohttp_ws_server_handler(request) + + try: + test_server = AIOHTTPWebsocketServer() + + # Starting the server with the fixture param as the handler function + await test_server.start(server_handler) + + yield test_server + except Exception as e: + print("Exception received in server fixture:", e) + finally: + await test_server.stop() + + class WebSocketServerHelper: @staticmethod async def send_complete(ws, query_id): @@ -306,6 +446,23 @@ def __exit__(self, type, value, traceback): os.unlink(self.filename) +def get_aiohttp_ws_server_handler(request): + """Get the server handler for the aiohttp websocket server. + + Either get it from test or use the default server handler + if the test provides only an array of answers. + """ + + if isinstance(request.param, types.FunctionType): + server_handler = request.param + + else: + answers = request.param + server_handler = AIOHTTPWebsocketServer.get_default_server_handler(answers) + + return server_handler + + def get_server_handler(request): """Get the server handler. @@ -462,6 +619,48 @@ async def client_and_server(server): yield session, server +@pytest_asyncio.fixture +async def aiohttp_client_and_server(server): + """ + Helper fixture to start a server and a client connected to its port + with an aiohttp websockets transport. + """ + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, server + + +@pytest_asyncio.fixture +async def aiohttp_client_and_aiohttp_ws_server(aiohttp_ws_server): + """ + Helper fixture to start an aiohttp websocket server and + a client connected to its port with an aiohttp websockets transport. + """ + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = aiohttp_ws_server + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, server + + @pytest_asyncio.fixture async def client_and_graphqlws_server(graphqlws_server): """Helper fixture to start a server with the graphql-ws prototocol @@ -483,6 +682,27 @@ async def client_and_graphqlws_server(graphqlws_server): yield session, graphqlws_server +@pytest_asyncio.fixture +async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): + """Helper fixture to start a server with the graphql-ws prototocol + and a client connected to its port.""" + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # Generate transport to connect to the server fixture + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + sample_transport = AIOHTTPWebsocketsTransport( + url=url, + subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], + ) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield session, graphqlws_server + + @pytest_asyncio.fixture async def run_sync_test(): async def run_sync_test_inner(event_loop, server, test_function): diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py new file mode 100644 index 00000000..ea48824f --- /dev/null +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -0,0 +1,406 @@ +import asyncio +import json +import types +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"data","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_invalid_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_aiohttp_websocket_invalid_subscription( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +connection_error_server_answer = ( + '{"type":"connection_error","id":null,' + '"payload":{"message":"Unexpected token Q in JSON at position 0"}}' +) + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_server_does_not_send_ack( + event_loop, server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + +async def server_connection_error(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(connection_error_server_answer) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_connection_error], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_sending_invalid_data( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + invalid_data = "QSDF" + print(f">>> {invalid_data}") + await session.transport.websocket.send_str(invalid_data) + + await asyncio.sleep(2 * MS) + + +invalid_payload_server_answer = ( + '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +) + + +async def server_invalid_payload(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_payload_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_sending_invalid_payload( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + # Monkey patching the _send_query method to send an invalid payload + + async def monkey_patch_send_query( + self, + document, + variable_values=None, + operation_name=None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + {"id": str(query_id), "type": "start", "payload": "BLAHBLAH"} + ) + + await self._send(query_str) + return query_id + + session.transport._send_query = types.MethodType( + monkey_patch_send_query, session.transport + ) + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["message"] == "Must provide document" + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "data"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "data", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +payload_is_not_a_dict = ['{"type": "data", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "data", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_aiohttp_websocket_transport_protocol_errors( + event_loop, aiohttp_client_and_server +): + + session, server = aiohttp_client_and_server + + query = gql("query { hello }") + + with pytest.raises((TransportProtocolError, TransportQueryError)): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_without_ack], indirect=True) +async def test_aiohttp_websocket_server_does_not_ack(event_loop, server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_directly], indirect=True) +async def test_aiohttp_websocket_server_closing_directly(event_loop, server): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(ConnectionResetError): + async with Client(transport=sample_transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) +async def test_aiohttp_websocket_server_closing_after_ack( + event_loop, aiohttp_client_and_server +): + + session, server = aiohttp_client_and_server + + query = gql("query { hello }") + + with pytest.raises(TransportClosed): + await session.execute(query) + + +async def server_sending_invalid_query_errors(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + invalid_error = ( + '{"type":"error","id":"404","payload":' + '{"message":"error for no good reason on non existing query"}}' + ) + await ws.send(invalid_error) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_aiohttp_websocket_server_sending_invalid_query_errors( + event_loop, server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + # Invalid server message is ignored + async with Client(transport=sample_transport): + await asyncio.sleep(2 * MS) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_aiohttp_websocket_non_regression_bug_105(event_loop, server): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + # This test will check a fix to a race condition which happens if the user is trying + # to connect using the same client twice at the same time + # See bug #105 + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + sample_transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + # Create a coroutine which start the connection with the transport but does nothing + async def client_connect(client): + async with client: + await asyncio.sleep(2 * MS) + + # Create two tasks which will try to connect using the same client (not allowed) + connect_task1 = asyncio.ensure_future(client_connect(client)) + connect_task2 = asyncio.ensure_future(client_connect(client)) + + result = await asyncio.gather(connect_task1, connect_task2, return_exceptions=True) + + assert result[0] is None + assert type(result[1]).__name__ == "TransportAlreadyConnected" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +async def test_aiohttp_websocket_using_cli_invalid_query( + event_loop, server, monkeypatch, capsys +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + import io + + from gql.cli import get_parser, main + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(invalid_query_str)) + + # Flush captured output + captured = capsys.readouterr() + + await main(args) + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = 'Cannot query field "bloh" on type "Continent"' + + assert expected_error in captured_err diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py new file mode 100644 index 00000000..d87315c9 --- /dev/null +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -0,0 +1,276 @@ +import asyncio +from typing import List + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +from .conftest import WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"next","id":"{query_id}",' + '"payload":{{"errors":[' + '{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],' + '"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [invalid_query1_server_answer] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_graphqlws_invalid_query( + event_loop, client_and_aiohttp_websocket_graphql_server, query_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +async def server_invalid_subscription(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_invalid_subscription], indirect=True +) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_aiohttp_websocket_graphqlws_invalid_subscription( + event_loop, client_and_aiohttp_websocket_graphql_server, query_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError) as exc_info: + async for result in session.subscribe(query): + pass + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" + + +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( + event_loop, graphqlws_server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=transport): + pass + + +invalid_query_server_answer = ( + '{"id":"1","type":"error","payload":[{"message":"Cannot query field ' + '\\"helo\\" on type \\"Query\\". Did you mean \\"hello\\"?",' + '"locations":[{"line":2,"column":3}]}]}' +) + + +async def server_invalid_query(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_query_server_answer) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) +async def test_aiohttp_websocket_graphqlws_sending_invalid_query( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("{helo}") + + with pytest.raises(TransportQueryError) as exc_info: + await session.execute(query) + + exception = exc_info.value + + assert isinstance(exception.errors, List) + + error = exception.errors[0] + + assert ( + error["message"] + == 'Cannot query field "helo" on type "Query". Did you mean "hello"?' + ) + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "next"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "next", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +error_with_payload_not_a_list = ['{"type": "error", "id":"1", "payload": "NOT A LIST"}'] +payload_is_not_a_dict = ['{"type": "next", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "next", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + error_with_payload_not_a_list, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_aiohttp_websocket_graphqlws_transport_protocol_errors( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, server = client_and_aiohttp_websocket_graphql_server + + query = gql("query { hello }") + + with pytest.raises((TransportProtocolError, TransportQueryError)): + await session.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) +async def test_aiohttp_websocket_graphqlws_server_does_not_ack( + event_loop, graphqlws_server +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with Client(transport=transport): + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) +async def test_aiohttp_websocket_graphqlws_server_closing_directly( + event_loop, graphqlws_server +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + with pytest.raises(ConnectionResetError): + async with Client(transport=transport): + pass + + +async def server_closing_after_ack(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) +async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( + event_loop, client_and_aiohttp_websocket_graphql_server +): + + session, _ = client_and_aiohttp_websocket_graphql_server + + query = gql("query { hello }") + + with pytest.raises(TransportClosed): + await session.execute(query) + + await session.transport.wait_closed() + + with pytest.raises(TransportClosed): + await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py new file mode 100644 index 00000000..e5db7ca1 --- /dev/null +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -0,0 +1,879 @@ +import asyncio +import json +import sys +import warnings +from typing import List + +import pytest +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportServerError + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +countdown_server_answer = ( + '{{"type":"next","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +COUNTING_DELAY = 20 * MS +PING_SENDING_DELAY = 50 * MS +PONG_TIMEOUT = 100 * MS + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def server_countdown_factory( + keepalive=False, answer_pings=True, simulate_disconnect=False +): + async def server_countdown_template(ws, path): + import websockets + + logged_messages.clear() + + try: + await WebSocketServerHelper.send_connection_ack( + ws, payload="dummy_connection_ack_payload" + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f" Server: Countdown started from: {count}") + + if simulate_disconnect and count == 8: + await ws.close() + + pong_received: asyncio.Event = asyncio.Event() + + async def counting_coro(): + print(" Server: counting task started") + try: + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format( + query_id=query_id, number=number + ) + ) + await asyncio.sleep(COUNTING_DELAY) + finally: + print(" Server: counting task ended") + + print(" Server: starting counting task") + counting_task = asyncio.ensure_future(counting_coro()) + + async def keepalive_coro(): + print(" Server: keepalive task started") + try: + while True: + await asyncio.sleep(PING_SENDING_DELAY) + try: + # Send a ping + await WebSocketServerHelper.send_ping( + ws, payload="dummy_ping_payload" + ) + + # Wait for a pong + try: + await asyncio.wait_for( + pong_received.wait(), PONG_TIMEOUT + ) + except asyncio.TimeoutError: + print( + "\n Server: No pong received in time!\n" + ) + break + + pong_received.clear() + + except websockets.exceptions.ConnectionClosed: + break + finally: + print(" Server: keepalive task ended") + + if keepalive: + print(" Server: starting keepalive task") + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + print(" Server: receiving task started") + try: + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + answer_type = json_result["type"] + + if answer_type == "complete" and json_result["id"] == str( + query_id + ): + print("Cancelling counting task now") + counting_task.cancel() + if keepalive: + print("Cancelling keep alive task now") + keepalive_task.cancel() + + elif answer_type == "ping": + if answer_pings: + payload = json_result.get("payload", None) + await WebSocketServerHelper.send_pong( + ws, payload=payload + ) + + elif answer_type == "pong": + pong_received.set() + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + print(" Server: waiting for counting task to complete") + await counting_task + except asyncio.CancelledError: + print(" Server: Now counting task is cancelled") + + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) + + if keepalive: + print(" Server: cancelling keepalive task") + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") + + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\n Server: Assertion failed: {e!s}\n") + finally: + print(" Server: waiting for websocket connection to close") + await ws.wait_closed() + print(" Server: connection closed") + + return server_countdown_template + + +async def server_countdown(ws, path): + + server = server_countdown_factory() + await server(ws, path) + + +async def server_countdown_keepalive(ws, path): + + server = server_countdown_factory(keepalive=True) + await server(ws, path) + + +async def server_countdown_dont_answer_pings(ws, path): + + server = server_countdown_factory(answer_pings=False) + await server(ws, path) + + +async def server_countdown_disconnect(ws, path): + + server = server_countdown_factory(simulate_disconnect=True) + await server(ws, path) + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_break( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_close_transport( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(5.5 * COUNTING_DELAY) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "subscribe" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(COUNTING_DELAY) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + session, _ = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(ConnectionResetError): + async for result in session.subscribe(subscription): + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( + event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str +): + + session, server = client_and_aiohttp_websocket_graphql_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + assert "ping" in session.transport.payloads + assert session.transport.payloads["ping"] == "dummy_ping_payload" + assert ( + session.transport.payloads["connection_ack"] == "dummy_connection_ack_payload" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, keep_alive_timeout=(5 * COUNTING_DELAY) + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, keep_alive_timeout=(COUNTING_DELAY / 2) + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport( + url=url, + ping_interval=(10 * COUNTING_DELAY), + pong_timeout=(8 * COUNTING_DELAY), + ) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_dont_answer_pings], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, ping_interval=(5 * COUNTING_DELAY)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No pong received" in str(exc_info.value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + payload = {"count_received": count} + + await transport.send_ping(payload=payload) + + await asyncio.wait_for(transport.pong_received.wait(), 10000 * MS) + + transport.pong_received.clear() + + assert transport.payloads["pong"] == payload + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_manual_pong_with_payload( + event_loop, graphqlws_server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, answer_pings=False) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + async def answer_ping_coro(): + while True: + await transport.ping_received.wait() + transport.ping_received.clear() + await transport.send_pong(payload={"some": "data"}) + + answer_ping_task = asyncio.ensure_future(answer_ping_coro()) + + try: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + finally: + answer_ping_task.cancel() + + assert count == -1 + + +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_graphqlws_subscription_sync( + graphqlws_server, subscription_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( + graphqlws_server, subscription_str +): + """Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Check that the server received a connection_terminate message last + # assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_keepalive], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_graphqlws_subscription_running_in_thread( + event_loop, graphqlws_server, subscription_str, run_sync_test +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, graphqlws_server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "graphqlws_server", [server_countdown_disconnect], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +@pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) +async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( + event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe +): + + from gql.transport.exceptions import TransportClosed + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 8 + subscription_with_disconnect = gql(subscription_str.format(count=count)) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + session = await client.connect_async( + reconnecting=True, retry_connect=False, retry_execute=False + ) + + # First we make a subscription which will cause a disconnect in the backend + # (count=8) + try: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): + pass + except ConnectionResetError: + pass + + await asyncio.sleep(50 * MS) + + # Then with the same session handle, we make a subscription or an execute + # which will detect that the transport is closed so that the client could + # try to reconnect + try: + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + await session.execute(subscription) + else: + print("\nSUBSCRIPTION_2\n") + async for result in session.subscribe(subscription): + pass + except TransportClosed: + pass + + await asyncio.sleep(50 * MS) + + # And finally with the same session handle, we make a subscription + # which works correctly + print("\nSUBSCRIPTION_3\n") + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await client.close_async() diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py new file mode 100644 index 00000000..f154386b --- /dev/null +++ b/tests/test_aiohttp_websocket_query.py @@ -0,0 +1,707 @@ +import asyncio +import json +import ssl +import sys +from typing import Dict, Mapping + +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportQueryError, + TransportServerError, +) + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = pytest.mark.aiohttp + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_data = ( + '{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}' +) + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_starting_client_in_context_manager( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +async def test_aiohttp_websocket_using_ssl_connection( + event_loop, ws_ssl_server, ssl_close_timeout +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(ws_ssl_server.testcert) + + transport = AIOHTTPWebsocketsTransport( + url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout + ) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_simple_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + +server1_two_answers_in_series = [ + query1_server_answer, + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aiohttp_ws_server", [server1_two_answers_in_series], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_two_queries_in_series( + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str +): + + session, server = aiohttp_client_and_aiohttp_ws_server + + query = gql(query_str) + + result1 = await session.execute(query) + + print("Query1 received:", result1) + + result2 = await session.execute(query) + + print("Query2 received:", result2) + + assert result1 == result2 + + +async def server1_two_queries_in_parallel(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await ws.send(query1_server_answer.format(query_id=2)) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 2) + await WebSocketServerHelper.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_two_queries_in_parallel( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result1 = None + result2 = None + + async def task1_coro(): + nonlocal result1 + result1 = await session.execute(query) + + async def task2_coro(): + nonlocal result2 + result2 = await session.execute(query) + + task1 = asyncio.ensure_future(task1_coro()) + task2 = asyncio.ensure_future(task2_coro()) + + await asyncio.gather(task1, task2) + + print("Query1 received:", result1) + print("Query2 received:", result2) + + assert result1 == result2 + + +async def server_closing_while_we_are_doing_something_else(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await asyncio.sleep(1 * MS) + + # Closing server after first query + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_closing_while_we_are_doing_something_else], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_server_closing_after_first_query( + event_loop, aiohttp_client_and_server, query_str +): + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + # First query is working + await session.execute(query) + + # Then we do other things + await asyncio.sleep(1000 * MS) + + # Now the server is closed but we don't know it yet, we have to send a query + # to notice it and to receive the exception + with pytest.raises(TransportClosed): + await session.execute(query) + + +ignore_invalid_id_answers = [ + query1_server_answer, + '{"type":"complete","id": "55"}', + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aiohttp_ws_server", [ignore_invalid_id_answers], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_ignore_invalid_id( + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str +): + + session, server = aiohttp_client_and_aiohttp_ws_server + + query = gql(query_str) + + # First query is working + await session.execute(query) + + # Second query gets no answer -> raises + with pytest.raises(TransportQueryError): + await session.execute(query) + + # Third query is working + await session.execute(query) + + +async def assert_client_is_working(session): + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_series( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + # Check client is disconnect here + assert transport.websocket is None + + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_multiple_connections_in_parallel( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + async def task_coro(): + transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + task1 = asyncio.ensure_future(task_coro()) + task2 = asyncio.ensure_future(task_coro()) + + await asyncio.gather(task1, task2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transport( + event_loop, aiohttp_ws_server +): + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + async with Client(transport=transport) as session: + await assert_client_is_working(session) + + with pytest.raises(TransportAlreadyConnected): + async with Client(transport=transport): + pass + + +async def server_with_authentication_in_connection_init_payload(ws, path): + # Wait the connection_init message + init_message_str = await ws.recv() + init_message = json.loads(init_message_str) + payload = init_message["payload"] + + if "Authorization" in payload: + if payload["Authorization"] == 12345: + await ws.send('{"type":"connection_ack"}') + + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + else: + await ws.send( + '{"type":"connection_error", "payload": "Invalid Authorization token"}' + ) + else: + await ws.send( + '{"type":"connection_error", "payload": "No Authorization token"}' + ) + + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_connect_success_with_authentication_in_connection_init( + event_loop, server, query_str +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + init_payload = {"Authorization": 12345} + + transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) + + async with Client(transport=transport) as session: + + query1 = gql(query_str) + + result = await session.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) +async def test_aiohttp_websocket_connect_failed_with_authentication_in_connection_init( + event_loop, server, query_str, init_payload +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url, init_payload=init_payload) + + for _ in range(2): + with pytest.raises(TransportServerError): + async with Client(transport=transport) as session: + query1 = gql(query_str) + + await session.execute(query1) + + assert transport.session is None + assert transport.websocket is None + + +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + query1 = gql(query1_str) + + result = client.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Execute sync a second time + result = client.execute(query1) + + print("Client received:", result) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_add_extra_parameters_to_connect( + event_loop, aiohttp_ws_server +): + + server = aiohttp_ws_server + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + # Increase max payload size + transport = AIOHTTPWebsocketsTransport( + url=url, + connect_args={ + "max_msg_size": 2**21, + }, + ) + + query = gql(query1_str) + + async with Client(transport=transport) as session: + await session.execute(query) + + +async def server_sending_keep_alive_before_connection_ack(ws, path): + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}", file=sys.stderr) + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServerHelper.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize( + "server", [server_sending_keep_alive_before_connection_ack], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_non_regression_bug_108( + event_loop, aiohttp_client_and_server, query_str +): + + # This test will check that we now ignore keepalive message + # arriving before the connection_ack + # See bug #108 + + session, server = aiohttp_client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("transport_arg", [[], ["--transport=aiohttp_websockets"]]) +async def test_aiohttp_websocket_using_cli( + event_loop, aiohttp_ws_server, transport_arg, monkeypatch, capsys +): + + """ + Note: depending on the transport_arg parameter, if there is no transport argument, + then we will use WebsocketsTransport if the websockets dependency is installed, + or AIOHTTPWebsocketsTransport if that is not the case. + """ + + server = aiohttp_ws_server + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + import io + import json + + from gql.cli import get_parser, main + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, *transport_arg]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + # Flush captured output + captured = capsys.readouterr() + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + +query1_server_answer_with_extensions = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}},' + '"extensions": {{"key1": "val1"}}}}}}' +) + +server1_answers_with_extensions = [ + query1_server_answer_with_extensions, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "aiohttp_ws_server", [server1_answers_with_extensions], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_aiohttp_websocket_simple_query_with_extensions( + event_loop, aiohttp_client_and_aiohttp_ws_server, query_str +): + + session, server = aiohttp_client_and_aiohttp_ws_server + + query = gql(query_str) + + execution_result = await session.execute(query, get_execution_result=True) + + assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) +async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_server): + + server = aiohttp_ws_server + + from aiohttp import TCPConnector + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + connector = TCPConnector() + transport = AIOHTTPWebsocketsTransport( + url=url, + client_session_args={ + "connector": connector, + "connector_owner": False, + }, + ) + + for _ in range(2): + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + result = await session.execute(query1) + + print("Client received:", result) + + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert transport.websocket is None + + await connector.close() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py new file mode 100644 index 00000000..3ebf4dbc --- /dev/null +++ b/tests/test_aiohttp_websocket_subscription.py @@ -0,0 +1,809 @@ +import asyncio +import json +import sys +import warnings +from typing import List + +import pytest +from graphql import ExecutionResult +from parse import search + +from gql import Client, gql +from gql.transport.exceptions import TransportClosed, TransportServerError + +from .conftest import MS, WebSocketServerHelper +from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef + +# Marking all tests in this file with the aiohttp AND websockets marker +pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] + +starwars_expected_one = { + "stars": 3, + "commentary": "Was expecting more stuff", + "episode": "JEDI", +} + +starwars_expected_two = { + "stars": 5, + "commentary": "This is a great movie!", + "episode": "JEDI", +} + + +async def server_starwars(ws, path): + import websockets + + await WebSocketServerHelper.send_connection_ack(ws) + + try: + await ws.recv() + + reviews = [starwars_expected_one, starwars_expected_two] + + for review in reviews: + + data = ( + '{"type":"data","id":"1","payload":{"data":{"reviewAdded": ' + + json.dumps(review) + + "}}}" + ) + await ws.send(data) + await asyncio.sleep(2 * MS) + + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.wait_connection_terminate(ws) + + except websockets.exceptions.ConnectionClosedOK: + pass + + print("Server is now closed") + + +starwars_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + stars, + commentary, + episode + } + } +""" + +starwars_invalid_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + not_valid_field, + stars, + commentary, + episode + } + } +""" + +countdown_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + +WITH_KEEPALIVE = False + + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +async def server_countdown(ws, path): + import websockets + + logged_messages.clear() + + global WITH_KEEPALIVE + try: + await WebSocketServerHelper.send_connection_ack(ws) + if WITH_KEEPALIVE: + await WebSocketServerHelper.send_keepalive(ws) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(2 * MS) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def stopping_coro(): + nonlocal counting_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + json_result = json.loads(result) + + if json_result["type"] == "stop" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + try: + await WebSocketServerHelper.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break + + stopping_task = asyncio.ensure_future(stopping_coro()) + if WITH_KEEPALIVE: + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + except Exception as exc: + print(f"Exception in counting task: {exc!s}") + + stopping_task.cancel() + + try: + await stopping_task + except asyncio.CancelledError: + print("Now stopping task is cancelled") + + if WITH_KEEPALIVE: + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print("Now keepalive task is cancelled") + + await WebSocketServerHelper.send_complete(ws, query_id) + await WebSocketServerHelper.wait_connection_terminate(ws) + except websockets.exceptions.ConnectionClosedOK: + pass + finally: + await ws.wait_closed() + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_get_execution_result( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription, get_execution_result=True): + + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_break( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + if sys.version_info < (3, 7): + await session._generator.aclose() + break + + count -= 1 + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_task_cancel( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_close_transport( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, _ = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + await session.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await WebSocketServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(2 * MS) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_server_connection_closed( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(ConnectionResetError): + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_slow_consumer( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + await asyncio.sleep(10 * MS) + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_operation_name( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe( + subscription, operation_name="CountdownSubscription" + ): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + # Check that the query contains the operationName + assert '"operationName": "CountdownSubscription"' in logged_messages[0] + + +WITH_KEEPALIVE = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive( + event_loop, aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( + event_loop, server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( + event_loop, server, subscription_str +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync(server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_user_exception(server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(Exception) as exc_info: + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + raise Exception("This is an user exception") + + assert count == 5 + assert "This is an user exception" in str(exc_info.value) + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_break(server, subscription_str): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + if count == 5: + break + + assert count == 5 + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="test failing on windows") +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_aiohttp_websocket_subscription_sync_graceful_shutdown( + server, subscription_str +): + """Note: this test will simulate a control-C happening while a sync subscription + is in progress. To do that we will throw a KeyboardInterrupt exception inside + the subscription async generator. + + The code should then do a clean close: + - send stop messages for each active query + - send a connection_terminate message + Then the KeyboardInterrupt will be reraise (to warn potential user code) + + This test does not work on Windows but the behaviour with Windows is correct. + """ + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + interrupt_task = None + + with pytest.raises(KeyboardInterrupt): + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + + if count == 5: + + # Simulate a KeyboardInterrupt in the generator + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + interrupt_task = asyncio.ensure_future( + client.session._generator.athrow(KeyboardInterrupt) + ) + + count -= 1 + + assert count == 4 + + # Catch interrupt_task exception to remove warning + interrupt_task.exception() + + # Check that the server received a connection_terminate message last + assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_running_in_thread( + event_loop, server, subscription_str, run_sync_test +): + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + def test_code(): + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_starwars], indirect=True) +@pytest.mark.parametrize("subscription_str", [starwars_subscription_str]) +@pytest.mark.parametrize( + "client_params", + [ + {"schema": StarWarsSchema}, + {"introspection": StarWarsIntrospection}, + {"schema": StarWarsTypeDef}, + ], +) +async def test_async_aiohttp_client_validation( + event_loop, server, subscription_str, client_params +): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport, **client_params) + + async with client as session: + + variable_values = {"ep": "JEDI"} + + subscription = gql(subscription_str) + + expected = [] + + async for result in session.subscribe( + subscription, variable_values=variable_values, parse_result=False + ): + + review = result["reviewAdded"] + expected.append(review) + + assert "stars" in review + assert "commentary" in review + assert "episode" in review + + assert expected[0] == starwars_expected_one + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_closing_transport(event_loop, server, subscription_str): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + session.transport.websocket._writer._closing = True + + with pytest.raises(ConnectionResetError) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "Cannot write to closing transport" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_subscribe_on_null_transport(event_loop, server, subscription_str): + + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + + transport = AIOHTTPWebsocketsTransport(url=url) + + client = Client(transport=transport) + count = 1 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + + session.transport.websocket = None + with pytest.raises(TransportClosed) as e: + async for _ in session.subscribe(subscription): + pass + + assert e.value.args[0] == "WebSocket connection is closed" From 9af51464e28f4cc38a853c466f656be633c8c7f1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 28 Jul 2024 19:33:14 +0200 Subject: [PATCH 156/239] Remove Python 3.7 support - cleaning old code (#495) --- README.md | 2 +- docs/gql-cli/intro.rst | 2 +- docs/intro.rst | 2 +- gql/client.py | 17 +++-------------- gql/transport/aiohttp_websockets.py | 17 +---------------- gql/transport/websockets_base.py | 5 ----- tests/custom_scalars/test_datetime.py | 19 ------------------- ...iohttp_websocket_graphqlws_subscription.py | 3 --- tests/test_aiohttp_websocket_subscription.py | 3 --- tests/test_graphqlws_subscription.py | 3 --- tests/test_phoenix_channel_subscription.py | 11 ----------- tests/test_websocket_subscription.py | 3 --- tox.ini | 5 ++--- 13 files changed, 9 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index a100e32d..cbc53af6 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # GQL -This is a GraphQL client for Python 3.7+. +This is a GraphQL client for Python 3.8+. Plays nicely with `graphene`, `graphql-core`, `graphql-js` and any other GraphQL implementation compatible with the spec. GQL architecture is inspired by `React-Relay` and `Apollo-Client`. diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst index 925958ee..f88b60a1 100644 --- a/docs/gql-cli/intro.rst +++ b/docs/gql-cli/intro.rst @@ -3,7 +3,7 @@ gql-cli ======= -GQL provides a python 3.7+ script, called `gql-cli` which allows you to execute +GQL provides a python script, called `gql-cli` which allows you to execute GraphQL queries directly from the terminal. This script supports http(s) or websockets protocols. diff --git a/docs/intro.rst b/docs/intro.rst index 21de16bd..3151755d 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -1,7 +1,7 @@ Introduction ============ -`GQL 3`_ is a `GraphQL`_ Client for Python 3.7+ which plays nicely with other +`GQL 3`_ is a `GraphQL`_ Client for Python 3.8+ which plays nicely with other graphql implementations compatible with the spec. Under the hood, it uses `GraphQL-core`_ which is a Python port of `GraphQL.js`_, diff --git a/gql/client.py b/gql/client.py index a9a2c7e2..5fd038d0 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,6 +1,5 @@ import asyncio import logging -import sys import time import warnings from concurrent.futures import Future @@ -13,6 +12,7 @@ Dict, Generator, List, + Literal, Optional, Tuple, TypeVar, @@ -44,17 +44,6 @@ from .utilities import serialize_variable_values from .utils import str_first_element -""" -Load the appropriate instance of the Literal type -Note: we cannot use try: except ImportError because of the following mypy issue: -https://github.com/python/mypy/issues/8520 -""" -if sys.version_info[:2] >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal # pragma: no cover - - log = logging.getLogger(__name__) @@ -1368,8 +1357,8 @@ async def _subscribe( **kwargs, ) - # Keep a reference to the inner generator to allow the user to call aclose() - # before a break if python version is too old (pypy3 py 3.6.1) + # Keep a reference to the inner generator + # This is only used for the tests to simulate a KeyboardInterrupt event self._generator = inner_generator try: diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index ff310a82..e7fb6815 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,7 +1,6 @@ import asyncio import json import logging -import sys import warnings from contextlib import suppress from ssl import SSLContext @@ -10,6 +9,7 @@ AsyncGenerator, Collection, Dict, + Literal, Mapping, Optional, Tuple, @@ -32,16 +32,6 @@ TransportServerError, ) -""" -Load the appropriate instance of the Literal type -Note: we cannot use try: except ImportError because of the following mypy issue: -https://github.com/python/mypy/issues/8520 -""" -if sys.version_info[:2] >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal # pragma: no cover - log = logging.getLogger("gql.transport.aiohttp_websockets") ParsedAnswer = Tuple[str, Optional[ExecutionResult]] @@ -1124,11 +1114,6 @@ async def execute( async for result in generator: first_result = result - - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - await generator.aclose() - break if first_result is None: diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index 5c7713e9..accca275 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -431,11 +431,6 @@ async def execute( async for result in generator: first_result = result - - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - await generator.aclose() - break if first_result is None: diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index b3e717c5..5a36669c 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -1,7 +1,6 @@ from datetime import datetime, timedelta from typing import Any, Dict, Optional -import pytest from graphql.error import GraphQLError from graphql.language import ValueNode from graphql.pyutils import inspect @@ -110,9 +109,6 @@ def resolve_seconds(root, _info, interval): schema = GraphQLSchema(query=queryType) -@pytest.mark.skipif( - not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" -) def test_shift_days(): client = Client(schema=schema, parse_results=True, serialize_variables=True) @@ -132,9 +128,6 @@ def test_shift_days(): assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") -@pytest.mark.skipif( - not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" -) def test_shift_days_serialized_manually_in_query(): client = Client(schema=schema) @@ -152,9 +145,6 @@ def test_shift_days_serialized_manually_in_query(): assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") -@pytest.mark.skipif( - not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" -) def test_shift_days_serialized_manually_in_variables(): client = Client(schema=schema, parse_results=True) @@ -172,9 +162,6 @@ def test_shift_days_serialized_manually_in_variables(): assert result["shiftDays"] == datetime.fromisoformat("2021-11-17T11:58:13.461161") -@pytest.mark.skipif( - not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" -) def test_latest(): client = Client(schema=schema, parse_results=True) @@ -197,9 +184,6 @@ def test_latest(): assert result["latest"] == in_five_days -@pytest.mark.skipif( - not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" -) def test_seconds(): client = Client(schema=schema) @@ -221,9 +205,6 @@ def test_seconds(): assert result["seconds"] == 432000 -@pytest.mark.skipif( - not hasattr(datetime, "fromisoformat"), reason="fromisoformat is new in Python 3.7+" -) def test_seconds_omit_optional_start_argument(): client = Client(schema=schema) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index e5db7ca1..86ff96ab 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -268,9 +268,6 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( assert number == count if count <= 5: - # Note: the following line is only necessary for pypy3 v3.6.1 - if sys.version_info < (3, 7): - await session._generator.aclose() break count -= 1 diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 3ebf4dbc..4bc6ad3c 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -258,9 +258,6 @@ async def test_aiohttp_websocket_subscription_break( assert number == count if count <= 5: - # Note: the following line is only necessary for pypy3 v3.6.1 - if sys.version_info < (3, 7): - await session._generator.aclose() break count -= 1 diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index cb705368..deeae395 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -268,9 +268,6 @@ async def test_graphqlws_subscription_break( assert number == count if count <= 5: - # Note: the following line is only necessary for pypy3 v3.6.1 - if sys.version_info < (3, 7): - await session._generator.aclose() break count -= 1 diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 6367945d..34564c6d 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -1,6 +1,5 @@ import asyncio import json -import sys import pytest from parse import search @@ -208,11 +207,6 @@ async def test_phoenix_channel_subscription( assert number == count if number == end_count: - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - # In more recent versions, 'break' will trigger __aexit__. - if sys.version_info < (3, 7): - await session._generator.aclose() print("break") break @@ -390,11 +384,6 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): assert heartbeat_count == i if heartbeat_count == 5: - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - # In more recent versions, 'break' will trigger __aexit__. - if sys.version_info < (3, 7): - await session._generator.aclose() break i += 1 diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 4419783b..38307349 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -189,9 +189,6 @@ async def test_websocket_subscription_break( assert number == count if count <= 5: - # Note: the following line is only necessary for pypy3 v3.6.1 - if sys.version_info < (3, 7): - await session._generator.aclose() break count -= 1 diff --git a/tox.ini b/tox.ini index e4794be5..7a639572 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,10 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{37,38,39,310,311,312,py3} + py{38,39,310,311,312,py3} [gh-actions] python = - 3.7: py37 3.8: py38 3.9: py39 3.10: py310 @@ -29,7 +28,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{37,39,310,311,312,py3}: pytest {posargs:tests} + py{39,310,311,312,py3}: pytest {posargs:tests} py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From 2ee2583568fab9a5bb952e043bbacb534efeb74f Mon Sep 17 00:00:00 2001 From: Rahul Patel Date: Thu, 1 Aug 2024 16:47:52 +0530 Subject: [PATCH 157/239] =?UTF-8?q?Update=20annotation=20for=20client.exec?= =?UTF-8?q?ute=5Fbatch,=20get=5Fexecution=5Fresult=20argu=E2=80=A6=20(#483?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hanusz Leszek --- gql/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gql/client.py b/gql/client.py index 5fd038d0..e1b168a7 100644 --- a/gql/client.py +++ b/gql/client.py @@ -251,7 +251,7 @@ def execute_batch_sync( *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - get_execution_result: Literal[False], + get_execution_result: Literal[False] = ..., **kwargs, ) -> List[Dict[str, Any]]: ... # pragma: no cover @@ -487,7 +487,7 @@ def execute_batch( *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - get_execution_result: Literal[False], + get_execution_result: Literal[False] = ..., **kwargs, ) -> List[Dict[str, Any]]: ... # pragma: no cover @@ -1096,7 +1096,7 @@ def execute_batch( *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - get_execution_result: Literal[False], + get_execution_result: Literal[False] = ..., **kwargs, ) -> List[Dict[str, Any]]: ... # pragma: no cover From 1c657d87013f74ff36fdd5ae042e5b0b5f75065f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 1 Aug 2024 14:06:34 +0200 Subject: [PATCH 158/239] Fix ssl=None not supported on recent versions of aiohttp (#496) --- gql/transport/aiohttp_websockets.py | 35 +++++++++++++++++------------ 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index e7fb6815..18699b5e 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -845,7 +845,27 @@ async def connect(self) -> None: if self.websocket is None and not self._connecting: self._connecting = True - connect_args: Dict[str, Any] = {} + connect_args: Dict[str, Any] = { + "url": self.url, + "headers": self.headers, + "auth": self.auth, + "heartbeat": self.heartbeat, + "origin": self.origin, + "params": self.params, + "protocols": self.supported_subprotocols, + "proxy": self.proxy, + "proxy_auth": self.proxy_auth, + "proxy_headers": self.proxy_headers, + "timeout": self.websocket_close_timeout, + "receive_timeout": self.receive_timeout, + } + + if self.ssl is not None: + connect_args.update( + { + "ssl": self.ssl, + } + ) # Adding custom parameters passed from init if self.connect_args: @@ -857,19 +877,6 @@ async def connect(self) -> None: # Set the _connecting flag to False after in all cases self.websocket = await asyncio.wait_for( self.session.ws_connect( - url=self.url, - headers=self.headers, - auth=self.auth, - heartbeat=self.heartbeat, - origin=self.origin, - params=self.params, - protocols=self.supported_subprotocols, - proxy=self.proxy, - proxy_auth=self.proxy_auth, - proxy_headers=self.proxy_headers, - timeout=self.websocket_close_timeout, - receive_timeout=self.receive_timeout, - ssl=self.ssl, **connect_args, ), self.connect_timeout, From 79dbe457bdf6a2c915e2cbfa182dc88625bc405b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 27 Oct 2024 14:39:21 +0100 Subject: [PATCH 159/239] Adding MIT License classifier (#498) --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 8828f8f0..a5efffa3 100644 --- a/setup.py +++ b/setup.py @@ -84,6 +84,7 @@ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", + "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.8", From dbafcd064966cec172e64a3e4b642703cad51b44 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 28 Oct 2024 12:08:44 +0100 Subject: [PATCH 160/239] Delete gql-checker folder (#502) --- MANIFEST.in | 1 - gql-checker/.gitignore | 9 - gql-checker/.travis.yml | 12 - gql-checker/LICENSE | 21 -- gql-checker/MANIFEST.in | 4 - gql-checker/README.rst | 29 -- gql-checker/gql_checker/__about__.py | 18 - gql-checker/gql_checker/__init__.py | 118 ------- gql-checker/gql_checker/flake8_linter.py | 50 --- gql-checker/gql_checker/pylama_linter.py | 36 -- gql-checker/gql_checker/stdlib_list.py | 330 ------------------- gql-checker/setup.cfg | 2 - gql-checker/setup.py | 64 ---- gql-checker/tests/__init__.py | 0 gql-checker/tests/introspection_schema.json | 1 - gql-checker/tests/test_cases/bad_query.py | 7 - gql-checker/tests/test_cases/noqa.py | 3 - gql-checker/tests/test_cases/syntax_error.py | 3 - gql-checker/tests/test_cases/validation.py | 78 ----- gql-checker/tests/test_flake8_linter.py | 56 ---- gql-checker/tests/test_pylama_linter.py | 50 --- gql-checker/tests/utils.py | 19 -- gql-checker/tox.ini | 38 --- 23 files changed, 949 deletions(-) delete mode 100644 gql-checker/.gitignore delete mode 100644 gql-checker/.travis.yml delete mode 100644 gql-checker/LICENSE delete mode 100644 gql-checker/MANIFEST.in delete mode 100644 gql-checker/README.rst delete mode 100644 gql-checker/gql_checker/__about__.py delete mode 100644 gql-checker/gql_checker/__init__.py delete mode 100644 gql-checker/gql_checker/flake8_linter.py delete mode 100644 gql-checker/gql_checker/pylama_linter.py delete mode 100644 gql-checker/gql_checker/stdlib_list.py delete mode 100644 gql-checker/setup.cfg delete mode 100644 gql-checker/setup.py delete mode 100644 gql-checker/tests/__init__.py delete mode 100644 gql-checker/tests/introspection_schema.json delete mode 100644 gql-checker/tests/test_cases/bad_query.py delete mode 100644 gql-checker/tests/test_cases/noqa.py delete mode 100644 gql-checker/tests/test_cases/syntax_error.py delete mode 100644 gql-checker/tests/test_cases/validation.py delete mode 100644 gql-checker/tests/test_flake8_linter.py delete mode 100644 gql-checker/tests/test_pylama_linter.py delete mode 100644 gql-checker/tests/utils.py delete mode 100644 gql-checker/tox.ini diff --git a/MANIFEST.in b/MANIFEST.in index c0f653ab..3df67dcf 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -18,6 +18,5 @@ recursive-include docs *.txt *.rst conf.py Makefile make.bat *.jpg *.png *.gif recursive-include docs/code_examples *.py prune docs/_build -prune gql-checker global-exclude *.py[co] __pycache__ diff --git a/gql-checker/.gitignore b/gql-checker/.gitignore deleted file mode 100644 index fe0fe6cc..00000000 --- a/gql-checker/.gitignore +++ /dev/null @@ -1,9 +0,0 @@ -*.pyc -*.pyo -__pycache__ -*.egg-info -*~ -.coverage -.tox/ -build/ -dist/ diff --git a/gql-checker/.travis.yml b/gql-checker/.travis.yml deleted file mode 100644 index eb155b3b..00000000 --- a/gql-checker/.travis.yml +++ /dev/null @@ -1,12 +0,0 @@ -language: python -addons: - apt: - sources: - - deadsnakes - packages: - - python3.5 -install: - - pip install tox -script: - - tox -sudo: false diff --git a/gql-checker/LICENSE b/gql-checker/LICENSE deleted file mode 100644 index 141776c3..00000000 --- a/gql-checker/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 GraphQL Python - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/gql-checker/MANIFEST.in b/gql-checker/MANIFEST.in deleted file mode 100644 index 93eaa028..00000000 --- a/gql-checker/MANIFEST.in +++ /dev/null @@ -1,4 +0,0 @@ -include LICENSE -include README.md -recursive-include tests * -recursive-exclude tests *.py[co] diff --git a/gql-checker/README.rst b/gql-checker/README.rst deleted file mode 100644 index 1fd9c42b..00000000 --- a/gql-checker/README.rst +++ /dev/null @@ -1,29 +0,0 @@ -gql-checker -=========== - -|Build Status| - -A `flake8 `__ and -`Pylama `__ plugin that checks the -all the static gql calls given a GraphQL schema. - -It will not check anything else about the gql calls. Merely that the -GraphQL syntax is correct and it validates against the provided schema. - -Warnings --------- - -This package adds 3 new flake8 warnings - -- ``GQL100``: The gql query is doesn't match GraphQL syntax -- ``GQL101``: The gql query have valid syntax but doesn't validate against provided schema - -Configuration -------------- - -You will want to set the ``gql-introspection-schema`` option to a -file with the json introspection of the schema. - - -.. |Build Status| image:: https://travis-ci.org/graphql-python/gql-checker.png?branch=master - :target: https://travis-ci.org/graphql-python/gql-checker diff --git a/gql-checker/gql_checker/__about__.py b/gql-checker/gql_checker/__about__.py deleted file mode 100644 index c2f25195..00000000 --- a/gql-checker/gql_checker/__about__.py +++ /dev/null @@ -1,18 +0,0 @@ -__all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", -] - -__title__ = "gql-checker" -__summary__ = ( - "Flake8 and pylama plugin that checks gql GraphQL calls." -) -__uri__ = "https://github.com/graphql-python/gql-checker" - -__version__ = "0.1" - -__author__ = "Syrus Akbary" -__email__ = "me@syrusakbary.com" - -__license__ = "MIT" -__copyright__ = "Copyright 2016 %s" % __author__ diff --git a/gql-checker/gql_checker/__init__.py b/gql-checker/gql_checker/__init__.py deleted file mode 100644 index 8b8f09bd..00000000 --- a/gql-checker/gql_checker/__init__.py +++ /dev/null @@ -1,118 +0,0 @@ -import ast -import json - -import pycodestyle - -from gql_checker.__about__ import ( - __author__, __copyright__, __email__, __license__, __summary__, __title__, - __uri__, __version__ -) -from gql_checker.stdlib_list import STDLIB_NAMES -from graphql import Source, validate, parse, build_client_schema - - -__all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", -] - -GQL_SYNTAX_ERROR = 'GQL100' -GQL_VALIDATION_ERROR = 'GQL101' - -class ImportVisitor(ast.NodeVisitor): - """ - This class visits all the gql calls. - """ - - def __init__(self, filename, options): - self.filename = filename - self.options = options or {} - self.calls = [] - - def visit_Call(self, node): # noqa - if node.func.id == 'gql': - self.calls.append(node) - - def node_query(self, node): - """ - Return the query for the gql call node - """ - - if isinstance(node, ast.Call): - assert node.args - arg = node.args[0] - if not isinstance(arg, ast.Str): - return - else: - raise TypeError(type(node)) - - return arg.s - - -class ImportOrderChecker(object): - visitor_class = ImportVisitor - options = None - - def __init__(self, filename, tree): - self.tree = tree - self.filename = filename - self.lines = None - - def load_file(self): - if self.filename in ("stdin", "-", None): - self.filename = "stdin" - self.lines = pycodestyle.stdin_get_value().splitlines(True) - else: - self.lines = pycodestyle.readlines(self.filename) - - if not self.tree: - self.tree = ast.parse("".join(self.lines)) - - def get_schema(self): - gql_introspection_schema = self.options.get('gql_introspection_schema') - if gql_introspection_schema: - try: - with open(gql_introspection_schema) as data_file: - introspection_schema = json.load(data_file) - return build_client_schema(introspection_schema) - except IOError as e: - raise Exception(f"Cannot find the provided introspection schema. {e}") - - schema = self.options.get('schema') - assert schema, 'Need to provide schema' - - def validation_errors(self, ast): - return validate(self.get_schema(), ast) - - def error(self, node, code, message): - raise NotImplemented() - - def check_gql(self): - if not self.tree or not self.lines: - self.load_file() - - visitor = self.visitor_class(self.filename, self.options) - visitor.visit(self.tree) - - for node in visitor.calls: - # Lines with the noqa flag are ignored entirely - if pycodestyle.noqa(self.lines[node.lineno - 1]): - continue - - query = visitor.node_query(node) - if not query: - continue - - try: - source = Source(query, 'gql query') - ast = parse(source) - except Exception as e: - message = str(e) - yield self.error(node, GQL_SYNTAX_ERROR, message) - continue - - validation_errors = self.validation_errors(ast) - if validation_errors: - for error in validation_errors: - message = str(error) - yield self.error(node, GQL_VALIDATION_ERROR, message) diff --git a/gql-checker/gql_checker/flake8_linter.py b/gql-checker/gql_checker/flake8_linter.py deleted file mode 100644 index 009421c1..00000000 --- a/gql-checker/gql_checker/flake8_linter.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import absolute_import - -import gql_checker -from gql_checker import ImportOrderChecker - - -class Linter(ImportOrderChecker): - name = "gql" - version = gql_checker.__version__ - - def __init__(self, tree, filename): - super(Linter, self).__init__(filename, tree) - - @classmethod - def add_options(cls, parser): - # List of application import names. They go last. - parser.add_option( - "--gql-introspection-schema", - metavar="FILE", - help="Import names to consider as application specific" - ) - parser.add_option( - "--gql-typedef-schema", - default='', - action="store", - type="string", - help=("Style to follow. Available: " - "cryptography, google, smarkets, pep8") - ) - parser.config_options.append("gql-introspection-schema") - parser.config_options.append("gql-typedef-schema") - - @classmethod - def parse_options(cls, options): - optdict = {} - - optdict = dict( - gql_introspection_schema=options.gql_introspection_schema, - gql_typedef_schema=options.gql_typedef_schema, - ) - - cls.options = optdict - - def error(self, node, code, message): - lineno, col_offset = node.lineno, node.col_offset - return lineno, col_offset, f'{code} {message}', Linter - - def run(self): - for error in self.check_gql(): - yield error diff --git a/gql-checker/gql_checker/pylama_linter.py b/gql-checker/gql_checker/pylama_linter.py deleted file mode 100644 index 994e6e19..00000000 --- a/gql-checker/gql_checker/pylama_linter.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import absolute_import - -from pylama.lint import Linter as BaseLinter - -import gql_checker -from gql_checker import ImportOrderChecker - - -class Linter(ImportOrderChecker, BaseLinter): - name = "gql" - version = gql_checker.__version__ - - def __init__(self): - super(Linter, self).__init__(None, None) - - def allow(self, path): - return path.endswith(".py") - - def error(self, node, code, message): - lineno, col_offset = node.lineno, node.col_offset - return { - "lnum": lineno, - "col": col_offset, - "text": message, - "type": code - } - - def run(self, path, **meta): - self.filename = path - self.tree = None - self.options = dict( - {'schema': ''}, - **meta) - - for error in self.check_gql(): - yield error diff --git a/gql-checker/gql_checker/stdlib_list.py b/gql-checker/gql_checker/stdlib_list.py deleted file mode 100644 index 4552d422..00000000 --- a/gql-checker/gql_checker/stdlib_list.py +++ /dev/null @@ -1,330 +0,0 @@ -STDLIB_NAMES = set(( - "AL", - "BaseHTTPServer", - "Bastion", - "Binary", - "Boolean", - "CGIHTTPServer", - "ColorPicker", - "ConfigParser", - "Cookie", - "DEVICE", - "DocXMLRPCServer", - "EasyDialogs", - "FL", - "FrameWork", - "GL", - "HTMLParser", - "MacOS", - "Mapping", - "MimeWriter", - "MiniAEFrame", - "Numeric", - "Queue", - "SUNAUDIODEV", - "ScrolledText", - "Sequence", - "Set", - "SimpleHTTPServer", - "SimpleXMLRPCServer", - "SocketServer", - "StringIO", - "Text", - "Tix", - "Tkinter", - "UserDict", - "UserList", - "UserString", - "__builtin__", - "__future__", - "__main__", - "_dummy_thread", - "_thread", - "abc", - "aepack", - "aetools", - "aetypes", - "aifc", - "al", - "anydbm", - "argparse", - "array", - "ast", - "asynchat", - "asyncio", - "asyncore", - "atexit", - "audioop", - "autoGIL", - "base64", - "bdb", - "binascii", - "binhex", - "bisect", - "bsddb", - "builtins", - "bz2", - "cPickle", - "cProfile", - "cStringIO", - "calendar", - "cd", - "cgi", - "cgitb", - "chunk", - "cmath", - "cmd", - "code", - "codecs", - "codeop", - "collections", - "collections.abc", - "colorsys", - "commands", - "compileall", - "concurrent.futures", - "configparser", - "contextlib", - "cookielib", - "copy", - "copy_reg", - "copyreg", - "crypt", - "csv", - "ctypes", - "curses", - "curses.ascii", - "curses.panel", - "curses.textpad", - "curses.wrapper", - "datetime", - "dbhash", - "dbm", - "decimal", - "difflib", - "dircache", - "dis", - "distutils", - "dl", - "doctest", - "dumbdbm", - "dummy_thread", - "dummy_threading", - "email", - "ensurepip", - "enum", - "errno", - "faulthandler", - "fcntl", - "filecmp", - "fileinput", - "findertools", - "fl", - "flp", - "fm", - "fnmatch", - "formatter", - "fpectl", - "fpformat", - "fractions", - "ftplib", - "functools", - "future_builtins", - "gc", - "gdbm", - "gensuitemodule", - "getopt", - "getpass", - "gettext", - "gl", - "glob", - "grp", - "gzip", - "hashlib", - "heapq", - "hmac", - "hotshot", - "html", - "html.entities", - "html.parser", - "htmlentitydefs", - "htmllib", - "http", - "http.client", - "http.cookiejar", - "http.cookies", - "http.server", - "httplib", - "ic", - "imageop", - "imaplib", - "imgfile", - "imghdr", - "imp", - "importlib", - "imputil", - "inspect", - "io", - "ipaddress", - "itertools", - "jpeg", - "json", - "keyword", - "linecache", - "locale", - "logging", - "logging.config", - "logging.handlers", - "lzma", - "macostools", - "macpath", - "macurl2path", - "mailbox", - "mailcap", - "marshal", - "math", - "md5", - "mhlib", - "mimetools", - "mimetypes", - "mimify", - "mmap", - "modulefinder", - "msilib", - "multifile", - "multiprocessing", - "mutex", - "netrc", - "new", - "nis", - "nntplib", - "nturl2path", - "numbers", - "operator", - "optparse", - "os", - "os.path", - "ossaudiodev", - "parser", - "pathlib", - "pdb", - "pickle", - "pickletools", - "pipes", - "pkgutil", - "platform", - "plistlib", - "popen2", - "poplib", - "posix", - "posixfile", - "posixpath", - "pprint", - "profile", - "pstats", - "pty", - "pwd", - "py_compile", - "pyclbr", - "pydoc", - "queue", - "quopri", - "random", - "re", - "readline", - "repr", - "reprlib", - "resource", - "rexec", - "rfc822", - "rlcompleter", - "robotparser", - "runpy", - "sched", - "select", - "sets", - "sgmllib", - "sha", - "shelve", - "shlex", - "shutil", - "signal", - "site", - "smtpd", - "smtplib", - "sndhdr", - "socket", - "socketserver", - "spwd", - "sqlite3", - "ssl", - "stat", - "statistics", - "statvfs", - "string", - "stringprep", - "struct", - "subprocess", - "sunau", - "sunaudiodev", - "symbol", - "symtable", - "sys", - "sysconfig", - "syslog", - "tabnanny", - "tarfile", - "telnetlib", - "tempfile", - "termios", - "test", - "test.support", - "test.test_support", - "textwrap", - "thread", - "threading", - "time", - "timeit", - "tkinter", - "tkinter.scrolledtext", - "tkinter.tix", - "tkinter.ttk", - "token", - "tokenize", - "trace", - "traceback", - "tracemalloc", - "ttk", - "tty", - "turtle", - "types", - "typing", - "unicodedata", - "unittest", - "unittest.mock", - "urllib", - "urllib.error", - "urllib.parse", - "urllib.request", - "urllib.response", - "urllib.robotparser", - "urllib2", - "urlparse", - "user", - "uu", - "uuid", - "venv", - "warnings", - "wave", - "weakref", - "webbrowser", - "whichdb", - "winsound", - "wsgiref", - "xdrlib", - "xml", - "xmlrpclib", - "zipfile", - "zipimport", - "zlib", -)) diff --git a/gql-checker/setup.cfg b/gql-checker/setup.cfg deleted file mode 100644 index 5e409001..00000000 --- a/gql-checker/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[wheel] -universal = 1 diff --git a/gql-checker/setup.py b/gql-checker/setup.py deleted file mode 100644 index 8ec1bf74..00000000 --- a/gql-checker/setup.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -from setuptools import setup, find_packages - - -base_dir = os.path.dirname(__file__) - -about = {} -with open(os.path.join(base_dir, "gql_checker", "__about__.py")) as f: - exec(f.read(), about) - -with open(os.path.join(base_dir, "README.rst")) as f: - long_description = f.read() - - -setup( - name=about["__title__"], - version=about["__version__"], - - description=about["__summary__"], - long_description=long_description, - license=about["__license__"], - url=about["__uri__"], - author=about["__author__"], - author_email=about["__email__"], - - packages=find_packages(exclude=["tests", "tests.*"]), - zip_safe=False, - - install_requires=[ - "pycodestyle" - ], - - tests_require=[ - "pytest", - "flake8", - "pycodestyle", - "pylama" - ], - - py_modules=['gql_checker'], - entry_points={ - 'flake8.extension': [ - 'GQL = gql_checker.flake8_linter:Linter', - ], - 'pylama.linter': [ - 'gql_checker = gql_checker.pylama_linter:Linter' - ] - }, - - classifiers=[ - "Intended Audience :: Developers", - "Development Status :: 4 - Beta", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - ( - "License :: OSI Approved :: " - "GNU Lesser General Public License v3 (LGPLv3)" - ), - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Software Development :: Quality Assurance", - "Operating System :: OS Independent" - ] -) diff --git a/gql-checker/tests/__init__.py b/gql-checker/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/gql-checker/tests/introspection_schema.json b/gql-checker/tests/introspection_schema.json deleted file mode 100644 index b4f5e0b4..00000000 --- a/gql-checker/tests/introspection_schema.json +++ /dev/null @@ -1 +0,0 @@ -{"__schema": {"queryType": {"name": "Query"}, "mutationType": null, "subscriptionType": null, "types": [{"kind": "OBJECT", "name": "Query", "description": null, "fields": [{"name": "droid", "description": null, "args": [{"name": "id", "description": "id of the droid", "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "defaultValue": null}], "type": {"kind": "OBJECT", "name": "Droid", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "hero", "description": null, "args": [{"name": "episode", "description": "If omitted, returns the hero of the whole saga. If provided, returns the hero of that particular episode.", "type": {"kind": "ENUM", "name": "Episode", "ofType": null}, "defaultValue": null}], "type": {"kind": "INTERFACE", "name": "Character", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "human", "description": null, "args": [{"name": "id", "description": "id of the human", "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "defaultValue": null}], "type": {"kind": "OBJECT", "name": "Human", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "SCALAR", "name": "String", "description": "The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text.", "fields": null, "inputFields": null, "interfaces": null, "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "Droid", "description": "A mechanical creature in the Star Wars universe.", "fields": [{"name": "appearsIn", "description": "Which movies they appear in.", "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "ENUM", "name": "Episode", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "friends", "description": "The friends of the droid, or an empty list if they have none.", "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "INTERFACE", "name": "Character", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "id", "description": "The id of the droid.", "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "name", "description": "The name of the droid.", "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "primaryFunction", "description": "The primary function of the droid.", "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [{"kind": "INTERFACE", "name": "Character", "ofType": null}], "enumValues": null, "possibleTypes": null}, {"kind": "INTERFACE", "name": "Character", "description": "A character in the Star Wars Trilogy", "fields": [{"name": "appearsIn", "description": "Which movies they appear in.", "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "ENUM", "name": "Episode", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "friends", "description": "The friends of the character, or an empty list if they have none.", "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "INTERFACE", "name": "Character", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "id", "description": "The id of the character.", "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "name", "description": "The name of the character.", "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": null, "enumValues": null, "possibleTypes": [{"kind": "OBJECT", "name": "Droid", "ofType": null}, {"kind": "OBJECT", "name": "Human", "ofType": null}]}, {"kind": "ENUM", "name": "Episode", "description": "One of the films in the Star Wars Trilogy", "fields": null, "inputFields": null, "interfaces": null, "enumValues": [{"name": "EMPIRE", "description": "Released in 1980.", "isDeprecated": false, "deprecationReason": null}, {"name": "JEDI", "description": "Released in 1983.", "isDeprecated": false, "deprecationReason": null}, {"name": "NEWHOPE", "description": "Released in 1977.", "isDeprecated": false, "deprecationReason": null}], "possibleTypes": null}, {"kind": "OBJECT", "name": "Human", "description": "A humanoid creature in the Star Wars universe.", "fields": [{"name": "appearsIn", "description": "Which movies they appear in.", "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "ENUM", "name": "Episode", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "friends", "description": "The friends of the human, or an empty list if they have none.", "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "INTERFACE", "name": "Character", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "homePlanet", "description": "The home planet of the human, or null if unknown.", "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "id", "description": "The id of the human.", "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "name", "description": "The name of the human.", "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [{"kind": "INTERFACE", "name": "Character", "ofType": null}], "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "__Schema", "description": "A GraphQL Schema defines the capabilities of a GraphQL server. It exposes all available types and directives on the server, as well as the entry points for query, mutation and subscription operations.", "fields": [{"name": "types", "description": "A list of all types supported by this server.", "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Type", "ofType": null}}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "queryType", "description": "The type that query operations will be rooted at.", "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Type", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "mutationType", "description": "If this server supports mutation, the type that mutation operations will be rooted at.", "args": [], "type": {"kind": "OBJECT", "name": "__Type", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "subscriptionType", "description": "If this server support subscription, the type that subscription operations will be rooted at.", "args": [], "type": {"kind": "OBJECT", "name": "__Type", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "directives", "description": "A list of all directives supported by this server.", "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Directive", "ofType": null}}}}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "__Type", "description": "The fundamental unit of any GraphQL Schema is the type. There are many kinds of types in GraphQL as represented by the `__TypeKind` enum.\n\nDepending on the kind of a type, certain fields describe information about that type. Scalar types provide no information beyond a name and description, while Enum types provide their values. Object and Interface types provide the fields they describe. Abstract types, Union and Interface, provide the Object types possible at runtime. List and NonNull types compose other types.", "fields": [{"name": "kind", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "ENUM", "name": "__TypeKind", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "name", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "description", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "fields", "description": null, "args": [{"name": "includeDeprecated", "description": null, "type": {"kind": "SCALAR", "name": "Boolean", "ofType": null}, "defaultValue": "false"}], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Field", "ofType": null}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "interfaces", "description": null, "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Type", "ofType": null}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "possibleTypes", "description": null, "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Type", "ofType": null}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "enumValues", "description": null, "args": [{"name": "includeDeprecated", "description": null, "type": {"kind": "SCALAR", "name": "Boolean", "ofType": null}, "defaultValue": "false"}], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__EnumValue", "ofType": null}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "inputFields", "description": null, "args": [], "type": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__InputValue", "ofType": null}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "ofType", "description": null, "args": [], "type": {"kind": "OBJECT", "name": "__Type", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "ENUM", "name": "__TypeKind", "description": "An enum describing what kind of type a given `__Type` is", "fields": null, "inputFields": null, "interfaces": null, "enumValues": [{"name": "SCALAR", "description": "Indicates this type is a scalar.", "isDeprecated": false, "deprecationReason": null}, {"name": "OBJECT", "description": "Indicates this type is an object. `fields` and `interfaces` are valid fields.", "isDeprecated": false, "deprecationReason": null}, {"name": "INTERFACE", "description": "Indicates this type is an interface. `fields` and `possibleTypes` are valid fields.", "isDeprecated": false, "deprecationReason": null}, {"name": "UNION", "description": "Indicates this type is a union. `possibleTypes` is a valid field.", "isDeprecated": false, "deprecationReason": null}, {"name": "ENUM", "description": "Indicates this type is an enum. `enumValues` is a valid field.", "isDeprecated": false, "deprecationReason": null}, {"name": "INPUT_OBJECT", "description": "Indicates this type is an input object. `inputFields` is a valid field.", "isDeprecated": false, "deprecationReason": null}, {"name": "LIST", "description": "Indicates this type is a list. `ofType` is a valid field.", "isDeprecated": false, "deprecationReason": null}, {"name": "NON_NULL", "description": "Indicates this type is a non-null. `ofType` is a valid field.", "isDeprecated": false, "deprecationReason": null}], "possibleTypes": null}, {"kind": "SCALAR", "name": "Boolean", "description": "The `Boolean` scalar type represents `true` or `false`.", "fields": null, "inputFields": null, "interfaces": null, "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "__Field", "description": "Object and Interface types are described by a list of Fields, each of which has a name, potentially a list of arguments, and a return type.", "fields": [{"name": "name", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "description", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "args", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__InputValue", "ofType": null}}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "type", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Type", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "isDeprecated", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "deprecationReason", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "__InputValue", "description": "Arguments provided to Fields or Directives and the input fields of an InputObject are represented as Input Values which describe their type and optionally a default value.", "fields": [{"name": "name", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "description", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "type", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__Type", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "defaultValue", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "__EnumValue", "description": "One possible value for a given Enum. Enum values are unique values, not a placeholder for a string or numeric value. However an Enum value is returned in a JSON response as a string.", "fields": [{"name": "name", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "description", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "isDeprecated", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "deprecationReason", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "OBJECT", "name": "__Directive", "description": "A Directive provides a way to describe alternate runtime execution and type validation behavior in a GraphQL document.\n\nIn some cases, you need to provide options to alter GraphQL's execution behavior in ways field arguments will not suffice, such as conditionally including or skipping a field. Directives provide this by describing additional information to the executor.", "fields": [{"name": "name", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "String", "ofType": null}}, "isDeprecated": false, "deprecationReason": null}, {"name": "description", "description": null, "args": [], "type": {"kind": "SCALAR", "name": "String", "ofType": null}, "isDeprecated": false, "deprecationReason": null}, {"name": "locations", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "ENUM", "name": "__DirectiveLocation", "ofType": null}}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "args", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "LIST", "name": null, "ofType": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "OBJECT", "name": "__InputValue", "ofType": null}}}}, "isDeprecated": false, "deprecationReason": null}, {"name": "onOperation", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "isDeprecated": true, "deprecationReason": "Use `locations`."}, {"name": "onFragment", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "isDeprecated": true, "deprecationReason": "Use `locations`."}, {"name": "onField", "description": null, "args": [], "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "isDeprecated": true, "deprecationReason": "Use `locations`."}], "inputFields": null, "interfaces": [], "enumValues": null, "possibleTypes": null}, {"kind": "ENUM", "name": "__DirectiveLocation", "description": "A Directive can be adjacent to many parts of the GraphQL language, a __DirectiveLocation describes one such possible adjacencies.", "fields": null, "inputFields": null, "interfaces": null, "enumValues": [{"name": "QUERY", "description": "Location adjacent to a query operation.", "isDeprecated": false, "deprecationReason": null}, {"name": "MUTATION", "description": "Location adjacent to a mutation operation.", "isDeprecated": false, "deprecationReason": null}, {"name": "SUBSCRIPTION", "description": "Location adjacent to a subscription operation.", "isDeprecated": false, "deprecationReason": null}, {"name": "FIELD", "description": "Location adjacent to a field.", "isDeprecated": false, "deprecationReason": null}, {"name": "FRAGMENT_DEFINITION", "description": "Location adjacent to a fragment definition.", "isDeprecated": false, "deprecationReason": null}, {"name": "FRAGMENT_SPREAD", "description": "Location adjacent to a fragment spread.", "isDeprecated": false, "deprecationReason": null}, {"name": "INLINE_FRAGMENT", "description": "Location adjacent to an inline fragment.", "isDeprecated": false, "deprecationReason": null}], "possibleTypes": null}], "directives": [{"name": "include", "description": null, "locations": ["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"], "args": [{"name": "if", "description": "Included when true.", "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "defaultValue": null}]}, {"name": "skip", "description": null, "locations": ["FIELD", "FRAGMENT_SPREAD", "INLINE_FRAGMENT"], "args": [{"name": "if", "description": "Skipped when true.", "type": {"kind": "NON_NULL", "name": null, "ofType": {"kind": "SCALAR", "name": "Boolean", "ofType": null}}, "defaultValue": null}]}]}} \ No newline at end of file diff --git a/gql-checker/tests/test_cases/bad_query.py b/gql-checker/tests/test_cases/bad_query.py deleted file mode 100644 index e2cf2705..00000000 --- a/gql-checker/tests/test_cases/bad_query.py +++ /dev/null @@ -1,7 +0,0 @@ -from gql import gql - -gql(''' -{ - id -} -''') # GQL101: Cannot query field "id" on type "Query". diff --git a/gql-checker/tests/test_cases/noqa.py b/gql-checker/tests/test_cases/noqa.py deleted file mode 100644 index d3b35d37..00000000 --- a/gql-checker/tests/test_cases/noqa.py +++ /dev/null @@ -1,3 +0,0 @@ -from gql import gql - -gql(''' wrong query ''') # noqa diff --git a/gql-checker/tests/test_cases/syntax_error.py b/gql-checker/tests/test_cases/syntax_error.py deleted file mode 100644 index 945ec000..00000000 --- a/gql-checker/tests/test_cases/syntax_error.py +++ /dev/null @@ -1,3 +0,0 @@ -from gql import gql - -gql(''' wrong query ''') # GQL100 diff --git a/gql-checker/tests/test_cases/validation.py b/gql-checker/tests/test_cases/validation.py deleted file mode 100644 index 503b2f94..00000000 --- a/gql-checker/tests/test_cases/validation.py +++ /dev/null @@ -1,78 +0,0 @@ -from gql import gql - - -gql(''' - query NestedQueryWithFragment { - hero { - ...NameAndAppearances - friends { - ...NameAndAppearances - friends { - ...NameAndAppearances - } - } - } - } - fragment NameAndAppearances on Character { - name - appearsIn - } -''') - -gql(''' - query HeroSpaceshipQuery { - hero { - favoriteSpaceship - } - } -''') # GQL101: Cannot query field "favoriteSpaceship" on type "Character". - -gql(''' - query HeroNoFieldsQuery { - hero - } -''') # GQL101: Field "hero" of type "Character" must have a sub selection. - - -gql(''' - query HeroFieldsOnScalarQuery { - hero { - name { - firstCharacterOfName - } - } - } -''') # GQL101: Field "name" of type "String" must not have a sub selection. - - -gql(''' - query DroidFieldOnCharacter { - hero { - name - primaryFunction - } - } -''') # GQL101: Cannot query field "primaryFunction" on type "Character". However, this field exists on "Droid". Perhaps you meant to use an inline fragment? - -gql(''' - query DroidFieldInFragment { - hero { - name - ...DroidFields - } - } - fragment DroidFields on Droid { - primaryFunction - } -''') - -gql(''' - query DroidFieldInFragment { - hero { - name - ... on Droid { - primaryFunction - } - } - } -''') diff --git a/gql-checker/tests/test_flake8_linter.py b/gql-checker/tests/test_flake8_linter.py deleted file mode 100644 index 7ed3c659..00000000 --- a/gql-checker/tests/test_flake8_linter.py +++ /dev/null @@ -1,56 +0,0 @@ -import ast -import re -import os - -import pycodestyle -import pytest - -from gql_checker.flake8_linter import Linter - -from tests.utils import extract_expected_errors - - -def load_test_cases(): - base_path = os.path.dirname(__file__) - test_case_path = os.path.join(base_path, "test_cases") - test_case_files = os.listdir(test_case_path) - - test_cases = [] - - for fname in test_case_files: - if not fname.endswith(".py"): - continue - - fullpath = os.path.join(test_case_path, fname) - data = open(fullpath).read() - tree = ast.parse(data, fullpath) - codes, messages = extract_expected_errors(data) - - test_cases.append((tree, fullpath, codes, messages)) - - return test_cases - - -@pytest.mark.parametrize( - "tree, filename, expected_codes, expected_messages", - load_test_cases() -) -def test_expected_error(tree, filename, expected_codes, expected_messages): - argv = [ - "--gql-introspection-schema=./tests/introspection_schema.json" - ] - - parser = pycodestyle.get_parser('', '') - Linter.add_options(parser) - options, args = parser.parse_args(argv) - Linter.parse_options(options) - - checker = Linter(tree, filename) - codes = [] - messages = [] - for lineno, col_offset, msg, instance in checker.run(): - code, message = msg.split(" ", 1) - codes.append(code) - messages.append(message) - assert codes == expected_codes - assert set(messages) >= set(expected_messages) diff --git a/gql-checker/tests/test_pylama_linter.py b/gql-checker/tests/test_pylama_linter.py deleted file mode 100644 index 8e942717..00000000 --- a/gql-checker/tests/test_pylama_linter.py +++ /dev/null @@ -1,50 +0,0 @@ -import ast -import os - -import pytest - -from gql_checker import pylama_linter - -from tests.utils import extract_expected_errors - - -def load_test_cases(): - base_path = os.path.dirname(__file__) - test_case_path = os.path.join(base_path, "test_cases") - test_case_files = os.listdir(test_case_path) - - test_cases = [] - - for fname in test_case_files: - if not fname.endswith(".py"): - continue - - fullpath = os.path.join(test_case_path, fname) - data = open(fullpath).read() - codes, messages = extract_expected_errors(data) - test_cases.append((fullpath, codes, messages)) - - return test_cases - - -@pytest.mark.parametrize( - "filename, expected_codes, expected_messages", - load_test_cases() -) -def test_expected_error(filename, expected_codes, expected_messages): - checker = pylama_linter.Linter() - assert checker.allow(filename) - - codes = [] - messages = [] - - options = { - "gql_introspection_schema": "./tests/introspection_schema.json" - } - - for error in checker.run(filename, **options): - codes.append(error['type']) - messages.append(error['text']) - - assert codes == expected_codes - assert set(messages) >= set(expected_messages) diff --git a/gql-checker/tests/utils.py b/gql-checker/tests/utils.py deleted file mode 100644 index e0fd9034..00000000 --- a/gql-checker/tests/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import re - - -ERROR_RX = re.compile("# ((GQL[0-9]+ ?)+)(: (.*))?$") - - -def extract_expected_errors(data): - lines = data.splitlines() - expected_codes = [] - expected_messages = [] - for line in lines: - match = ERROR_RX.search(line) - if match: - codes = match.group(1).split() - message = match.group(4) - expected_codes.extend(codes) - if message: - expected_messages.append(message) - return expected_codes, expected_messages diff --git a/gql-checker/tox.ini b/gql-checker/tox.ini deleted file mode 100644 index 0bbd2e81..00000000 --- a/gql-checker/tox.ini +++ /dev/null @@ -1,38 +0,0 @@ -[tox] -envlist = py26,py27,pypy,py33,py34,py35,pep8,py3pep8 - -[testenv] -deps = - coverage==3.7 - pytest - flake8 - pylama - pycodestyle>=2.0 -commands = - coverage run --source=gql_checker/,tests/ -m pytest --capture=no --strict {posargs} - coverage report -m - -# Temporarily disable coverage on pypy because of performance problems with -# coverage.py on pypy. -[testenv:pypy] -commands = py.test --capture=no --strict {posargs} - -[testenv:pep8] -deps = - flake8 - pep8-naming - flake8-import-order -commands = flake8 gql_checker/ - -[testenv:py3pep8] -basepython = python3.3 -deps = - flake8 - pep8-naming - flake8-import-order -commands = flake8 gql_checker/ - -[flake8] -exclude = .tox,*.egg -select = E,W,F,N,I -application-import-names = gql_checker,tests From 013eebd8427d58acbd8d9ed4d7e80057b95499c9 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 28 Oct 2024 13:02:08 +0100 Subject: [PATCH 161/239] Fix build wheel warnings (#503) * MANIFEST.in remove non-existing included files * setup.cfg Removing deprecated universal wheel option * Removing deprecated tests_require parameter --- MANIFEST.in | 3 +-- setup.cfg | 3 --- setup.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 3df67dcf..ddebd0b0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,7 +6,6 @@ include README.md include CONTRIBUTING.md include .readthedocs.yaml -include dev_requirements.txt include Makefile include tox.ini @@ -14,7 +13,7 @@ include tox.ini include gql/py.typed recursive-include tests *.py *.graphql *.cnf *.yaml *.pem -recursive-include docs *.txt *.rst conf.py Makefile make.bat *.jpg *.png *.gif +recursive-include docs *.txt *.rst conf.py Makefile make.bat recursive-include docs/code_examples *.py prune docs/_build diff --git a/setup.cfg b/setup.cfg index 50388a19..66380493 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[wheel] -universal = 1 - [flake8] max-line-length = 88 diff --git a/setup.py b/setup.py index a5efffa3..054bc93e 100644 --- a/setup.py +++ b/setup.py @@ -99,7 +99,6 @@ # PEP-561: https://www.python.org/dev/peps/pep-0561/ package_data={"gql": ["py.typed"]}, install_requires=install_requires, - tests_require=install_all_requires + tests_requires, extras_require={ "all": install_all_requires, "test": install_all_requires + tests_requires, From 57ef910b13727a63c2b15c1a5492446f892d59e6 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 28 Oct 2024 22:09:00 +0100 Subject: [PATCH 162/239] Using unittest.mock instead of mock (#504) Also Running gql-cli --version test in subprocess mode to try to fix flaky test --- setup.py | 1 - tests/test_cli.py | 1 + tests/test_client.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 054bc93e..21763a03 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,6 @@ "pytest-asyncio==0.21.1", "pytest-console-scripts==1.3.1", "pytest-cov==5.0.0", - "mock==4.0.2", "vcrpy==4.4.0", "aiofiles", ] diff --git a/tests/test_cli.py b/tests/test_cli.py index f0534957..cdbe07f9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -370,6 +370,7 @@ def test_cli_get_transport_no_protocol(parser): get_transport(args) +@pytest.mark.script_launch_mode("subprocess") def test_cli_ep_version(script_runner): ret = script_runner.run("gql-cli", "--version") diff --git a/tests/test_client.py b/tests/test_client.py index ada129c6..f7a3c947 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,7 @@ import os from contextlib import suppress +from unittest import mock -import mock import pytest from graphql import build_ast_schema, parse From e355261e0eefa33c62ef24995af53d8cab4eb32b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 28 Oct 2024 22:43:24 +0100 Subject: [PATCH 163/239] Using pyupgrade with --py38-plus (#505) --- gql/transport/exceptions.py | 2 +- gql/transport/phoenix_channel_websockets.py | 2 +- setup.py | 2 +- tests/fixtures/aws/fake_credentials.py | 2 +- tests/fixtures/aws/fake_request.py | 2 +- tests/fixtures/aws/fake_session.py | 2 +- tests/fixtures/aws/fake_signer.py | 2 +- tests/test_client.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 48e9d96b..7ec27a33 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -23,7 +23,7 @@ class TransportServerError(TransportError): code: Optional[int] def __init__(self, message: str, code: Optional[int] = None): - super(TransportServerError, self).__init__(message) + super().__init__(message) self.code = code diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index b8226234..d5585807 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -52,7 +52,7 @@ def __init__( self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscriptions: Dict[str, Subscription] = {} - super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) async def _initialize(self) -> None: """Join the specified channel and wait for the connection ACK. diff --git a/setup.py b/setup.py index 21763a03..7ca66ae3 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ # Get version from __version__.py file current_folder = os.path.abspath(os.path.dirname(__file__)) about = {} -with open(os.path.join(current_folder, "gql", "__version__.py"), "r") as f: +with open(os.path.join(current_folder, "gql", "__version__.py")) as f: exec(f.read(), about) setup( diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py index d8eac834..8df8b22b 100644 --- a/tests/fixtures/aws/fake_credentials.py +++ b/tests/fixtures/aws/fake_credentials.py @@ -1,7 +1,7 @@ import pytest -class FakeCredentials(object): +class FakeCredentials: def __init__( self, access_key=None, secret_key=None, method=None, token=None, region=None ): diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py index 615bc095..0c135d3a 100644 --- a/tests/fixtures/aws/fake_request.py +++ b/tests/fixtures/aws/fake_request.py @@ -1,7 +1,7 @@ import pytest -class FakeRequest(object): +class FakeRequest: headers = None def __init__(self, request_props=None): diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py index 78e1511a..585f5c59 100644 --- a/tests/fixtures/aws/fake_session.py +++ b/tests/fixtures/aws/fake_session.py @@ -1,7 +1,7 @@ import pytest -class FakeSession(object): +class FakeSession: def __init__(self, credentials, region_name): self._credentials = credentials self._region_name = region_name diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index ff096745..c0177a32 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -11,7 +11,7 @@ def _fake_signer_factory(request=None): yield _fake_signer_factory -class FakeSigner(object): +class FakeSigner: def __init__(self, request=None) -> None: self.request = request diff --git a/tests/test_client.py b/tests/test_client.py index f7a3c947..1e794558 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -30,7 +30,7 @@ def http_transport_query(): def test_request_transport_not_implemented(http_transport_query): class RandomTransport(Transport): def execute(self): - super(RandomTransport, self).execute(http_transport_query) + super().execute(http_transport_query) with pytest.raises(NotImplementedError) as exc_info: RandomTransport().execute() From c3b722eb6609cda0248147b78755288a7ceccb53 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 29 Oct 2024 00:23:08 +0100 Subject: [PATCH 164/239] chore: upgrade GitHub workflows (#506) --- .github/workflows/deploy.yml | 10 +++++----- .github/workflows/lint.yml | 10 +++++----- .github/workflows/tests.yml | 35 +++++++++++++++++++---------------- setup.py | 4 ++-- tests/test_aiohttp.py | 3 ++- tests/test_cli.py | 2 +- tests/test_httpx_async.py | 3 ++- tox.ini | 12 ++++++------ 8 files changed, 42 insertions(+), 37 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index da129836..1147ecf5 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -7,14 +7,14 @@ on: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.12 - name: Build wheel and source tarball run: | pip install wheel diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 39f5cf0c..86f2468b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,14 +4,14 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.12 - name: Install dependencies run: | python -m pip install --upgrade pip wheel diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7588a997..f67d0b6f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,8 +8,8 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.8"] - os: [ubuntu-20.04, windows-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.10"] + os: [ubuntu-24.04, windows-latest] exclude: - os: windows-latest python-version: "3.9" @@ -20,12 +20,12 @@ jobs: - os: windows-latest python-version: "3.12" - os: windows-latest - python-version: "pypy3.8" + python-version: "pypy3.10" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -38,18 +38,18 @@ jobs: TOXENV: ${{ matrix.toxenv }} single_extra: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: fail-fast: false matrix: dependency: ["aiohttp", "requests", "httpx", "websockets"] steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.12 - name: Install dependencies with only ${{ matrix.dependency }} extra dependency run: | python -m pip install --upgrade pip wheel @@ -58,14 +58,14 @@ jobs: run: pytest tests --${{ matrix.dependency }}-only coverage: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.12 - name: Install test dependencies run: | python -m pip install --upgrade pip wheel @@ -73,4 +73,7 @@ jobs: - name: Test with coverage run: pytest --cov=gql --cov-report=xml --cov-report=term-missing tests - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v4 + with: + fail_ci_if_error: false + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/setup.py b/setup.py index 7ca66ae3..132f6ead 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ "parse==1.15.0", "pytest==7.4.2", "pytest-asyncio==0.21.1", - "pytest-console-scripts==1.3.1", + "pytest-console-scripts==1.4.1", "pytest-cov==5.0.0", "vcrpy==4.4.0", "aiofiles", @@ -26,7 +26,7 @@ dev_requires = [ "black==22.3.0", "check-manifest>=0.42,<1", - "flake8==3.8.1", + "flake8==7.1.1", "isort==4.3.21", "mypy==1.10", "sphinx>=5.3.0,<6", diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index b16964d0..55b08260 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1162,7 +1162,8 @@ def test_code(): monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) ret = script_runner.run( - "gql-cli", url, "--verbose", stdin=io.StringIO(query1_str) + ["gql-cli", url, "--verbose"], + stdin=io.StringIO(query1_str), ) assert ret.success diff --git a/tests/test_cli.py b/tests/test_cli.py index cdbe07f9..88d1f533 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -372,7 +372,7 @@ def test_cli_get_transport_no_protocol(parser): @pytest.mark.script_launch_mode("subprocess") def test_cli_ep_version(script_runner): - ret = script_runner.run("gql-cli", "--version") + ret = script_runner.run(["gql-cli", "--version"]) assert ret.success diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 3665f5d8..17be0db5 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1036,7 +1036,8 @@ def test_code(): monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) ret = script_runner.run( - "gql-cli", url, "--verbose", stdin=io.StringIO(query1_str) + ["gql-cli", url, "--verbose"], + stdin=io.StringIO(query1_str), ) assert ret.success diff --git a/tox.ini b/tox.ini index 7a639572..308aba00 100644 --- a/tox.ini +++ b/tox.ini @@ -32,37 +32,37 @@ commands = py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] -basepython = python3.8 +basepython = python deps = -e.[dev] commands = black --check gql tests [testenv:flake8] -basepython = python3.8 +basepython = python deps = -e.[dev] commands = flake8 gql tests [testenv:import-order] -basepython = python3.8 +basepython = python deps = -e.[dev] commands = isort --recursive --check-only --diff gql tests [testenv:mypy] -basepython = python3.8 +basepython = python deps = -e.[dev] commands = mypy gql tests [testenv:docs] -basepython = python3.8 +basepython = python deps = -e.[dev] commands = sphinx-build -b html -nEW docs docs/_build/html [testenv:manifest] -basepython = python3.8 +basepython = python deps = -e.[dev] commands = check-manifest -v From ff4a20a955659b7e377c2f48e091b33c8bddd06b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 29 Oct 2024 21:34:24 +0100 Subject: [PATCH 165/239] Restrict permissions to GitHub actions (#509) --- .github/workflows/lint.yml | 3 +++ .github/workflows/tests.yml | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 86f2468b..6f1daaf7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,5 +1,8 @@ name: Lint +permissions: + contents: read + on: [push, pull_request] jobs: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f67d0b6f..e53820c0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,5 +1,8 @@ name: Tests +permissions: + contents: read + on: [push, pull_request] jobs: @@ -60,6 +63,10 @@ jobs: coverage: runs-on: ubuntu-24.04 + permissions: + contents: read + checks: write + steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 From d07461136ee175c599bdcade8dea98ded5e0ae8a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 16 Nov 2024 04:21:30 +0100 Subject: [PATCH 166/239] Fix python 3.11 test coverage issues (#512) --- gql/transport/aiohttp.py | 2 +- gql/transport/phoenix_channel_websockets.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index be22ce9c..6455e2d8 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -147,7 +147,7 @@ def connection_lost(exc, orig_lost): all_is_lost.set() def eof_received(orig_eof_received): - try: + try: # pragma: no cover orig_eof_received() except AttributeError: # pragma: no cover # It may happen that eof_received() is called after diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index d5585807..08cde8cc 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -342,9 +342,8 @@ def _validate_data_response(d: Any, label: str) -> dict: elif status == "timeout": raise TransportQueryError("reply timeout", query_id=answer_id) - else: - # missing or unrecognized status, just continue - pass + + # In case of missing or unrecognized status, just continue elif event == "phx_error": # Sent if the channel has crashed From b2f2a68104f009ea4e05edfbde3ffac897323cf1 Mon Sep 17 00:00:00 2001 From: Alexandre Detiste Date: Sat, 16 Nov 2024 04:37:27 +0100 Subject: [PATCH 167/239] types-mock was only usefull when using old standalone "mock" module (#511) --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 132f6ead..bda6f570 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,6 @@ "sphinx_rtd_theme>=0.4,<1", "sphinx-argparse==0.2.5", "types-aiofiles", - "types-mock", "types-requests", ] + tests_requires From b7c58656ae49645aecdf848e7906fd00ee8fd050 Mon Sep 17 00:00:00 2001 From: "Malte S. Stretz" Date: Wed, 11 Dec 2024 15:04:44 +0100 Subject: [PATCH 168/239] Wire httpx transport in gql-cli (#513) --- gql/cli.py | 6 ++++++ tests/test_cli.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/gql/cli.py b/gql/cli.py index a7d129e2..06781c2b 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -157,6 +157,7 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: choices=[ "auto", "aiohttp", + "httpx", "phoenix", "websockets", "aiohttp_websockets", @@ -330,6 +331,11 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: return AIOHTTPTransport(url=args.server, **transport_args) + elif transport_name == "httpx": + from gql.transport.httpx import HTTPXAsyncTransport + + return HTTPXAsyncTransport(url=args.server, **transport_args) + elif transport_name == "phoenix": from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, diff --git a/tests/test_cli.py b/tests/test_cli.py index 88d1f533..dccfcb5a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -190,6 +190,22 @@ def test_cli_get_transport_aiohttp(parser, url): assert isinstance(transport, AIOHTTPTransport) +@pytest.mark.httpx +@pytest.mark.parametrize( + "url", + ["http://your_server.com", "https://your_server.com"], +) +def test_cli_get_transport_httpx(parser, url): + + from gql.transport.httpx import HTTPXAsyncTransport + + args = parser.parse_args([url, "--transport", "httpx"]) + + transport = get_transport(args) + + assert isinstance(transport, HTTPXAsyncTransport) + + @pytest.mark.websockets @pytest.mark.parametrize( "url", From 5879e23484b6cbac8e43aa1fd3de5248cae1df8a Mon Sep 17 00:00:00 2001 From: "Malte S. Stretz" Date: Wed, 11 Dec 2024 15:56:10 +0100 Subject: [PATCH 169/239] Add minimal pyproject.toml (#514) --- pyproject.toml | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..9b631e08 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "gql" +readme = "README.md" +requires-python = ">=3.8.1" +dynamic = ["authors", "classifiers", "dependencies", "description", "entry-points", "keywords", "license", "optional-dependencies", "scripts", "version"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" From 483053febe2851e179dbd9641999658788786fa7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 12 Dec 2024 11:58:10 +0100 Subject: [PATCH 170/239] chore Fix docs sphinx nitpick warnings - adding intersphinx_mapping (#515) --- docs/Makefile | 2 +- docs/code_examples/fastapi_async.py | 1 + docs/conf.py | 50 +++++++++++++++++++++++++++++ gql/dsl.py | 2 +- gql/gql.py | 2 +- gql/transport/appsync_auth.py | 4 +-- gql/transport/requests.py | 8 ++--- 7 files changed, 60 insertions(+), 9 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb..747126b3 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +SPHINXOPTS ?= -n SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 80920252..3bedd187 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,6 +10,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse + from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/conf.py b/docs/conf.py index db6e7c5f..94daf942 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,6 +34,7 @@ extensions = [ 'sphinxarg.ext', 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', 'sphinx_rtd_theme' ] @@ -77,3 +78,52 @@ 'show-inheritance': True } autosummary_generate = True + +# -- Intersphinx configuration --------------------------------------------- +intersphinx_mapping = { + 'aiohttp': ('https://docs.aiohttp.org/en/stable/', None), + 'graphql': ('https://graphql-core-3.readthedocs.io/en/latest/', None), + 'multidict': ('https://multidict.readthedocs.io/en/stable/', None), + 'python': ('https://docs.python.org/3/', None), + 'requests': ('https://requests.readthedocs.io/en/latest/', None), + 'websockets': ('https://websockets.readthedocs.io/en/11.0.3/', None), + 'yarl': ('https://yarl.readthedocs.io/en/stable/', None), +} + +nitpick_ignore = [ + # graphql-core: should be fixed + ('py:class', 'graphql.execution.execute.ExecutionResult'), + ('py:class', 'Source'), + ('py:class', 'GraphQLSchema'), + + # asyncio: should be fixed + ('py:class', 'asyncio.locks.Event'), + + # aiohttp: should be fixed + ('py:class', 'aiohttp.client_reqrep.Fingerprint'), + ('py:class', 'aiohttp.helpers.BasicAuth'), + + # multidict: should be fixed + ('py:class', 'multidict._multidict.CIMultiDictProxy'), + ('py:class', 'multidict._multidict.CIMultiDict'), + ('py:class', 'multidict._multidict.istr'), + + # websockets: first bump websockets version + ('py:class', 'websockets.datastructures.SupportsKeysAndGetItem'), + ('py:class', 'websockets.typing.Subprotocol'), + + # httpx: no sphinx docs yet https://github.com/encode/httpx/discussions/3091 + ('py:class', 'httpx.AsyncClient'), + ('py:class', 'httpx.Client'), + ('py:class', 'httpx.Headers'), + + # botocore: no sphinx docs + ('py:class', 'botocore.auth.BaseSigner'), + ('py:class', 'botocore.awsrequest.AWSRequest'), + ('py:class', 'botocore.credentials.Credentials'), + ('py:class', 'botocore.session.Session'), + + # gql: ignore private classes + ('py:class', 'gql.transport.httpx._HTTPXTransport'), + ('py:class', 'gql.client._CallableT'), +] diff --git a/gql/dsl.py b/gql/dsl.py index 536a8b8b..be2b5a7e 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -347,7 +347,7 @@ def select( :type \**fields_with_alias: DSLSelectable :raises TypeError: if an argument is not an instance of :class:`DSLSelectable` - :raises GraphQLError: if an argument is not a valid field + :raises graphql.error.GraphQLError: if an argument is not a valid field """ # Concatenate fields without and with alias added_fields: Tuple["DSLSelectable", ...] = DSLField.get_aliased_fields( diff --git a/gql/gql.py b/gql/gql.py index e35c8045..e9705947 100644 --- a/gql/gql.py +++ b/gql/gql.py @@ -13,7 +13,7 @@ def gql(request_string: str | Source) -> DocumentNode: :class:`async session ` or by a :class:`sync session ` - :raises GraphQLError: if a syntax error is encountered. + :raises graphql.error.GraphQLError: if a syntax error is encountered. """ if isinstance(request_string, Source): source = request_string diff --git a/gql/transport/appsync_auth.py b/gql/transport/appsync_auth.py index 5ce93d4e..1eb51b4e 100644 --- a/gql/transport/appsync_auth.py +++ b/gql/transport/appsync_auth.py @@ -18,7 +18,7 @@ class AppSyncAuthentication(ABC): """AWS authentication abstract base class All AWS authentication class should have a - :meth:`get_headers ` + :meth:`get_headers ` method which defines the headers used in the authentication process.""" def get_auth_url(self, url: str) -> str: @@ -91,7 +91,7 @@ class AppSyncIAMAuthentication(AppSyncAuthentication): .. note:: There is no need for you to use this class directly, you could instead - intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport` + intantiate :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` without an auth argument. During initialization, this class will use botocore to attempt to diff --git a/gql/transport/requests.py b/gql/transport/requests.py index fd9759ed..bd370908 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -54,10 +54,10 @@ def __init__( """Initialize the transport with the given request parameters. :param url: The GraphQL server URL. - :param headers: Dictionary of HTTP Headers to send with the :class:`Request` - (Default: None). - :param cookies: Dict or CookieJar object to send with the :class:`Request` - (Default: None). + :param headers: Dictionary of HTTP Headers to send with + :meth:`requests.Session.request` (Default: None). + :param cookies: Dict or CookieJar object to send with + :meth:`requests.Session.request` (Default: None). :param auth: Auth tuple or callable to enable Basic/Digest/Custom HTTP Auth (Default: None). :param use_json: Send request body as JSON instead of form-urlencoded From 26b28d7d6cf630ba819d9db904fa5aff4df9ac73 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 20 Jan 2025 17:40:58 +0100 Subject: [PATCH 171/239] Chore fix tests failing vcrpy urllib3 dep (#518) * Bump vcrpy to 7.0.0 --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bda6f570..1a9918c4 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ "pytest-asyncio==0.21.1", "pytest-console-scripts==1.4.1", "pytest-cov==5.0.0", - "vcrpy==4.4.0", + "vcrpy==4.4.0;python_version<='3.8'", + "vcrpy==7.0.0;python_version>'3.8'", "aiofiles", ] From 212e2e1097da7bca845c748c7232281042bdd5f3 Mon Sep 17 00:00:00 2001 From: Taylor Braun-Jones Date: Mon, 20 Jan 2025 12:01:15 -0500 Subject: [PATCH 172/239] Websockets transport Fix long hang under certain network failures (#517) --- gql/transport/websockets.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index c385d3d7..02abb61f 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -7,6 +7,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike +from websockets.exceptions import ConnectionClosed from websockets.typing import Subprotocol from .exceptions import ( @@ -505,10 +506,15 @@ async def _after_initialize(self): self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) async def _close_hook(self): + log.debug("_close_hook: start") # Properly shut down the send ping task if enabled if self.send_ping_task is not None: + log.debug("_close_hook: cancelling send_ping_task") self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): + with suppress(asyncio.CancelledError, ConnectionClosed): + log.debug("_close_hook: awaiting send_ping_task") await self.send_ping_task self.send_ping_task = None + + log.debug("_close_hook: end") From cd071da9836a27392db43ca166c92a8b5f38ef5c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 20 Jan 2025 19:00:33 +0100 Subject: [PATCH 173/239] Support Python 3.13 (#519) --- .github/workflows/tests.yml | 4 +++- docs/code_examples/fastapi_async.py | 1 - docs/code_examples/httpx_async_trio.py | 1 - setup.py | 1 + tox.ini | 5 +++-- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e53820c0..a7f6e732 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "pypy3.10"] os: [ubuntu-24.04, windows-latest] exclude: - os: windows-latest @@ -22,6 +22,8 @@ jobs: python-version: "3.11" - os: windows-latest python-version: "3.12" + - os: windows-latest + python-version: "3.13" - os: windows-latest python-version: "pypy3.10" diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 3bedd187..80920252 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,7 +10,6 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse - from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/httpx_async_trio.py b/docs/code_examples/httpx_async_trio.py index 058b952b..b76dab42 100644 --- a/docs/code_examples/httpx_async_trio.py +++ b/docs/code_examples/httpx_async_trio.py @@ -1,5 +1,4 @@ import trio - from gql import Client, gql from gql.transport.httpx import HTTPXAsyncTransport diff --git a/setup.py b/setup.py index 1a9918c4..6c12c95c 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay gql client", diff --git a/tox.ini b/tox.ini index 308aba00..4d6d4d2f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{38,39,310,311,312,py3} + py{38,39,310,311,312,313,py3} [gh-actions] python = @@ -10,6 +10,7 @@ python = 3.10: py310 3.11: py311 3.12: py312 + 3.13: py313 pypy-3: pypy3 [testenv] @@ -28,7 +29,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{39,310,311,312,py3}: pytest {posargs:tests} + py{39,310,311,312,313,py3}: pytest {posargs:tests} py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From 25bb56d56e95f9d28106c6f16dba09f3a2d6805e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 20 Jan 2025 20:24:07 +0100 Subject: [PATCH 174/239] Bump version number to 3.6.0b3 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index dc9e18d0..6361d12f 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0b2" +__version__ = "3.6.0b3" From 35cddc8343bc95eea365f74bf6cacf519f680bef Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 20 Jan 2025 20:51:38 +0100 Subject: [PATCH 175/239] Chore fix deploy GitHub action --- .github/workflows/deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 1147ecf5..69c11d2a 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -17,7 +17,7 @@ jobs: python-version: 3.12 - name: Build wheel and source tarball run: | - pip install wheel + pip install wheel setuptools python setup.py sdist bdist_wheel - name: Publish a Python distribution to PyPI uses: pypa/gh-action-pypi-publish@v1.1.0 From 8dd458c76c35a22dd96166ec0f337ddf90ca7814 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 27 Jan 2025 17:38:01 +0100 Subject: [PATCH 176/239] Temporarily restrict graphql-core<3.3.0a7 for tests --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6c12c95c..3833502b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.3.0a3,<3.4", + "graphql-core>=3.3.0a3,<3.3.0a7", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", From 38e64b2d2f7f1aad775302eddff95b933b27294f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 27 Jan 2025 18:27:17 +0100 Subject: [PATCH 177/239] Adding the input_value_deprecation argument to get_introspection_query_ast (#524) --- gql/utilities/get_introspection_query_ast.py | 20 +++++++++++++++++--- tests/starwars/test_dsl.py | 2 ++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index d35a2a75..975ccc83 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -10,6 +10,7 @@ def get_introspection_query_ast( specified_by_url: bool = False, directive_is_repeatable: bool = False, schema_description: bool = False, + input_value_deprecation: bool = False, type_recursion_level: int = 7, ) -> DocumentNode: """Get a query for introspection as a document using the DSL module. @@ -43,13 +44,20 @@ def get_introspection_query_ast( directives = ds.__Schema.directives.select(ds.__Directive.name) + deprecated_expand = {} + + if input_value_deprecation: + deprecated_expand = { + "includeDeprecated": True, + } + if descriptions: directives.select(ds.__Directive.description) if directive_is_repeatable: directives.select(ds.__Directive.isRepeatable) directives.select( ds.__Directive.locations, - ds.__Directive.args.select(fragment_InputValue), + ds.__Directive.args(**deprecated_expand).select(fragment_InputValue), ) schema.select(directives) @@ -69,7 +77,7 @@ def get_introspection_query_ast( fields.select(ds.__Field.description) fields.select( - ds.__Field.args.select(fragment_InputValue), + ds.__Field.args(**deprecated_expand).select(fragment_InputValue), ds.__Field.type.select(fragment_TypeRef), ds.__Field.isDeprecated, ds.__Field.deprecationReason, @@ -89,7 +97,7 @@ def get_introspection_query_ast( fragment_FullType.select( fields, - ds.__Type.inputFields.select(fragment_InputValue), + ds.__Type.inputFields(**deprecated_expand).select(fragment_InputValue), ds.__Type.interfaces.select(fragment_TypeRef), enum_values, ds.__Type.possibleTypes.select(fragment_TypeRef), @@ -105,6 +113,12 @@ def get_introspection_query_ast( ds.__InputValue.defaultValue, ) + if input_value_deprecation: + fragment_InputValue.select( + ds.__InputValue.isDeprecated, + ds.__InputValue.deprecationReason, + ) + fragment_TypeRef.select( ds.__Type.kind, ds.__Type.name, diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 2aadf92f..1aa1efa2 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -984,12 +984,14 @@ def test_get_introspection_query_ast(option): specified_by_url=option, directive_is_repeatable=option, schema_description=option, + input_value_deprecation=option, ) dsl_introspection_query = get_introspection_query_ast( descriptions=option, specified_by_url=option, directive_is_repeatable=option, schema_description=option, + input_value_deprecation=option, ) assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) From 4a366021d8b1123d6bd883125b72b36b09463d65 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 27 Jan 2025 18:32:32 +0100 Subject: [PATCH 178/239] Using gql version of the get_introspection_query method (#523) This would reset the change of graphql-core to increase the type recursion level from 7 to 9 --- gql/client.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/gql/client.py b/gql/client.py index e1b168a7..c52a00b2 100644 --- a/gql/client.py +++ b/gql/client.py @@ -29,7 +29,6 @@ GraphQLSchema, IntrospectionQuery, build_ast_schema, - get_introspection_query, parse, validate, ) @@ -39,7 +38,7 @@ from .transport.exceptions import TransportClosed, TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport -from .utilities import build_client_schema +from .utilities import build_client_schema, get_introspection_query_ast from .utilities import parse_result as parse_result_fn from .utilities import serialize_variable_values from .utils import str_first_element @@ -87,8 +86,8 @@ def __init__( :param transport: The provided :ref:`transport `. :param fetch_schema_from_transport: Boolean to indicate that if we want to fetch the schema from the transport using an introspection query. - :param introspection_args: arguments passed to the get_introspection_query - method of graphql-core. + :param introspection_args: arguments passed to the + :meth:`gql.utilities.get_introspection_query_ast` method. :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. Passing None results in waiting forever for a response. @@ -1282,8 +1281,10 @@ def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" - introspection_query = get_introspection_query(**self.client.introspection_args) - execution_result = self.transport.execute(parse(introspection_query)) + introspection_query = get_introspection_query_ast( + **self.client.introspection_args + ) + execution_result = self.transport.execute(introspection_query) self.client._build_schema_from_introspection(execution_result) @@ -1650,8 +1651,10 @@ async def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" - introspection_query = get_introspection_query(**self.client.introspection_args) - execution_result = await self.transport.execute(parse(introspection_query)) + introspection_query = get_introspection_query_ast( + **self.client.introspection_args + ) + execution_result = await self.transport.execute(introspection_query) self.client._build_schema_from_introspection(execution_result) From f07fb2b92a3d5d1a1625f37fa9a64dea774cacd6 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 27 Jan 2025 18:36:34 +0100 Subject: [PATCH 179/239] Sort elements in node_tree method (#520) This should resolve some non-important difference from graphql-core v3.3.0a7 --- gql/utilities/node_tree.py | 5 ++- tests/starwars/test_dsl.py | 86 +++++++++++++++++++------------------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index c307d937..4313188e 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -19,7 +19,7 @@ def _node_tree_recursive( results.append(" " * indent + f"{type(obj).__name__}") try: - keys = obj.keys + keys = sorted(obj.keys) except AttributeError: # If the object has no keys attribute, print its repr and return. results.append(" " * (indent + 1) + repr(obj)) @@ -70,6 +70,9 @@ def node_tree( Useful to debug deep DocumentNode instances created by gql or dsl_gql. + NOTE: from gql version 3.6.0b4 the elements of each node are sorted to ignore + small changes in graphql-core + WARNING: the output of this method is not guaranteed and may change without notice. """ diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 1aa1efa2..098c8aca 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1030,11 +1030,10 @@ def test_node_tree_with_loc(ds): node_tree_result = """ DocumentNode - loc: - Location - definitions: OperationDefinitionNode + directives: + empty tuple loc: Location @@ -1045,10 +1044,8 @@ def test_node_tree_with_loc(ds): value: 'GetHeroName' - directives: - empty tuple - variable_definitions: - empty tuple + operation: + selection_set: SelectionSetNode loc: @@ -1056,13 +1053,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1070,8 +1069,6 @@ def test_node_tree_with_loc(ds): value: 'hero' - arguments: - empty tuple nullability_assertion: None selection_set: @@ -1081,13 +1078,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1095,23 +1094,23 @@ def test_node_tree_with_loc(ds): value: 'name' - arguments: - empty tuple nullability_assertion: None selection_set: None - operation: - + variable_definitions: + empty tuple + loc: + Location + """.strip() node_tree_result_stable = """ DocumentNode - loc: - Location - definitions: OperationDefinitionNode + directives: + empty tuple loc: Location @@ -1122,10 +1121,8 @@ def test_node_tree_with_loc(ds): value: 'GetHeroName' - directives: - empty tuple - variable_definitions: - empty tuple + operation: + selection_set: SelectionSetNode loc: @@ -1133,13 +1130,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1147,8 +1146,6 @@ def test_node_tree_with_loc(ds): value: 'hero' - arguments: - empty tuple selection_set: SelectionSetNode loc: @@ -1156,13 +1153,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1170,14 +1169,17 @@ def test_node_tree_with_loc(ds): value: 'name' - arguments: - empty tuple selection_set: None - operation: - + variable_definitions: + empty tuple + loc: + Location + """.strip() + print(node_tree(document, ignore_loc=False)) + try: assert node_tree(document, ignore_loc=False) == node_tree_result except AssertionError: From 88b80c32c49b38059e915c06cc84aa57d9b18ec7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 27 Jan 2025 18:55:36 +0100 Subject: [PATCH 180/239] Fix test for introspection type recursion level change in graphql-core v3.3.0a7 (#521) --- tests/starwars/test_dsl.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 098c8aca..5cd051ba 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -994,10 +994,26 @@ def test_get_introspection_query_ast(option): input_value_deprecation=option, ) - assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) - assert node_tree(dsl_introspection_query) == node_tree( - gql(print_ast(dsl_introspection_query)) - ) + try: + assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert node_tree(dsl_introspection_query) == node_tree( + gql(print_ast(dsl_introspection_query)) + ) + except AssertionError: + + # From graphql-core version 3.3.0a7, there is two more type recursion levels + dsl_introspection_query = get_introspection_query_ast( + descriptions=option, + specified_by_url=option, + directive_is_repeatable=option, + schema_description=option, + input_value_deprecation=option, + type_recursion_level=9, + ) + assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert node_tree(dsl_introspection_query) == node_tree( + gql(print_ast(dsl_introspection_query)) + ) def test_typename_aliased(ds): From ba61920e2cf907af5af6367d2e11ca09d34363f3 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 27 Jan 2025 18:56:47 +0100 Subject: [PATCH 181/239] Revert "Temporarily restrict graphql-core<3.3.0a7 for tests" This reverts commit 8dd458c76c35a22dd96166ec0f337ddf90ca7814. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3833502b..6c12c95c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.3.0a3,<3.3.0a7", + "graphql-core>=3.3.0a3,<3.4", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", From 996a2aded28fea2c28a854249cb6e268f6013277 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 27 Jan 2025 20:31:03 +0100 Subject: [PATCH 182/239] Remove Python 3.8 support (#525) --- .github/workflows/tests.yml | 4 +--- setup.py | 4 +--- tests/test_websocket_online.py | 3 --- tox.ini | 5 ++--- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a7f6e732..8463ac00 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "pypy3.10"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "pypy3.10"] os: [ubuntu-24.04, windows-latest] exclude: - os: windows-latest @@ -20,8 +20,6 @@ jobs: python-version: "3.10" - os: windows-latest python-version: "3.11" - - os: windows-latest - python-version: "3.12" - os: windows-latest python-version: "3.13" - os: windows-latest diff --git a/setup.py b/setup.py index 6c12c95c..1d985a1d 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,7 @@ "pytest-asyncio==0.21.1", "pytest-console-scripts==1.4.1", "pytest-cov==5.0.0", - "vcrpy==4.4.0;python_version<='3.8'", - "vcrpy==7.0.0;python_version>'3.8'", + "vcrpy==7.0.0", "aiofiles", ] @@ -86,7 +85,6 @@ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index b5fca837..fa288b6d 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -1,6 +1,5 @@ import asyncio import logging -import sys from typing import Dict import pytest @@ -151,7 +150,6 @@ async def test_websocket_sending_invalid_payload(): @pytest.mark.online -@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_sending_invalid_data_while_other_query_is_running(): @@ -203,7 +201,6 @@ async def query_task2(): @pytest.mark.online -@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.skip(reason=skip_reason) @pytest.mark.asyncio async def test_websocket_two_queries_in_parallel_using_two_tasks(): diff --git a/tox.ini b/tox.ini index 4d6d4d2f..8796357b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,10 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{38,39,310,311,312,313,py3} + py{39,310,311,312,313,py3} [gh-actions] python = - 3.8: py38 3.9: py39 3.10: py310 3.11: py311 @@ -30,7 +29,7 @@ commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging py{39,310,311,312,313,py3}: pytest {posargs:tests} - py{38}: pytest {posargs:tests --cov-report=term-missing --cov=gql} + py{312}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] basepython = python From ed6373445baf44e6ff014c5af8d7a563d4878e1f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 27 Jan 2025 20:56:03 +0100 Subject: [PATCH 183/239] Bump sphinx dev dependencies (#526) --- setup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 1d985a1d..e41836a7 100644 --- a/setup.py +++ b/setup.py @@ -29,9 +29,10 @@ "flake8==7.1.1", "isort==4.3.21", "mypy==1.10", - "sphinx>=5.3.0,<6", - "sphinx_rtd_theme>=0.4,<1", - "sphinx-argparse==0.2.5", + "sphinx>=7.0.0,<8;python_version<='3.9'", + "sphinx>=8.1.0,<9;python_version>'3.9'", + "sphinx_rtd_theme>=3.0.2,<4", + "sphinx-argparse==0.4.0", "types-aiofiles", "types-requests", ] + tests_requires From b066e8944b0da0a4bbac6c31f43e5c3c7772cd51 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 31 Jan 2025 00:57:52 +0100 Subject: [PATCH 184/239] Use httpx with gql-cli if aiohttp is not available on auto (#528) --- gql/cli.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/gql/cli.py b/gql/cli.py index 06781c2b..91c67873 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -297,7 +297,23 @@ def autodetect_transport(url: URL) -> str: else: assert url.scheme in ["http", "https"] - transport_name = "aiohttp" + + try: + from gql.transport.aiohttp import AIOHTTPTransport # noqa: F401 + + transport_name = "aiohttp" + except ModuleNotFoundError: # pragma: no cover + try: + from gql.transport.httpx import HTTPXAsyncTransport # noqa: F401 + + transport_name = "httpx" + except ModuleNotFoundError: + raise ModuleNotFoundError( + "\n\nNo suitable dependencies has been found for an http(s) backend" + " (aiohttp or httpx).\n\n" + "Please check the install documentation at:\n" + "https://gql.readthedocs.io/en/stable/intro.html#installation\n" + ) return transport_name @@ -462,6 +478,9 @@ async def main(args: Namespace) -> int: except ValueError as e: print(f"Error: {e}", file=sys.stderr) return 1 + except ModuleNotFoundError as e: # pragma: no cover + print(f"Error: {e}", file=sys.stderr) + return 2 # By default, the exit_code is 0 (everything is ok) exit_code = 0 From 68ae2e683e54b3f97fc33ca3f7dd394217bbf81d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 18 Feb 2025 13:29:37 +0100 Subject: [PATCH 185/239] AIOHTTPTransport default ssl cert validation add warning (#530) --- MANIFEST.in | 2 +- gql/transport/aiohttp.py | 30 ++++++- tests/conftest.py | 23 +++++ tests/test_aiohttp.py | 84 +++++++++++++++++- tests/test_aiohttp_websocket_query.py | 63 ++++++++++++-- tests/test_httpx.py | 114 +++++++++++++++++++++++- tests/test_httpx_async.py | 61 ++++++++++++- tests/test_localhost_client.crt | 20 +++++ tests/test_phoenix_channel_query.py | 88 +++++++++++++++++-- tests/test_requests.py | 120 +++++++++++++++++++++++++- tests/test_websocket_query.py | 55 ++++++++++-- 11 files changed, 628 insertions(+), 32 deletions(-) create mode 100644 tests/test_localhost_client.crt diff --git a/MANIFEST.in b/MANIFEST.in index ddebd0b0..ca670908 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -12,7 +12,7 @@ include tox.ini include gql/py.typed -recursive-include tests *.py *.graphql *.cnf *.yaml *.pem +recursive-include tests *.py *.graphql *.cnf *.yaml *.pem *.crt recursive-include docs *.txt *.rst conf.py Makefile make.bat recursive-include docs/code_examples *.py diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 6455e2d8..0c332205 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -3,8 +3,19 @@ import io import json import logging +import warnings from ssl import SSLContext -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + cast, +) import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -46,7 +57,7 @@ def __init__( headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, - ssl: Union[SSLContext, bool, Fingerprint] = False, + ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning", timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, @@ -77,7 +88,20 @@ def __init__( self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth - self.ssl: Union[SSLContext, bool, Fingerprint] = ssl + + if ssl == "ssl_warning": + ssl = False + if str(url).startswith("https"): + warnings.warn( + "WARNING: By default, AIOHTTPTransport does not verify" + " ssl certificates. This will be fixed in the next major version." + " You can set ssl=True to force the ssl certificate verification" + " or ssl=False to disable this warning" + ) + + self.ssl: Union[SSLContext, bool, Fingerprint] = cast( + Union[SSLContext, bool, Fingerprint], ssl + ) self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args diff --git a/tests/conftest.py b/tests/conftest.py index c164c355..c0b2037f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,6 +156,29 @@ def get_localhost_ssl_context(): return (testcert, ssl_context) +def get_localhost_ssl_context_client(): + """ + Create a client-side SSL context that verifies the specific self-signed certificate + used for our test. + """ + # Get the certificate from the server setup + cert_path = bytes(pathlib.Path(__file__).with_name("test_localhost_client.crt")) + + # Create client SSL context + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # Load just the certificate part as a trusted CA + ssl_context.load_verify_locations(cafile=cert_path) + + # Require certificate verification + ssl_context.verify_mode = ssl.CERT_REQUIRED + + # Enable hostname checking for localhost + ssl_context.check_hostname = True + + return cert_path, ssl_context + + class WebSocketServer: """Websocket server on localhost on a free port. diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 55b08260..81af20ff 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -14,7 +14,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) query1_str = """ query getContinents { @@ -1285,7 +1289,10 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) -async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout): +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_aiohttp_query_https( + event_loop, ssl_aiohttp_server, ssl_close_timeout, verify_https +): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1300,8 +1307,20 @@ async def handler(request): assert str(url).startswith("https://") + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False + transport = AIOHTTPTransport( - url=url, timeout=10, ssl_close_timeout=ssl_close_timeout + url=url, + timeout=10, + ssl_close_timeout=ssl_close_timeout, + **extra_args, ) async with Client(transport=transport) as session: @@ -1318,6 +1337,65 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.skip(reason="We will change the default to fix this in a future version") +@pytest.mark.asyncio +async def test_aiohttp_query_https_self_cert_fail(event_loop, ssl_aiohttp_server): + """By default, we should verify the ssl certificate""" + from aiohttp.client_exceptions import ClientConnectorCertificateError + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + transport = AIOHTTPTransport(url=url, timeout=10) + + with pytest.raises(ClientConnectorCertificateError) as exc_info: + async with Client(transport=transport) as session: + query = gql(query1_str) + + # Execute query asynchronously + await session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + assert transport.session is None + + +@pytest.mark.asyncio +async def test_aiohttp_query_https_self_cert_warn(event_loop, ssl_aiohttp_server): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + assert str(url).startswith("https://") + + expected_warning = ( + "WARNING: By default, AIOHTTPTransport does not verify ssl certificates." + " This will be fixed in the next major version." + ) + + with pytest.warns(Warning, match=expected_warning): + AIOHTTPTransport(url=url, timeout=10) + + @pytest.mark.asyncio async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): from aiohttp import web diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index f154386b..ff2bcf02 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -1,6 +1,5 @@ import asyncio import json -import ssl import sys from typing import Dict, Mapping @@ -14,7 +13,7 @@ TransportServerError, ) -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper, get_localhost_ssl_context_client # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = pytest.mark.aiohttp @@ -92,8 +91,9 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( @pytest.mark.websockets @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_aiohttp_websocket_using_ssl_connection( - event_loop, ws_ssl_server, ssl_close_timeout + event_loop, ws_ssl_server, ssl_close_timeout, verify_https ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -103,11 +103,19 @@ async def test_aiohttp_websocket_using_ssl_connection( url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(ws_ssl_server.testcert) + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False transport = AIOHTTPWebsocketsTransport( - url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout + url=url, + ssl_close_timeout=ssl_close_timeout, + **extra_args, ) async with Client(transport=transport) as session: @@ -130,6 +138,49 @@ async def test_aiohttp_websocket_using_ssl_connection( assert transport.websocket is None +@pytest.mark.asyncio +@pytest.mark.websockets +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("ssl_close_timeout", [10]) +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( + event_loop, ws_ssl_server, ssl_close_timeout, verify_https +): + + from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["ssl"] = True + + transport = AIOHTTPWebsocketsTransport( + url=url, + ssl_close_timeout=ssl_close_timeout, + **extra_args, + ) + + with pytest.raises(ClientConnectorCertificateError) as exc_info: + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + await session.execute(query1) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + # Check client is disconnect here + assert transport.websocket is None + + @pytest.mark.asyncio @pytest.mark.websockets @pytest.mark.parametrize("server", [server1_answers], indirect=True) diff --git a/tests/test_httpx.py b/tests/test_httpx.py index af12f717..8ef57a84 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -11,7 +11,7 @@ TransportServerError, ) -from .conftest import TemporaryFile, strip_braces_spaces +from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces # Marking all tests in this file with the httpx marker pytestmark = pytest.mark.httpx @@ -77,6 +77,118 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_httpx_query_https( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + from aiohttp import web + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert str(url).startswith("https://") + + def test_code(): + extra_args = {} + + if verify_https == "cert_provided": + cert, _ = get_localhost_ssl_context() + + extra_args["verify"] = cert.decode() + elif verify_https == "disabled": + extra_args["verify"] = False + + transport = HTTPXTransport( + url=url, + **extra_args, + ) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_httpx_query_https_self_cert_fail( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + """By default, we should verify the ssl certificate""" + from aiohttp import web + from httpx import ConnectError + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert str(url).startswith("https://") + + def test_code(): + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["verify"] = True + + transport = HTTPXTransport( + url=url, + **extra_args, + ) + + with pytest.raises(ConnectError) as exc_info: + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 17be0db5..47744538 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -14,7 +14,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) query1_str = """ query getContinents { @@ -1162,7 +1166,8 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_https(event_loop, ssl_aiohttp_server): +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_httpx_query_https(event_loop, ssl_aiohttp_server, verify_https): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1177,9 +1182,16 @@ async def handler(request): assert url.startswith("https://") - cert, _ = get_localhost_ssl_context() + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() - transport = HTTPXAsyncTransport(url=url, timeout=10, verify=cert.decode()) + extra_args["verify"] = ssl_context + elif verify_https == "disabled": + extra_args["verify"] = False + + transport = HTTPXAsyncTransport(url=url, timeout=10, **extra_args) async with Client(transport=transport) as session: @@ -1195,6 +1207,47 @@ async def handler(request): assert africa["code"] == "AF" +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_httpx_query_https_self_cert_fail( + event_loop, ssl_aiohttp_server, verify_https +): + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + from httpx import ConnectError + + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = str(server.make_url("/")) + + assert url.startswith("https://") + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["verify"] = True + + transport = HTTPXAsyncTransport(url=url, timeout=10, **extra_args) + + with pytest.raises(ConnectError) as exc_info: + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + await session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): diff --git a/tests/test_localhost_client.crt b/tests/test_localhost_client.crt new file mode 100644 index 00000000..0bbed2f5 --- /dev/null +++ b/tests/test_localhost_client.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDTTCCAjWgAwIBAgIJAJ6VG2cQlsepMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV +BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp +bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTE4MDUwNTE2NTc1NloYDzIwNjAwNTA0 +MTY1NzU2WjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM +EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAJSCtBWQ1sBZGWjNlSPXhR/PtgSnYxea+aF2 +V84YvCPL7E873xolG/n+dgXZ5YzeWVyYt7wVsFIr5AVOjiy7tlWdzqohM4epxINT +DTpZqtBQyz3huEdS9CnW7z5vaE2Ix4bDr5CIEjo4lE6IaktFuQ3pSPcArCLxJhWg +vIyLO27Bs3IZ/x8XcMOkdm0GK0a0xIEIyxCx8HjrmmXZSjIGtZraWxsu3dW8Flm8 +ep8S4+OmOMo3lRIhedp/Q2LNpHqmzcTJ9+1bLiLvMhA3m5MTG9o8PI+f2cfer92R +P32ZIxJTUC9NOlfw83sOWoTrBkxtCwE9EZbsYSVD47Egp0o4uTkCAwEAAaMwMC4w +LAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMA0G +CSqGSIb3DQEBCwUAA4IBAQA0imKp/rflfbDCCx78NdsR5rt0jKem2t3YPGT6tbeU ++FQz62SEdeD2OHWxpvfPf+6h3iTXJbkakr2R4lP3z7GHUe61lt3So9VHAvgbtPTH +aB1gOdThA83o0fzQtnIv67jCvE9gwPQInViZLEcm2iQEZLj6AuSvBKmluTR7vNRj +8/f2R4LsDfCWGrzk2W+deGRvSow7irS88NQ8BW8S8otgMiBx4D2UlOmQwqr6X+/r +jYIDuMb6GDKRXtBUGDokfE94hjj9u2mrNRwt8y4tqu8ZNa//yLEQ0Ow2kP3QJPLY +941VZpwRi2v/+JvI7OBYlvbOTFwM8nAk79k+Dgviygd9 +-----END CERTIFICATE----- diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index b13a8c55..666fec34 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -2,6 +2,8 @@ from gql import Client, gql +from .conftest import get_localhost_ssl_context_client + # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -56,17 +58,91 @@ async def test_phoenix_channel_query(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) + + query = gql(query_str) + async with Client(transport=transport) as session: + result = await session.execute(query) + + print("Client received:", result) + + +@pytest.mark.skip(reason="ssl=False is not working for now") +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_phoenix_channel_query_ssl( + event_loop, ws_ssl_server, query_str, verify_https +): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + + path = "/graphql" + server = ws_ssl_server + url = f"wss://{server.hostname}:{server.port}{path}" + + extra_args = {} + + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False + + transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", + url=url, + **extra_args, ) query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: result = await session.execute(query) print("Client received:", result) +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_phoenix_channel_query_ssl_self_cert_fail( + event_loop, ws_ssl_server, query_str, verify_https +): + from gql.transport.phoenix_channel_websockets import ( + PhoenixChannelWebsocketsTransport, + ) + from ssl import SSLCertVerificationError + + path = "/graphql" + server = ws_ssl_server + url = f"wss://{server.hostname}:{server.port}{path}" + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["ssl"] = True + + transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", + url=url, + **extra_args, + ) + + query = gql(query_str) + + with pytest.raises(SSLCertVerificationError) as exc_info: + async with Client(transport=transport) as session: + await session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + query2_str = """ subscription getContinents { continents { @@ -133,13 +209,11 @@ async def test_phoenix_channel_subscription(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) first_result = None query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for result in session.subscribe(query): first_result = result break diff --git a/tests/test_requests.py b/tests/test_requests.py index ba666243..95db0b3f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -11,7 +11,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) # Marking all tests in this file with the requests marker pytestmark = pytest.mark.requests @@ -77,6 +81,120 @@ def test_code(): await run_sync_test(event_loop, server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_requests_query_https( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + import warnings + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + with warnings.catch_warnings(): + + extra_args = {} + + if verify_https == "cert_provided": + cert_path, _ = get_localhost_ssl_context_client() + + extra_args["verify"] = cert_path + elif verify_https == "disabled": + extra_args["verify"] = False + + # Ignoring Insecure Request warning + warnings.filterwarnings("ignore") + + transport = RequestsHTTPTransport( + url=url, + **extra_args, + ) + + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_requests_query_https_self_cert_fail( + event_loop, ssl_aiohttp_server, run_sync_test, verify_https +): + """By default, we should verify the ssl certificate""" + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport + from requests.exceptions import SSLError + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await ssl_aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["verify"] = True + + transport = RequestsHTTPTransport( + url=url, + **extra_args, + ) + + with pytest.raises(SSLError) as exc_info: + with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query synchronously + session.execute(query) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + await run_sync_test(event_loop, server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 9e6fd4ab..56dd150f 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,6 +1,5 @@ import asyncio import json -import ssl import sys from typing import Dict, Mapping @@ -14,7 +13,7 @@ TransportServerError, ) -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper, get_localhost_ssl_context_client # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -89,9 +88,11 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport.websocket is None +@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): +@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) +async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_https): import websockets from gql.transport.websockets import WebsocketsTransport @@ -100,10 +101,16 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ssl_context.load_verify_locations(ws_ssl_server.testcert) + extra_args = {} - transport = WebsocketsTransport(url=url, ssl=ssl_context) + if verify_https == "cert_provided": + _, ssl_context = get_localhost_ssl_context_client() + + extra_args["ssl"] = ssl_context + elif verify_https == "disabled": + extra_args["ssl"] = False + + transport = WebsocketsTransport(url=url, **extra_args) async with Client(transport=transport) as session: @@ -129,6 +136,42 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): assert transport.websocket is None +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) +@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) +async def test_websocket_using_ssl_connection_self_cert_fail( + event_loop, ws_ssl_server, verify_https +): + from gql.transport.websockets import WebsocketsTransport + from ssl import SSLCertVerificationError + + server = ws_ssl_server + + url = f"wss://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + extra_args = {} + + if verify_https == "explicitely_enabled": + extra_args["ssl"] = True + + transport = WebsocketsTransport(url=url, **extra_args) + + with pytest.raises(SSLCertVerificationError) as exc_info: + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + await session.execute(query1) + + expected_error = "certificate verify failed: self-signed certificate" + + assert expected_error in str(exc_info.value) + + # Check client is disconnect here + assert transport.websocket is None + + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) From 0e678b9ffad14cdb64a7f5dc70459885b06215de Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 18 Feb 2025 14:57:56 +0100 Subject: [PATCH 186/239] Bump httpx to min 0.27 (#531) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e41836a7..fa8cc2f9 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ ] install_httpx_requires = [ - "httpx>=0.23.1,<1", + "httpx>=0.27.0,<1", ] install_websockets_requires = [ From 163723f2dfaf3bdfb40e3117cf4003f6486858da Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 18 Feb 2025 14:58:27 +0100 Subject: [PATCH 187/239] Chore bump websockets to 13.x (#532) --- setup.py | 2 +- tests/conftest.py | 2 +- tests/test_aiohttp_websocket_exceptions.py | 16 +++++++-------- ..._aiohttp_websocket_graphqlws_exceptions.py | 12 +++++------ ...iohttp_websocket_graphqlws_subscription.py | 20 +++++++++---------- tests/test_aiohttp_websocket_query.py | 8 ++++---- tests/test_aiohttp_websocket_subscription.py | 6 +++--- tests/test_appsync_websockets.py | 20 ++++++++++--------- tests/test_async_client_validation.py | 2 +- tests/test_graphqlws_exceptions.py | 12 +++++------ tests/test_graphqlws_subscription.py | 20 +++++++++---------- tests/test_phoenix_channel_exceptions.py | 6 +++--- tests/test_phoenix_channel_query.py | 4 ++-- tests/test_phoenix_channel_subscription.py | 4 ++-- tests/test_websocket_exceptions.py | 16 +++++++-------- tests/test_websocket_query.py | 8 ++++---- tests/test_websocket_subscription.py | 4 ++-- 17 files changed, 82 insertions(+), 80 deletions(-) diff --git a/setup.py b/setup.py index fa8cc2f9..a44c2e01 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ ] install_websockets_requires = [ - "websockets>=10,<12", + "websockets>=10.1,<14", ] install_botocore_requires = [ diff --git a/tests/conftest.py b/tests/conftest.py index c0b2037f..b0103a99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -501,7 +501,7 @@ def get_server_handler(request): else: answers = request.param - async def default_server_handler(ws, path): + async def default_server_handler(ws): try: await WebSocketServerHelper.send_connection_ack(ws) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index ea48824f..8ee44d2c 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -70,7 +70,7 @@ async def test_aiohttp_websocket_invalid_query( """ -async def server_invalid_subscription(ws, path): +async def server_invalid_subscription(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) @@ -108,7 +108,7 @@ async def test_aiohttp_websocket_invalid_subscription( ) -async def server_no_ack(ws, path): +async def server_no_ack(ws): await ws.wait_closed() @@ -129,7 +129,7 @@ async def test_aiohttp_websocket_server_does_not_send_ack( pass -async def server_connection_error(ws, path): +async def server_connection_error(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") @@ -158,7 +158,7 @@ async def test_aiohttp_websocket_sending_invalid_data( ) -async def server_invalid_payload(ws, path): +async def server_invalid_payload(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") @@ -253,7 +253,7 @@ async def test_aiohttp_websocket_transport_protocol_errors( await session.execute(query) -async def server_without_ack(ws, path): +async def server_without_ack(ws): # Sending something else than an ack await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -274,7 +274,7 @@ async def test_aiohttp_websocket_server_does_not_ack(event_loop, server): pass -async def server_closing_directly(ws, path): +async def server_closing_directly(ws): await ws.close() @@ -294,7 +294,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): pass -async def server_closing_after_ack(ws, path): +async def server_closing_after_ack(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.close() @@ -313,7 +313,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( await session.execute(query) -async def server_sending_invalid_query_errors(ws, path): +async def server_sending_invalid_query_errors(ws): await WebSocketServerHelper.send_connection_ack(ws) invalid_error = ( '{"type":"error","id":"404","payload":' diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index d87315c9..b234d296 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -68,7 +68,7 @@ async def test_aiohttp_websocket_graphqlws_invalid_query( """ -async def server_invalid_subscription(ws, path): +async def server_invalid_subscription(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) @@ -102,7 +102,7 @@ async def test_aiohttp_websocket_graphqlws_invalid_subscription( assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" -async def server_no_ack(ws, path): +async def server_no_ack(ws): await ws.wait_closed() @@ -130,7 +130,7 @@ async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( ) -async def server_invalid_query(ws, path): +async def server_invalid_query(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") @@ -207,7 +207,7 @@ async def test_aiohttp_websocket_graphqlws_transport_protocol_errors( await session.execute(query) -async def server_without_ack(ws, path): +async def server_without_ack(ws): # Sending something else than an ack await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -230,7 +230,7 @@ async def test_aiohttp_websocket_graphqlws_server_does_not_ack( pass -async def server_closing_directly(ws, path): +async def server_closing_directly(ws): await ws.close() @@ -252,7 +252,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( pass -async def server_closing_after_ack(ws, path): +async def server_closing_after_ack(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.close() diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 86ff96ab..d40d15ce 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -30,7 +30,7 @@ def server_countdown_factory( keepalive=False, answer_pings=True, simulate_disconnect=False ): - async def server_countdown_template(ws, path): + async def server_countdown_template(ws): import websockets logged_messages.clear() @@ -192,28 +192,28 @@ async def receiving_coro(): return server_countdown_template -async def server_countdown(ws, path): +async def server_countdown(ws): server = server_countdown_factory() - await server(ws, path) + await server(ws) -async def server_countdown_keepalive(ws, path): +async def server_countdown_keepalive(ws): server = server_countdown_factory(keepalive=True) - await server(ws, path) + await server(ws) -async def server_countdown_dont_answer_pings(ws, path): +async def server_countdown_dont_answer_pings(ws): server = server_countdown_factory(answer_pings=False) - await server(ws, path) + await server(ws) -async def server_countdown_disconnect(ws, path): +async def server_countdown_disconnect(ws): server = server_countdown_factory(simulate_disconnect=True) - await server(ws, path) + await server(ws) countdown_subscription_str = """ @@ -353,7 +353,7 @@ async def close_transport_task_coro(): assert count > 0 -async def server_countdown_close_connection_in_middle(ws, path): +async def server_countdown_close_connection_in_middle(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index ff2bcf02..d76d646f 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -228,7 +228,7 @@ async def test_aiohttp_websocket_two_queries_in_series( assert result1 == result2 -async def server1_two_queries_in_parallel(ws, path): +async def server1_two_queries_in_parallel(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}", file=sys.stderr) @@ -276,7 +276,7 @@ async def task2_coro(): assert result1 == result2 -async def server_closing_while_we_are_doing_something_else(ws, path): +async def server_closing_while_we_are_doing_something_else(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}", file=sys.stderr) @@ -434,7 +434,7 @@ async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transpor pass -async def server_with_authentication_in_connection_init_payload(ws, path): +async def server_with_authentication_in_connection_init_payload(ws): # Wait the connection_init message init_message_str = await ws.recv() init_message = json.loads(init_message_str) @@ -593,7 +593,7 @@ async def test_aiohttp_websocket_add_extra_parameters_to_connect( await session.execute(query) -async def server_sending_keep_alive_before_connection_ack(ws, path): +async def server_sending_keep_alive_before_connection_ack(ws): await WebSocketServerHelper.send_keepalive(ws) await WebSocketServerHelper.send_keepalive(ws) await WebSocketServerHelper.send_keepalive(ws) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 4bc6ad3c..9d2d652b 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -30,7 +30,7 @@ } -async def server_starwars(ws, path): +async def server_starwars(ws): import websockets await WebSocketServerHelper.send_connection_ack(ws) @@ -91,7 +91,7 @@ async def server_starwars(ws, path): logged_messages: List[str] = [] -async def server_countdown(ws, path): +async def server_countdown(ws): import websockets logged_messages.clear() @@ -343,7 +343,7 @@ async def close_transport_task_coro(): assert count > 0 -async def server_countdown_close_connection_in_middle(ws, path): +async def server_countdown_close_connection_in_middle(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 14c40e75..88bae8b6 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -107,7 +107,7 @@ def verify_headers(headers, in_query=False): "errorCode": 400, } - async def realtime_appsync_server_template(ws, path): + async def realtime_appsync_server_template(ws): import websockets logged_messages.clear() @@ -139,6 +139,8 @@ async def realtime_appsync_server_template(ws, path): ) return + path = ws.path + print(f"path = {path}") path_base, parameters_str = path.split("?") @@ -348,28 +350,28 @@ async def receiving_coro(): return realtime_appsync_server_template -async def realtime_appsync_server(ws, path): +async def realtime_appsync_server(ws): server = realtime_appsync_server_factory() - await server(ws, path) + await server(ws) -async def realtime_appsync_server_keepalive(ws, path): +async def realtime_appsync_server_keepalive(ws): server = realtime_appsync_server_factory(keepalive=True) - await server(ws, path) + await server(ws) -async def realtime_appsync_server_not_json_answer(ws, path): +async def realtime_appsync_server_not_json_answer(ws): server = realtime_appsync_server_factory(not_json_answer=True) - await server(ws, path) + await server(ws) -async def realtime_appsync_server_error_without_id(ws, path): +async def realtime_appsync_server_error_without_id(ws): server = realtime_appsync_server_factory(error_without_id=True) - await server(ws, path) + await server(ws) on_create_message_subscription_str = """ diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index d39019e8..acfabe0e 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -22,7 +22,7 @@ } -async def server_starwars(ws, path): +async def server_starwars(ws): import websockets await WebSocketServerHelper.send_connection_ack(ws) diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index ca689c47..befeeb4e 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -68,7 +68,7 @@ async def test_graphqlws_invalid_query( """ -async def server_invalid_subscription(ws, path): +async def server_invalid_subscription(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) @@ -102,7 +102,7 @@ async def test_graphqlws_invalid_subscription( assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR" -async def server_no_ack(ws, path): +async def server_no_ack(ws): await ws.wait_closed() @@ -130,7 +130,7 @@ async def test_graphqlws_server_does_not_send_ack( ) -async def server_invalid_query(ws, path): +async def server_invalid_query(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") @@ -205,7 +205,7 @@ async def test_graphqlws_transport_protocol_errors( await session.execute(query) -async def server_without_ack(ws, path): +async def server_without_ack(ws): # Sending something else than an ack await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -226,7 +226,7 @@ async def test_graphqlws_server_does_not_ack(event_loop, graphqlws_server): pass -async def server_closing_directly(ws, path): +async def server_closing_directly(ws): await ws.close() @@ -246,7 +246,7 @@ async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): pass -async def server_closing_after_ack(ws, path): +async def server_closing_after_ack(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.close() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index deeae395..683da43a 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -30,7 +30,7 @@ def server_countdown_factory( keepalive=False, answer_pings=True, simulate_disconnect=False ): - async def server_countdown_template(ws, path): + async def server_countdown_template(ws): import websockets logged_messages.clear() @@ -192,28 +192,28 @@ async def receiving_coro(): return server_countdown_template -async def server_countdown(ws, path): +async def server_countdown(ws): server = server_countdown_factory() - await server(ws, path) + await server(ws) -async def server_countdown_keepalive(ws, path): +async def server_countdown_keepalive(ws): server = server_countdown_factory(keepalive=True) - await server(ws, path) + await server(ws) -async def server_countdown_dont_answer_pings(ws, path): +async def server_countdown_dont_answer_pings(ws): server = server_countdown_factory(answer_pings=False) - await server(ws, path) + await server(ws) -async def server_countdown_disconnect(ws, path): +async def server_countdown_disconnect(ws): server = server_countdown_factory(simulate_disconnect=True) - await server(ws, path) + await server(ws) countdown_subscription_str = """ @@ -353,7 +353,7 @@ async def close_transport_task_coro(): assert count > 0 -async def server_countdown_close_connection_in_middle(ws, path): +async def server_countdown_close_connection_in_middle(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index e2bf0091..c042ce01 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -127,7 +127,7 @@ def ensure_list(s): def query_server(server_answers=default_query_server_answer): from .conftest import PhoenixChannelServerHelper - async def phoenix_server(ws, path): + async def phoenix_server(ws): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() for server_answer in ensure_list(server_answers): @@ -138,7 +138,7 @@ async def phoenix_server(ws, path): return phoenix_server -async def no_connection_ack_phoenix_server(ws, path): +async def no_connection_ack_phoenix_server(ws): from .conftest import PhoenixChannelServerHelper await ws.recv() @@ -363,7 +363,7 @@ def subscription_server( from .conftest import PhoenixChannelServerHelper import json - async def phoenix_server(ws, path): + async def phoenix_server(ws): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() if server_answers is not None: diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 666fec34..f39edacb 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -38,7 +38,7 @@ def ws_server_helper(request): yield PhoenixChannelServerHelper -async def query_server(ws, path): +async def query_server(ws): from .conftest import PhoenixChannelServerHelper await PhoenixChannelServerHelper.send_connection_ack(ws) @@ -185,7 +185,7 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( ) -async def subscription_server(ws, path): +async def subscription_server(ws): from .conftest import PhoenixChannelServerHelper await PhoenixChannelServerHelper.send_connection_ack(ws) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 34564c6d..6193c658 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -77,7 +77,7 @@ ) -async def server_countdown(ws, path): +async def server_countdown(ws): import websockets from .conftest import MS, PhoenixChannelServerHelper @@ -295,7 +295,7 @@ async def testing_stopping_without_break(): ) -async def phoenix_heartbeat_server(ws, path): +async def phoenix_heartbeat_server(ws): import websockets from .conftest import PhoenixChannelServerHelper diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 72db8a87..cb9e7274 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -69,7 +69,7 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str) """ -async def server_invalid_subscription(ws, path): +async def server_invalid_subscription(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) @@ -105,7 +105,7 @@ async def test_websocket_invalid_subscription(event_loop, client_and_server, que ) -async def server_no_ack(ws, path): +async def server_no_ack(ws): await ws.wait_closed() @@ -124,7 +124,7 @@ async def test_websocket_server_does_not_send_ack(event_loop, server, query_str) pass -async def server_connection_error(ws, path): +async def server_connection_error(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") @@ -151,7 +151,7 @@ async def test_websocket_sending_invalid_data(event_loop, client_and_server, que ) -async def server_invalid_payload(ws, path): +async def server_invalid_payload(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") @@ -244,7 +244,7 @@ async def test_websocket_transport_protocol_errors(event_loop, client_and_server await session.execute(query) -async def server_without_ack(ws, path): +async def server_without_ack(ws): # Sending something else than an ack await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -265,7 +265,7 @@ async def test_websocket_server_does_not_ack(event_loop, server): pass -async def server_closing_directly(ws, path): +async def server_closing_directly(ws): await ws.close() @@ -285,7 +285,7 @@ async def test_websocket_server_closing_directly(event_loop, server): pass -async def server_closing_after_ack(ws, path): +async def server_closing_after_ack(ws): await WebSocketServerHelper.send_connection_ack(ws) await ws.close() @@ -309,7 +309,7 @@ async def test_websocket_server_closing_after_ack(event_loop, client_and_server) await session.execute(query) -async def server_sending_invalid_query_errors(ws, path): +async def server_sending_invalid_query_errors(ws): await WebSocketServerHelper.send_connection_ack(ws) invalid_error = ( '{"type":"error","id":"404","payload":' diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 56dd150f..2c723b3f 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -214,7 +214,7 @@ async def test_websocket_two_queries_in_series( assert result1 == result2 -async def server1_two_queries_in_parallel(ws, path): +async def server1_two_queries_in_parallel(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}", file=sys.stderr) @@ -261,7 +261,7 @@ async def task2_coro(): assert result1 == result2 -async def server_closing_while_we_are_doing_something_else(ws, path): +async def server_closing_while_we_are_doing_something_else(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}", file=sys.stderr) @@ -402,7 +402,7 @@ async def test_websocket_trying_to_connect_to_already_connected_transport( pass -async def server_with_authentication_in_connection_init_payload(ws, path): +async def server_with_authentication_in_connection_init_payload(ws): # Wait the connection_init message init_message_str = await ws.recv() init_message = json.loads(init_message_str) @@ -545,7 +545,7 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server): await session.execute(query) -async def server_sending_keep_alive_before_connection_ack(ws, path): +async def server_sending_keep_alive_before_connection_ack(ws): await WebSocketServerHelper.send_keepalive(ws) await WebSocketServerHelper.send_keepalive(ws) await WebSocketServerHelper.send_keepalive(ws) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 38307349..5af44d59 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -27,7 +27,7 @@ logged_messages: List[str] = [] -async def server_countdown(ws, path): +async def server_countdown(ws): import websockets logged_messages.clear() @@ -274,7 +274,7 @@ async def close_transport_task_coro(): assert count > 0 -async def server_countdown_close_connection_in_middle(ws, path): +async def server_countdown_close_connection_in_middle(ws): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() From 7cada5109e9d7834364b58225cec2bf45e81c91f Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 18 Feb 2025 15:02:57 +0100 Subject: [PATCH 188/239] Bump version number to 3.6.0b4 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 6361d12f..cfe6b54e 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0b3" +__version__ = "3.6.0b4" From 740e98848e815d237ea7bfd9a168fd6e9b3a6e95 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 19 Feb 2025 15:42:36 +0100 Subject: [PATCH 189/239] Add the warning from PR #530 in the stable branch (#533) * Restrict graphql-core to <3.2.4 to fix tests --- gql/transport/aiohttp.py | 30 +++++++++++++++++++++++++++--- setup.py | 5 +++-- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 60f42c94..65c15997 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -3,8 +3,19 @@ import io import json import logging +import warnings from ssl import SSLContext -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + cast, +) import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -46,7 +57,7 @@ def __init__( headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, - ssl: Union[SSLContext, bool, Fingerprint] = False, + ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning", timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, @@ -74,7 +85,20 @@ def __init__( self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth - self.ssl: Union[SSLContext, bool, Fingerprint] = ssl + + if ssl == "ssl_warning": + ssl = False + if str(url).startswith("https"): + warnings.warn( + "WARNING: By default, AIOHTTPTransport does not verify" + " ssl certificates. This will be fixed in the next major version." + " You can set ssl=True to force the ssl certificate verification" + " or ssl=False to disable this warning" + ) + + self.ssl: Union[SSLContext, bool, Fingerprint] = cast( + Union[SSLContext, bool, Fingerprint], ssl + ) self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args diff --git a/setup.py b/setup.py index 233900d2..66cf2d01 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.2,<3.3", + "graphql-core>=3.2,<3.2.4", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", @@ -20,7 +20,8 @@ "pytest-console-scripts==1.3.1", "pytest-cov==3.0.0", "mock==4.0.2", - "vcrpy==4.4.0", + "vcrpy==4.4.0;python_version<='3.8'", + "vcrpy==7.0.0;python_version>'3.8'", "aiofiles", ] From 46e188cbd2755347496b87cdcef99cd90a34fa1e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 19 Feb 2025 15:44:13 +0100 Subject: [PATCH 190/239] Bump version number to 3.5.1 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index dcbfb52f..0c11babd 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.0" +__version__ = "3.5.1" From ba70d2d79161b3c2417e3821cd9252728ed07ddb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 6 Mar 2025 11:35:37 +0100 Subject: [PATCH 191/239] Allow graphql-core 3.2.4 by retrofitting introspection commits (#535) Using gql version of the get_introspection_query method (#523) Adding the input_value_deprecation argument to get_introspection_query_ast (#524) Fix test for introspection type recursion level change in graphql-core v3.3.0a7 (#521) Bump graphql-core to <3.2.5 --- gql/client.py | 19 ++-- gql/utilities/get_introspection_query_ast.py | 20 +++- gql/utilities/node_tree.py | 5 +- setup.py | 2 +- tests/starwars/test_dsl.py | 112 +++++++++++-------- 5 files changed, 99 insertions(+), 59 deletions(-) diff --git a/gql/client.py b/gql/client.py index a79d4b72..4800fb2d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -29,7 +29,6 @@ GraphQLSchema, IntrospectionQuery, build_ast_schema, - get_introspection_query, parse, validate, ) @@ -39,7 +38,7 @@ from .transport.exceptions import TransportClosed, TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport -from .utilities import build_client_schema +from .utilities import build_client_schema, get_introspection_query_ast from .utilities import parse_result as parse_result_fn from .utilities import serialize_variable_values from .utils import str_first_element @@ -98,8 +97,8 @@ def __init__( :param transport: The provided :ref:`transport `. :param fetch_schema_from_transport: Boolean to indicate that if we want to fetch the schema from the transport using an introspection query. - :param introspection_args: arguments passed to the get_introspection_query - method of graphql-core. + :param introspection_args: arguments passed to the + :meth:`gql.utilities.get_introspection_query_ast` method. :param execute_timeout: The maximum time in seconds for the execution of a request before a TimeoutError is raised. Only used for async transports. Passing None results in waiting forever for a response. @@ -1289,8 +1288,10 @@ def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" - introspection_query = get_introspection_query(**self.client.introspection_args) - execution_result = self.transport.execute(parse(introspection_query)) + introspection_query = get_introspection_query_ast( + **self.client.introspection_args + ) + execution_result = self.transport.execute(introspection_query) self.client._build_schema_from_introspection(execution_result) @@ -1657,8 +1658,10 @@ async def fetch_schema(self) -> None: Don't use this function and instead set the fetch_schema_from_transport attribute to True""" - introspection_query = get_introspection_query(**self.client.introspection_args) - execution_result = await self.transport.execute(parse(introspection_query)) + introspection_query = get_introspection_query_ast( + **self.client.introspection_args + ) + execution_result = await self.transport.execute(introspection_query) self.client._build_schema_from_introspection(execution_result) diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index d35a2a75..975ccc83 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -10,6 +10,7 @@ def get_introspection_query_ast( specified_by_url: bool = False, directive_is_repeatable: bool = False, schema_description: bool = False, + input_value_deprecation: bool = False, type_recursion_level: int = 7, ) -> DocumentNode: """Get a query for introspection as a document using the DSL module. @@ -43,13 +44,20 @@ def get_introspection_query_ast( directives = ds.__Schema.directives.select(ds.__Directive.name) + deprecated_expand = {} + + if input_value_deprecation: + deprecated_expand = { + "includeDeprecated": True, + } + if descriptions: directives.select(ds.__Directive.description) if directive_is_repeatable: directives.select(ds.__Directive.isRepeatable) directives.select( ds.__Directive.locations, - ds.__Directive.args.select(fragment_InputValue), + ds.__Directive.args(**deprecated_expand).select(fragment_InputValue), ) schema.select(directives) @@ -69,7 +77,7 @@ def get_introspection_query_ast( fields.select(ds.__Field.description) fields.select( - ds.__Field.args.select(fragment_InputValue), + ds.__Field.args(**deprecated_expand).select(fragment_InputValue), ds.__Field.type.select(fragment_TypeRef), ds.__Field.isDeprecated, ds.__Field.deprecationReason, @@ -89,7 +97,7 @@ def get_introspection_query_ast( fragment_FullType.select( fields, - ds.__Type.inputFields.select(fragment_InputValue), + ds.__Type.inputFields(**deprecated_expand).select(fragment_InputValue), ds.__Type.interfaces.select(fragment_TypeRef), enum_values, ds.__Type.possibleTypes.select(fragment_TypeRef), @@ -105,6 +113,12 @@ def get_introspection_query_ast( ds.__InputValue.defaultValue, ) + if input_value_deprecation: + fragment_InputValue.select( + ds.__InputValue.isDeprecated, + ds.__InputValue.deprecationReason, + ) + fragment_TypeRef.select( ds.__Type.kind, ds.__Type.name, diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index c307d937..4313188e 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -19,7 +19,7 @@ def _node_tree_recursive( results.append(" " * indent + f"{type(obj).__name__}") try: - keys = obj.keys + keys = sorted(obj.keys) except AttributeError: # If the object has no keys attribute, print its repr and return. results.append(" " * (indent + 1) + repr(obj)) @@ -70,6 +70,9 @@ def node_tree( Useful to debug deep DocumentNode instances created by gql or dsl_gql. + NOTE: from gql version 3.6.0b4 the elements of each node are sorted to ignore + small changes in graphql-core + WARNING: the output of this method is not guaranteed and may change without notice. """ diff --git a/setup.py b/setup.py index 66cf2d01..f34b2e35 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.2,<3.2.4", + "graphql-core>=3.2,<3.2.5", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 2aadf92f..5cd051ba 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -984,18 +984,36 @@ def test_get_introspection_query_ast(option): specified_by_url=option, directive_is_repeatable=option, schema_description=option, + input_value_deprecation=option, ) dsl_introspection_query = get_introspection_query_ast( descriptions=option, specified_by_url=option, directive_is_repeatable=option, schema_description=option, + input_value_deprecation=option, ) - assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) - assert node_tree(dsl_introspection_query) == node_tree( - gql(print_ast(dsl_introspection_query)) - ) + try: + assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert node_tree(dsl_introspection_query) == node_tree( + gql(print_ast(dsl_introspection_query)) + ) + except AssertionError: + + # From graphql-core version 3.3.0a7, there is two more type recursion levels + dsl_introspection_query = get_introspection_query_ast( + descriptions=option, + specified_by_url=option, + directive_is_repeatable=option, + schema_description=option, + input_value_deprecation=option, + type_recursion_level=9, + ) + assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert node_tree(dsl_introspection_query) == node_tree( + gql(print_ast(dsl_introspection_query)) + ) def test_typename_aliased(ds): @@ -1028,11 +1046,10 @@ def test_node_tree_with_loc(ds): node_tree_result = """ DocumentNode - loc: - Location - definitions: OperationDefinitionNode + directives: + empty tuple loc: Location @@ -1043,10 +1060,8 @@ def test_node_tree_with_loc(ds): value: 'GetHeroName' - directives: - empty tuple - variable_definitions: - empty tuple + operation: + selection_set: SelectionSetNode loc: @@ -1054,13 +1069,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1068,8 +1085,6 @@ def test_node_tree_with_loc(ds): value: 'hero' - arguments: - empty tuple nullability_assertion: None selection_set: @@ -1079,13 +1094,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1093,23 +1110,23 @@ def test_node_tree_with_loc(ds): value: 'name' - arguments: - empty tuple nullability_assertion: None selection_set: None - operation: - + variable_definitions: + empty tuple + loc: + Location + """.strip() node_tree_result_stable = """ DocumentNode - loc: - Location - definitions: OperationDefinitionNode + directives: + empty tuple loc: Location @@ -1120,10 +1137,8 @@ def test_node_tree_with_loc(ds): value: 'GetHeroName' - directives: - empty tuple - variable_definitions: - empty tuple + operation: + selection_set: SelectionSetNode loc: @@ -1131,13 +1146,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1145,8 +1162,6 @@ def test_node_tree_with_loc(ds): value: 'hero' - arguments: - empty tuple selection_set: SelectionSetNode loc: @@ -1154,13 +1169,15 @@ def test_node_tree_with_loc(ds): selections: FieldNode + alias: + None + arguments: + empty tuple + directives: + empty tuple loc: Location - directives: - empty tuple - alias: - None name: NameNode loc: @@ -1168,14 +1185,17 @@ def test_node_tree_with_loc(ds): value: 'name' - arguments: - empty tuple selection_set: None - operation: - + variable_definitions: + empty tuple + loc: + Location + """.strip() + print(node_tree(document, ignore_loc=False)) + try: assert node_tree(document, ignore_loc=False) == node_tree_result except AssertionError: From 7881a9b3e594c4522927cc8a0f8f6bc2bb5d5989 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 6 Mar 2025 11:37:00 +0100 Subject: [PATCH 192/239] Bump version number to 3.5.2 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 0c11babd..dae42b1b 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.1" +__version__ = "3.5.2" From 19fd8107576ad026903d0256377708c7a6272f41 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 6 Mar 2025 18:34:28 +0100 Subject: [PATCH 193/239] Put ListenerQueue in separate file --- gql/transport/aiohttp_websockets.py | 61 ++----------------- gql/transport/websockets_base.py | 55 +---------------- gql/transport/websockets_common/__init__.py | 3 + .../websockets_common/listener_queue.py | 58 ++++++++++++++++++ 4 files changed, 66 insertions(+), 111 deletions(-) create mode 100644 gql/transport/websockets_common/__init__.py create mode 100644 gql/transport/websockets_common/listener_queue.py diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 18699b5e..9b84bd9b 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -22,72 +22,19 @@ from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy -from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.async_transport import AsyncTransport -from gql.transport.exceptions import ( +from .aiohttp import AIOHTTPTransport +from .async_transport import AsyncTransport +from .exceptions import ( TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .websockets_common import ListenerQueue log = logging.getLogger("gql.transport.aiohttp_websockets") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - class AIOHTTPWebsocketsTransport(AsyncTransport): diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py index accca275..f8694c16 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_base.py @@ -21,63 +21,10 @@ TransportQueryError, TransportServerError, ) +from .websockets_common import ListenerQueue log = logging.getLogger("gql.transport.websockets") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - class WebsocketsTransportBase(AsyncTransport): """abstract :ref:`Async Transport ` used to implement diff --git a/gql/transport/websockets_common/__init__.py b/gql/transport/websockets_common/__init__.py new file mode 100644 index 00000000..7661cf87 --- /dev/null +++ b/gql/transport/websockets_common/__init__.py @@ -0,0 +1,3 @@ +from .listener_queue import ListenerQueue, ParsedAnswer + +__all__ = ["ListenerQueue", "ParsedAnswer"] diff --git a/gql/transport/websockets_common/listener_queue.py b/gql/transport/websockets_common/listener_queue.py new file mode 100644 index 00000000..54aa650f --- /dev/null +++ b/gql/transport/websockets_common/listener_queue.py @@ -0,0 +1,58 @@ +import asyncio +from typing import Optional, Tuple + +from graphql import ExecutionResult + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True From 5cb5b9a8878b0bd09e9a55c6f836bf904077f986 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 6 Mar 2025 18:41:04 +0100 Subject: [PATCH 194/239] Moving websockets_base.py into websockets_common folder --- gql/transport/phoenix_channel_websockets.py | 2 +- gql/transport/websockets.py | 2 +- .../{websockets_base.py => websockets_common/base.py} | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) rename gql/transport/{websockets_base.py => websockets_common/base.py} (99%) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 08cde8cc..a7b256eb 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -11,7 +11,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase +from .websockets_common.base import WebsocketsTransportBase log = logging.getLogger(__name__) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 02abb61f..adebf249 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -15,7 +15,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase +from .websockets_common.base import WebsocketsTransportBase log = logging.getLogger(__name__) diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_common/base.py similarity index 99% rename from gql/transport/websockets_base.py rename to gql/transport/websockets_common/base.py index f8694c16..4a07a10d 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/websockets_common/base.py @@ -13,15 +13,15 @@ from websockets.exceptions import ConnectionClosed from websockets.typing import Data, Subprotocol -from .async_transport import AsyncTransport -from .exceptions import ( +from ..async_transport import AsyncTransport +from ..exceptions import ( TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .websockets_common import ListenerQueue +from .listener_queue import ListenerQueue log = logging.getLogger("gql.transport.websockets") From c369d2a1b67bb485b38202f18926a25c6b54bc0b Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 7 Mar 2025 23:37:43 +0100 Subject: [PATCH 195/239] Refactor WebSockets Transport with Dependency Injection Architecture This major architectural improvement implements dependency injection patterns across the WebSockets transport layer, creating a more modular, testable, and extensible system: - Created abstract AdapterConnection interface in common/adapters/connection.py - Implemented concrete WebSocketsAdapter to wrap the websockets library - Moved websockets_base.py to common/base.py maintaining better structure which is independant of the websockets library used - Added new TransportConnectionClosed exception for clearer error handling - Reorganized code with proper separation of concerns: * Moved common functionality into dedicated adapters folder * Isolated connection handling from transport business logic * Separated ListenerQueue into its own file for better modularity Potential Breaking changes: * New TransportConnectionClosed Exception replacing ConnectionClosed Exception * websocket attribute removed from transport, now using _connected to check if the transport is connected --- gql/transport/aiohttp_websockets.py | 2 +- gql/transport/appsync_websockets.py | 2 +- gql/transport/common/__init__.py | 10 ++ gql/transport/common/adapters/__init__.py | 3 + gql/transport/common/adapters/connection.py | 54 +++++++ gql/transport/common/adapters/websockets.py | 142 +++++++++++++++++ .../{websockets_common => common}/base.py | 148 ++++++------------ .../listener_queue.py | 0 gql/transport/exceptions.py | 7 + gql/transport/phoenix_channel_websockets.py | 4 +- gql/transport/websockets.py | 38 ++--- gql/transport/websockets_base.py | 93 +++++++++++ gql/transport/websockets_common/__init__.py | 3 - tests/conftest.py | 3 +- tests/test_graphqlws_exceptions.py | 8 +- tests/test_graphqlws_subscription.py | 9 +- tests/test_phoenix_channel_query.py | 4 + tests/test_websocket_exceptions.py | 10 +- tests/test_websocket_query.py | 73 +++++++-- tests/test_websocket_subscription.py | 6 +- tests/test_websockets_adapter.py | 98 ++++++++++++ 21 files changed, 556 insertions(+), 161 deletions(-) create mode 100644 gql/transport/common/__init__.py create mode 100644 gql/transport/common/adapters/__init__.py create mode 100644 gql/transport/common/adapters/connection.py create mode 100644 gql/transport/common/adapters/websockets.py rename gql/transport/{websockets_common => common}/base.py (78%) rename gql/transport/{websockets_common => common}/listener_queue.py (100%) create mode 100644 gql/transport/websockets_base.py delete mode 100644 gql/transport/websockets_common/__init__.py create mode 100644 tests/test_websockets_adapter.py diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 9b84bd9b..f97fbba8 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -24,6 +24,7 @@ from .aiohttp import AIOHTTPTransport from .async_transport import AsyncTransport +from .common import ListenerQueue from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -31,7 +32,6 @@ TransportQueryError, TransportServerError, ) -from .websockets_common import ListenerQueue log = logging.getLogger("gql.transport.aiohttp_websockets") diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 66091747..0d5139c3 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -181,7 +181,7 @@ async def _send_query( return query_id - subscribe = WebsocketsTransportBase.subscribe + subscribe = WebsocketsTransportBase.subscribe # type: ignore[assignment] """Send a subscription query and receive the results using a python async generator. diff --git a/gql/transport/common/__init__.py b/gql/transport/common/__init__.py new file mode 100644 index 00000000..a60ce0b0 --- /dev/null +++ b/gql/transport/common/__init__.py @@ -0,0 +1,10 @@ +from .adapters import AdapterConnection +from .base import SubscriptionTransportBase +from .listener_queue import ListenerQueue, ParsedAnswer + +__all__ = [ + "AdapterConnection", + "ListenerQueue", + "ParsedAnswer", + "SubscriptionTransportBase", +] diff --git a/gql/transport/common/adapters/__init__.py b/gql/transport/common/adapters/__init__.py new file mode 100644 index 00000000..593c46b6 --- /dev/null +++ b/gql/transport/common/adapters/__init__.py @@ -0,0 +1,3 @@ +from .connection import AdapterConnection + +__all__ = ["AdapterConnection"] diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py new file mode 100644 index 00000000..fbe38e3b --- /dev/null +++ b/gql/transport/common/adapters/connection.py @@ -0,0 +1,54 @@ +import abc +from typing import Dict + + +class AdapterConnection(abc.ABC): + """Abstract interface for subscription connections. + + This allows different WebSocket implementations to be used interchangeably. + """ + + @abc.abstractmethod + async def connect(self) -> None: + """Connect to the server.""" + pass # pragma: no cover + + @abc.abstractmethod + async def send(self, message: str) -> None: + """Send message to the server. + + Args: + message: String message to send + + Raises: + TransportConnectionClosed: If connection closed + """ + pass # pragma: no cover + + @abc.abstractmethod + async def receive(self) -> str: + """Receive message from the server. + + Returns: + String message received + + Raises: + TransportConnectionClosed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + pass # pragma: no cover + + @abc.abstractmethod + async def close(self) -> None: + """Close the connection.""" + pass # pragma: no cover + + @property + @abc.abstractmethod + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the connection. + + Returns: + Dictionary of response headers + """ + pass # pragma: no cover diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py new file mode 100644 index 00000000..95fbaf39 --- /dev/null +++ b/gql/transport/common/adapters/websockets.py @@ -0,0 +1,142 @@ +from ssl import SSLContext +from typing import Any, Dict, Optional, Union + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.datastructures import Headers, HeadersLike +from websockets.exceptions import WebSocketException + +from ...exceptions import TransportConnectionClosed, TransportProtocolError +from .connection import AdapterConnection + + +class WebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the websockets library.""" + + def __init__( + self, + url: str, + *, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param connect_args: Other parameters forwarded to websockets.connect + """ + self.url: str = url + self._headers: Optional[HeadersLike] = headers + self.ssl: Union[SSLContext, bool] = ssl + self.connect_args = connect_args + + self.websocket: Optional[WebSocketClientProtocol] = None + self._response_headers: Optional[Headers] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + } + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + try: + self.websocket = await websockets.client.connect(self.url, **connect_args) + except WebSocketException as e: + raise TransportConnectionClosed("Connection was closed") from e + + self._response_headers = self.websocket.response_headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionClosed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + try: + await self.websocket.send(message) + except WebSocketException as e: + raise TransportConnectionClosed("Connection was closed") from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionClosed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + try: + data = await self.websocket.recv() + except WebSocketException as e: + # When the connection is closed, make sure to clean up resources + self.websocket = None + raise TransportConnectionClosed("Connection was closed") from e + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + return answer + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self.websocket: + websocket = self.websocket + self.websocket = None + await websocket.close() + + @property + def headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return dict(self._headers) + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers.raw_items()) + return {} diff --git a/gql/transport/websockets_common/base.py b/gql/transport/common/base.py similarity index 78% rename from gql/transport/websockets_common/base.py rename to gql/transport/common/base.py index 4a07a10d..9ee07dd8 100644 --- a/gql/transport/websockets_common/base.py +++ b/gql/transport/common/base.py @@ -3,79 +3,54 @@ import warnings from abc import abstractmethod from contextlib import suppress -from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union -import websockets from graphql import DocumentNode, ExecutionResult -from websockets.client import WebSocketClientProtocol -from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol from ..async_transport import AsyncTransport from ..exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .adapters import AdapterConnection from .listener_queue import ListenerQueue -log = logging.getLogger("gql.transport.websockets") +log = logging.getLogger("gql.transport.common.base") -class WebsocketsTransportBase(AsyncTransport): +class SubscriptionTransportBase(AsyncTransport): """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. + different subscription protocols (mainly websockets). """ def __init__( self, - url: str, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + *, + adapter: AdapterConnection, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. + :param adapter: The connection dependency adapter :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. + of the connection. If None is provided this will wait forever. :param close_timeout: Timeout in seconds for the close. If None is provided this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect """ - self.url: str = url - self.headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl - self.init_payload: Dict[str, Any] = init_payload - self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.adapter: AdapterConnection = adapter - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} @@ -105,18 +80,14 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - self._connecting: bool = False + self._connected: bool = False self.close_exception: Optional[Exception] = None - # The list of supported subprotocols should be defined in the subclass - self.supported_subprotocols: List[Subprotocol] = [] - - self.response_headers: Optional[Headers] = None + @property + def response_headers(self) -> Dict[str, str]: + return self.adapter.response_headers async def _initialize(self): """Hook to send the initialization messages after the connection @@ -153,36 +124,30 @@ async def _connection_terminate(self): pass # pragma: no cover async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" + """Send the provided message to the adapter connection and log the message""" - if not self.websocket: + if not self._connected: raise TransportClosed( "Transport is not connected" ) from self.close_exception try: - await self.websocket.send(message) + await self.adapter.send(message) log.info(">>> %s", message) - except ConnectionClosed as e: + except TransportConnectionClosed as e: await self._fail(e, clean_close=False) raise e async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" + """Wait the next message from the connection and log the answer""" - # It is possible that the websocket has been already closed in another task - if self.websocket is None: + # It is possible that the connection has been already closed in another task + if not self._connected: raise TransportClosed("Transport is already closed") - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data + # Wait for the next frame. + # Can raise TransportConnectionClosed or TransportProtocolError + answer: str = await self.adapter.receive() log.info("<<< %s", answer) @@ -243,10 +208,10 @@ async def _receive_data_loop(self) -> None: try: while True: - # Wait the next answer from the websocket server + # Wait the next answer from the server try: answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: + except (TransportConnectionClosed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break except TransportClosed: @@ -331,7 +296,7 @@ async def subscribe( while True: # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. + # This can raise TransportError or TransportConnectionClosed answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -394,52 +359,30 @@ async def connect(self) -> None: - send the init message - wait for the connection acknowledge from the server - create an asyncio task which will be used to receive - and parse the websocket answers + and parse the answers Should be cleaned with a call to the close coroutine """ log.debug("connect: starting") - if self.websocket is None and not self._connecting: + if not self._connected and not self._connecting: # Set connecting to True to avoid a race condition if user is trying # to connect twice using the same client at the same time self._connecting = True - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url # Generate a TimeoutError if taking more than connect_timeout seconds # Set the _connecting flag to False after in all cases try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), + await asyncio.wait_for( + self.adapter.connect(), self.connect_timeout, ) + self._connected = True finally: self._connecting = False - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - self.response_headers = self.websocket.response_headers - # Run the after_connect hook of the subclass await self._after_connect() @@ -452,7 +395,7 @@ async def connect(self) -> None: # if no ACKs are received within the ack_timeout try: await self._initialize() - except ConnectionClosed as e: + except TransportConnectionClosed as e: raise e except ( TransportProtocolError, @@ -531,7 +474,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: try: # We should always have an active websocket connection here - assert self.websocket is not None + assert self._connected # Properly shut down liveness checker if enabled if self.check_keep_alive_task is not None: @@ -560,11 +503,11 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: for query_id, listener in self.listeners.items(): await listener.set_exception(e) - log.debug("_close_coro: close websocket connection") + log.debug("_close_coro: close connection") - await self.websocket.close() + await self.adapter.close() - log.debug("_close_coro: websocket connection closed") + log.debug("_close_coro: connection closed") except Exception as exc: # pragma: no cover log.warning("Exception catched in _close_coro: " + repr(exc)) @@ -573,7 +516,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: start cleanup") - self.websocket = None + self._connected = False self.close_task = None self.check_keep_alive_task = None self._wait_closed.set() @@ -585,12 +528,12 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: if self.close_task is None: - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: + if self._connected: self.close_task = asyncio.shield( asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) ) + else: + log.debug("_fail started with self._connected:False -> already closed") else: log.debug( "close_task is not None in _fail. Previous exception is: " @@ -602,7 +545,7 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: async def close(self) -> None: log.debug("close: starting") - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self._fail(TransportClosed("Transport closed by user")) await self.wait_closed() log.debug("close: done") @@ -610,6 +553,9 @@ async def close(self) -> None: async def wait_closed(self) -> None: log.debug("wait_close: starting") - await self._wait_closed.wait() + try: + await asyncio.wait_for(self._wait_closed.wait(), self.close_timeout) + except asyncio.TimeoutError: + log.debug("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") diff --git a/gql/transport/websockets_common/listener_queue.py b/gql/transport/common/listener_queue.py similarity index 100% rename from gql/transport/websockets_common/listener_queue.py rename to gql/transport/common/listener_queue.py diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 7ec27a33..27cefe2f 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -61,6 +61,13 @@ class TransportClosed(TransportError): """ +class TransportConnectionClosed(TransportError): + """Transport adapter connection closed. + + This exception is by the connection adapter code when a connection closed. + """ + + class TransportAlreadyConnected(TransportError): """Transport is already connected. diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index a7b256eb..382e9014 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -11,7 +11,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_common.base import WebsocketsTransportBase +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -370,7 +370,7 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: if answer_type == "close": - await self.close() + pass else: await super()._handle_answer(answer_type, answer_id, execution_result) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index adebf249..929761e6 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -15,7 +15,7 @@ TransportQueryError, TransportServerError, ) -from .websockets_common.base import WebsocketsTransportBase +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -36,6 +36,7 @@ class WebsocketsTransport(WebsocketsTransportBase): def __init__( self, url: str, + *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, init_payload: Dict[str, Any] = {}, @@ -83,16 +84,24 @@ def __init__( By default: both apollo and graphql-ws subprotocols. """ + if subprotocols is None: + subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + # Initiliaze WebsocketsTransportBase parent class super().__init__( url, - headers, - ssl, - init_payload, - connect_timeout, - close_timeout, - ack_timeout, - keep_alive_timeout, - connect_args, + headers=headers, + ssl=ssl, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + connect_args=connect_args, + subprotocols=subprotocols, ) self.ping_interval: Optional[Union[int, float]] = ping_interval @@ -115,14 +124,6 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - if subprotocols is None: - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - else: - self.supported_subprotocols = subprotocols - async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" @@ -485,9 +486,8 @@ async def _handle_answer( async def _after_connect(self): # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] except KeyError: # If the server does not send the subprotocol header, using # the apollo subprotocol by default diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py new file mode 100644 index 00000000..95e54b3f --- /dev/null +++ b/gql/transport/websockets_base.py @@ -0,0 +1,93 @@ +from ssl import SSLContext +from typing import Any, Dict, List, Optional, Union + +from websockets.datastructures import HeadersLike +from websockets.typing import Subprotocol + +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase + + +class WebsocketsTransportBase(SubscriptionTransportBase): + """abstract :ref:`Async Transport ` used to implement + different websockets protocols. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + def __init__( + self, + url: str, + *, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + subprotocols: Optional[List[Subprotocol]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + """ + + if subprotocols is not None: + connect_args.update({"subprotocols": subprotocols}) + + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url, + headers=headers, + ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + keep_alive_timeout=keep_alive_timeout, + ) + + self.init_payload: Dict[str, Any] = init_payload + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + @property + def url(self) -> str: + return self.adapter.url + + @property + def headers(self) -> Dict[str, str]: + return self.adapter.headers + + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args diff --git a/gql/transport/websockets_common/__init__.py b/gql/transport/websockets_common/__init__.py deleted file mode 100644 index 7661cf87..00000000 --- a/gql/transport/websockets_common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .listener_queue import ListenerQueue, ParsedAnswer - -__all__ = ["ListenerQueue", "ParsedAnswer"] diff --git a/tests/conftest.py b/tests/conftest.py index b0103a99..664fe8c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,9 +121,10 @@ async def ssl_aiohttp_server(): "gql.transport.aiohttp", "gql.transport.aiohttp_websockets", "gql.transport.appsync", + "gql.transport.common.base", + "gql.transport.httpx", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", - "gql.transport.httpx", "gql.transport.websockets", "gql.dsl", "gql.utilities.parse_result", diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index befeeb4e..cce31d59 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -6,6 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -233,7 +234,6 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -241,7 +241,7 @@ async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): async with Client(transport=sample_transport): pass @@ -257,13 +257,11 @@ async def test_graphqlws_server_closing_after_ack( event_loop, client_and_graphqlws_server ): - import websockets - session, server = client_and_graphqlws_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 683da43a..1b8f7ccb 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -385,14 +385,12 @@ async def server_countdown_close_connection_in_middle(ws): async def test_graphqlws_subscription_server_connection_closed( event_loop, client_and_graphqlws_server, subscription_str ): - import websockets - session, server = client_and_graphqlws_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): @@ -812,7 +810,6 @@ async def test_graphqlws_subscription_reconnecting_session( event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe ): - import websockets from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import TransportClosed @@ -838,7 +835,7 @@ async def test_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except websockets.exceptions.ConnectionClosedOK: + except TransportConnectionClosed: pass await asyncio.sleep(50 * MS) diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index f39edacb..320d1da3 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -65,6 +65,10 @@ async def test_phoenix_channel_query(event_loop, server, query_str): result = await session.execute(query) print("Client received:", result) + continents = result["continents"] + print("Continents received:", continents) + africa = continents[0] + assert africa["code"] == "AF" @pytest.mark.skip(reason="ssl=False is not working for now") diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index cb9e7274..f9f1f8db 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -141,7 +142,7 @@ async def test_websocket_sending_invalid_data(event_loop, client_and_server, que invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send(invalid_data) + await session.transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2 * MS) @@ -272,7 +273,6 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_websocket_server_closing_directly(event_loop, server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -280,7 +280,7 @@ async def test_websocket_server_closing_directly(event_loop, server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): async with Client(transport=sample_transport): pass @@ -294,13 +294,11 @@ async def server_closing_after_ack(ws): @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) async def test_websocket_server_closing_after_ack(event_loop, client_and_server): - import websockets - session, server = client_and_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 2c723b3f..f509f676 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -51,19 +51,19 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_starting_client_in_context_manager(event_loop, server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + assert transport.response_headers == {} + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: - assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol - ) + assert transport._connected is True query1 = gql(query1_str) @@ -85,7 +85,7 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.skip(reason="ssl=False is not working for now") @@ -133,7 +133,7 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -169,7 +169,7 @@ async def test_websocket_using_ssl_connection_self_cert_fail( assert expected_error in str(exc_info.value) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -355,13 +355,13 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -484,7 +484,7 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) - assert transport.websocket is None + assert transport._connected is False @pytest.mark.parametrize("server", [server1_answers], indirect=True) @@ -526,7 +526,7 @@ def test_websocket_execute_sync(server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -649,3 +649,52 @@ async def test_websocket_simple_query_with_extensions( execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_adapter_connection_closed(event_loop, server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + await transport.adapter.close() + + with pytest.raises(TransportClosed): + await session.execute(query1) + + # Check client is disconnect here + assert transport._connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_transport_closed_in_receive(event_loop, server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport( + url=url, + close_timeout=0.1, + ) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + # await transport.adapter.close() + transport._connected = False + + with pytest.raises(TransportClosed): + await session.execute(query1) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 5af44d59..3efe63a6 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -306,14 +306,12 @@ async def server_countdown_close_connection_in_middle(ws): async def test_websocket_subscription_server_connection_closed( event_loop, client_and_server, subscription_str ): - import websockets - session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py new file mode 100644 index 00000000..f266ce29 --- /dev/null +++ b/tests/test_websockets_adapter.py @@ -0,0 +1,98 @@ +import json + +import pytest +from graphql import print_ast + +from gql import gql +from gql.transport.exceptions import TransportConnectionClosed + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_simple_query(event_loop, server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str)) + print("query=", query) + + adapter = WebSocketsAdapter(url) + + await adapter.connect() + + init_message = json.dumps({"type": "connection_init", "payload": {}}) + + await adapter.send(init_message) + + result = await adapter.receive() + print(f"result={result}") + + payload = json.dumps({"query": query}) + query_message = json.dumps({"id": 1, "type": "start", "payload": payload}) + + await adapter.send(query_message) + + result = await adapter.receive() + print(f"result={result}") + + await adapter.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_edge_cases(event_loop, server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str)) + print("query=", query) + + adapter = WebSocketsAdapter(url, headers={"a": 1}, ssl=False, connect_args={}) + + await adapter.connect() + + assert adapter.headers["a"] == 1 + assert adapter.ssl is False + assert adapter.connect_args == {} + assert adapter.response_headers["dummy"] == "test1234" + + # Connect twice causes AssertionError + with pytest.raises(AssertionError): + await adapter.connect() + + await adapter.close() + + # Second close call is ignored + await adapter.close() + + with pytest.raises(TransportConnectionClosed): + await adapter.send("Blah") + + with pytest.raises(TransportConnectionClosed): + await adapter.receive() From 4a8493b22b26110a9c130fb1e732713f94db79d2 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 8 Mar 2025 23:03:20 +0100 Subject: [PATCH 196/239] Using SubscriptionTransportBase instead of WebsocketsTransportBase for Phoenix transport --- gql/transport/phoenix_channel_websockets.py | 40 +++++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 382e9014..0c1bd62b 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,17 +1,18 @@ import asyncio import json import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.exceptions import ConnectionClosed +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import ( + TransportConnectionClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def __init__(self, query_id: int) -> None: self.unsubscribe_id: Optional[int] = None -class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): +class PhoenixChannelWebsocketsTransport(SubscriptionTransportBase): """The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. @@ -36,23 +37,48 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): def __init__( self, + url: str, + *, channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, - *args, + ack_timeout: Optional[Union[int, float]] = 10, **kwargs, ) -> None: """Initialize the transport with the given parameters. + :param url: The server URL.'. :param channel_name: Channel on the server this transport will join. The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client + :param ack_timeout: Timeout in seconds to wait for the reply message + from the server. """ self.channel_name: str = channel_name self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscriptions: Dict[str, Subscription] = {} - super().__init__(*args, **kwargs) + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + ws_adapter_args = {} + for ws_arg in ["headers", "ssl", "connect_args"]: + try: + ws_adapter_args[ws_arg] = kwargs.pop(ws_arg) + except KeyError: + pass + + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, + **ws_adapter_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + **kwargs, + ) async def _initialize(self) -> None: """Join the specified channel and wait for the connection ACK. @@ -101,7 +127,7 @@ async def heartbeat_coro(): } ) ) - except ConnectionClosed: # pragma: no cover + except TransportConnectionClosed: # pragma: no cover return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) From fe6712b383f256eafdb976e1a1d985c375b6236e Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 8 Mar 2025 23:21:24 +0100 Subject: [PATCH 197/239] Using SubscriptionTransportBase instead of WebsocketsTransportBase for AppSync transport --- gql/transport/appsync_websockets.py | 41 ++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 0d5139c3..c339e0b8 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -7,8 +7,10 @@ from graphql import DocumentNode, ExecutionResult, print_ast from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import TransportProtocolError, TransportServerError -from .websockets import WebsocketsTransport, WebsocketsTransportBase +from .websockets import WebsocketsTransport log = logging.getLogger("gql.transport.appsync") @@ -19,7 +21,7 @@ pass -class AppSyncWebsocketsTransport(WebsocketsTransportBase): +class AppSyncWebsocketsTransport(SubscriptionTransportBase): """:ref:`Async Transport ` used to execute GraphQL subscription on AWS appsync realtime endpoint. @@ -32,6 +34,7 @@ class AppSyncWebsocketsTransport(WebsocketsTransportBase): def __init__( self, url: str, + *, auth: Optional[AppSyncAuthentication] = None, session: Optional["botocore.session.Session"] = None, ssl: Union[SSLContext, bool] = False, @@ -70,17 +73,25 @@ def __init__( auth = AppSyncIAMAuthentication(host=host, session=session) self.auth = auth + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.init_payload: Dict[str, Any] = {} url = self.auth.get_auth_url(url) - super().__init__( - url, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, connect_timeout=connect_timeout, close_timeout=close_timeout, - ack_timeout=ack_timeout, keep_alive_timeout=keep_alive_timeout, - connect_args=connect_args, ) # Using the same 'graphql-ws' protocol as the apollo protocol @@ -181,7 +192,7 @@ async def _send_query( return query_id - subscribe = WebsocketsTransportBase.subscribe # type: ignore[assignment] + subscribe = SubscriptionTransportBase.subscribe # type: ignore[assignment] """Send a subscription query and receive the results using a python async generator. @@ -212,3 +223,19 @@ async def execute( WebsocketsTransport._send_init_message_and_wait_ack ) _wait_ack = WebsocketsTransport._wait_ack + + @property + def url(self) -> str: + return self.adapter.url + + @property + def headers(self) -> Dict[str, str]: + return self.adapter.headers + + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args From 352af37ec8c9fa2eb89540155243bedf6d16d887 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 9 Mar 2025 17:02:13 +0100 Subject: [PATCH 198/239] Put dependency-free websockets protocol in websockets_protocol.py --- gql/transport/appsync_websockets.py | 12 - gql/transport/common/adapters/connection.py | 5 +- gql/transport/common/adapters/websockets.py | 6 +- gql/transport/common/base.py | 8 + gql/transport/websockets.py | 474 +----------------- gql/transport/websockets_base.py | 93 ---- gql/transport/websockets_protocol.py | 516 ++++++++++++++++++++ tests/test_phoenix_channel_subscription.py | 4 +- tests/test_websocket_query.py | 3 + 9 files changed, 564 insertions(+), 557 deletions(-) delete mode 100644 gql/transport/websockets_base.py create mode 100644 gql/transport/websockets_protocol.py diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index c339e0b8..e0f5c031 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -224,18 +224,6 @@ async def execute( ) _wait_ack = WebsocketsTransport._wait_ack - @property - def url(self) -> str: - return self.adapter.url - - @property - def headers(self) -> Dict[str, str]: - return self.adapter.headers - @property def ssl(self) -> Union[SSLContext, bool]: return self.adapter.ssl - - @property - def connect_args(self) -> Dict[str, Any]: - return self.adapter.connect_args diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py index fbe38e3b..cf361b8d 100644 --- a/gql/transport/common/adapters/connection.py +++ b/gql/transport/common/adapters/connection.py @@ -1,5 +1,5 @@ import abc -from typing import Dict +from typing import Any, Dict class AdapterConnection(abc.ABC): @@ -8,6 +8,9 @@ class AdapterConnection(abc.ABC): This allows different WebSocket implementations to be used interchangeably. """ + url: str + connect_args: Dict[str, Any] + @abc.abstractmethod async def connect(self) -> None: """Connect to the server.""" diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index 95fbaf39..4494e256 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -19,7 +19,7 @@ def __init__( *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, - connect_args: Dict[str, Any] = {}, + connect_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -31,6 +31,10 @@ def __init__( self.url: str = url self._headers: Optional[HeadersLike] = headers self.ssl: Union[SSLContext, bool] = ssl + + if connect_args is None: + connect_args = {} + self.connect_args = connect_args self.websocket: Optional[WebSocketClientProtocol] = None diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 9ee07dd8..40d0b4cb 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -559,3 +559,11 @@ async def wait_closed(self) -> None: log.debug("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") + + @property + def url(self) -> str: + return self.adapter.url + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 929761e6..7a0ce10a 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,26 +1,13 @@ -import asyncio -import json -import logging -from contextlib import suppress from ssl import SSLContext -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Union -from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Subprotocol -from .exceptions import ( - TransportProtocolError, - TransportQueryError, - TransportServerError, -) -from .websockets_base import WebsocketsTransportBase +from .common.adapters.websockets import WebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger(__name__) - -class WebsocketsTransport(WebsocketsTransportBase): +class WebsocketsTransport(WebsocketsProtocolTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. @@ -28,18 +15,13 @@ class WebsocketsTransport(WebsocketsTransportBase): on a websocket connection. """ - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") - GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") - def __init__( self, url: str, *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + init_payload: Optional[Dict[str, Any]] = None, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, @@ -47,8 +29,8 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, - connect_args: Dict[str, Any] = {}, - subprotocols: Optional[List[Subprotocol]] = None, + connect_args: Optional[Dict[str, Any]] = None, + subprotocols: Optional[List[str]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -84,437 +66,33 @@ def __init__( By default: both apollo and graphql-ws subprotocols. """ - if subprotocols is None: - subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - - # Initiliaze WebsocketsTransportBase parent class - super().__init__( - url, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, headers=headers, ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, init_payload=init_payload, connect_timeout=connect_timeout, close_timeout=close_timeout, ack_timeout=ack_timeout, keep_alive_timeout=keep_alive_timeout, - connect_args=connect_args, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, subprotocols=subprotocols, ) - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, answer_id, execution_result = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = json.dumps( - {"type": "connection_init", "payload": self.init_payload} - ) - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - await self._send_init_message_and_wait_ack() - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(json.dumps(ping_message)) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(json.dumps(pong_message)) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = json.dumps({"id": str(query_id), "type": "stop"}) - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = json.dumps({"id": str(query_id), "type": "complete"}) - - await self._send(complete_message) - - async def _stop_listener(self, query_id: int): - """Stop the listener corresponding to the query_id depending on the - detected backend protocol. - - For apollo: send a "stop" message - (a "complete" message will be sent from the backend) - - For graphql-ws: send a "complete" message and simulate the reception - of a "complete" message from the backend - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = json.dumps({"type": "connection_terminate"}) - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query_str = json.dumps( - {"id": str(query_id), "type": query_type, "payload": payload} - ) - - await self._send(query_str) - - return query_id - - async def _connection_terminate(self): - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - - def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = json_answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = json_answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = json_answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - # Put the answer in the queue - await super()._handle_answer(answer_type, answer_id, execution_result) - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _after_connect(self): - - # Find the backend subprotocol returned in the response headers - try: - self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def _after_initialize(self): - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - log.debug("_close_hook: start") - - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - log.debug("_close_hook: cancelling send_ping_task") - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError, ConnectionClosed): - log.debug("_close_hook: awaiting send_ping_task") - await self.send_ping_task - self.send_ping_task = None + @property + def headers(self) -> Optional[HeadersLike]: + return self.adapter.headers - log.debug("_close_hook: end") + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py deleted file mode 100644 index 95e54b3f..00000000 --- a/gql/transport/websockets_base.py +++ /dev/null @@ -1,93 +0,0 @@ -from ssl import SSLContext -from typing import Any, Dict, List, Optional, Union - -from websockets.datastructures import HeadersLike -from websockets.typing import Subprotocol - -from .common.adapters.websockets import WebSocketsAdapter -from .common.base import SubscriptionTransportBase - - -class WebsocketsTransportBase(SubscriptionTransportBase): - """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. - """ - - def __init__( - self, - url: str, - *, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, - connect_timeout: Optional[Union[int, float]] = 10, - close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, - keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, - subprotocols: Optional[List[Subprotocol]] = None, - ) -> None: - """Initialize the transport with the given parameters. - - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. - :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. - :param close_timeout: Timeout in seconds for the close. If None is provided - this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. - :param keep_alive_timeout: Optional Timeout in seconds to receive - a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect - :param subprotocols: list of subprotocols sent to the - backend in the 'subprotocols' http header. - """ - - if subprotocols is not None: - connect_args.update({"subprotocols": subprotocols}) - - # Instanciate a WebSocketAdapter to indicate the use - # of the websockets dependency for this transport - self.adapter: WebSocketsAdapter = WebSocketsAdapter( - url, - headers=headers, - ssl=ssl, - connect_args=connect_args, - ) - - # Initialize the generic SubscriptionTransportBase parent class - super().__init__( - adapter=self.adapter, - connect_timeout=connect_timeout, - close_timeout=close_timeout, - keep_alive_timeout=keep_alive_timeout, - ) - - self.init_payload: Dict[str, Any] = init_payload - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - - @property - def url(self) -> str: - return self.adapter.url - - @property - def headers(self) -> Dict[str, str]: - return self.adapter.headers - - @property - def ssl(self) -> Union[SSLContext, bool]: - return self.adapter.ssl - - @property - def connect_args(self) -> Dict[str, Any]: - return self.adapter.connect_args diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py new file mode 100644 index 00000000..84ba7656 --- /dev/null +++ b/gql/transport/websockets_protocol.py @@ -0,0 +1,516 @@ +import asyncio +import json +import logging +from contextlib import suppress +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import DocumentNode, ExecutionResult, print_ast + +from .common.adapters.websockets import AdapterConnection +from .common.base import SubscriptionTransportBase +from .exceptions import ( + TransportConnectionClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + + +class WebsocketsProtocolTransportBase(SubscriptionTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. + + This transport uses asyncio and the provided websockets adapter library + in order to send requests on a websocket connection. + """ + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL = "graphql-transport-ws" + + def __init__( + self, + *, + adapter: AdapterConnection, + init_payload: Optional[Dict[str, Any]] = None, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + subprotocols: Optional[List[str]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param adapter: The connection dependency adapter + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + """ + + if subprotocols is None: + subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + self.adapter.connect_args.update({"subprotocols": subprotocols}) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + keep_alive_timeout=keep_alive_timeout, + ) + + if init_payload is None: + init_payload = {} + + self.init_payload: Dict[str, Any] = init_payload + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + await self._send_init_message_and_wait_ack() + + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(json.dumps(pong_message)) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = json.dumps({"id": str(query_id), "type": "stop"}) + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + + await self._send(complete_message) + + async def _stop_listener(self, query_id: int): + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. + + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) + + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = json.dumps({"type": "connection_terminate"}) + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query_str = json.dumps( + {"id": str(query_id), "type": query_type, "payload": payload} + ) + + await self._send(query_str) + + return query_id + + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + def _parse_answer_graphqlws( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) + + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + + async def _after_connect(self): + + # Find the backend subprotocol returned in the response headers + try: + self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + + async def _after_initialize(self): + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + + async def _close_hook(self): + log.debug("_close_hook: start") + + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + log.debug("_close_hook: cancelling send_ping_task") + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError, TransportConnectionClosed): + log.debug("_close_hook: awaiting send_ping_task") + await self.send_ping_task + self.send_ping_task = None + + log.debug("_close_hook: end") diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 6193c658..3be4b07d 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -186,7 +186,7 @@ async def test_phoenix_channel_subscription( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) @@ -227,7 +227,7 @@ async def test_phoenix_channel_subscription_no_break( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger from .conftest import MS diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index f509f676..7aa853bf 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -157,6 +157,9 @@ async def test_websocket_using_ssl_connection_self_cert_fail( transport = WebsocketsTransport(url=url, **extra_args) + if verify_https == "explicitely_enabled": + assert transport.ssl is True + with pytest.raises(SSLCertVerificationError) as exc_info: async with Client(transport=transport) as session: From 496add12c35fcc59bb580d15770b8ae8c633179d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 9 Mar 2025 23:55:35 +0100 Subject: [PATCH 199/239] Use new connection adapter for aiohttp websockets --- gql/transport/aiohttp.py | 57 +- gql/transport/aiohttp_websockets.py | 1067 +---------------- gql/transport/appsync_websockets.py | 2 +- gql/transport/common/adapters/aiohttp.py | 269 +++++ gql/transport/common/adapters/connection.py | 13 +- gql/transport/common/adapters/websockets.py | 38 +- gql/transport/common/aiohttp_closed_event.py | 59 + gql/transport/websockets_protocol.py | 4 +- tests/test_aiohttp_websocket_exceptions.py | 8 +- ..._aiohttp_websocket_graphqlws_exceptions.py | 5 +- ...iohttp_websocket_graphqlws_subscription.py | 6 +- tests/test_aiohttp_websocket_query.py | 39 +- tests/test_aiohttp_websocket_subscription.py | 16 +- tests/test_phoenix_channel_query.py | 22 +- tests/test_websocket_query.py | 22 +- 15 files changed, 481 insertions(+), 1146 deletions(-) create mode 100644 gql/transport/common/adapters/aiohttp.py create mode 100644 gql/transport/common/aiohttp_closed_event.py diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 0c332205..c1302794 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,5 +1,4 @@ import asyncio -import functools import io import json import logging @@ -28,6 +27,7 @@ from ..utils import extract_files from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport +from .common.aiohttp_closed_event import create_aiohttp_closed_event from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -147,59 +147,6 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") - @staticmethod - def create_aiohttp_closed_event(session) -> asyncio.Event: - """Work around aiohttp issue that doesn't properly close transports on exit. - - See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 - - Returns: - An event that will be set once all transports have been properly closed. - """ - - ssl_transports = 0 - all_is_lost = asyncio.Event() - - def connection_lost(exc, orig_lost): - nonlocal ssl_transports - - try: - orig_lost(exc) - finally: - ssl_transports -= 1 - if ssl_transports == 0: - all_is_lost.set() - - def eof_received(orig_eof_received): - try: # pragma: no cover - orig_eof_received() - except AttributeError: # pragma: no cover - # It may happen that eof_received() is called after - # _app_protocol and _transport are set to None. - pass - - for conn in session.connector._conns.values(): - for handler, _ in conn: - proto = getattr(handler.transport, "_ssl_protocol", None) - if proto is None: - continue - - ssl_transports += 1 - orig_lost = proto.connection_lost - orig_eof_received = proto.eof_received - - proto.connection_lost = functools.partial( - connection_lost, orig_lost=orig_lost - ) - proto.eof_received = functools.partial( - eof_received, orig_eof_received=orig_eof_received - ) - - if ssl_transports == 0: - all_is_lost.set() - - return all_is_lost - async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -219,7 +166,7 @@ async def close(self) -> None: log.debug("connector_owner is False -> not closing connector") else: - closed_event = self.create_aiohttp_closed_event(self.session) + closed_event = create_aiohttp_closed_event(self.session) await self.session.close() try: await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index f97fbba8..59d870f6 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,53 +1,26 @@ -import asyncio -import json -import logging -import warnings -from contextlib import suppress from ssl import SSLContext -from typing import ( - Any, - AsyncGenerator, - Collection, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Union, -) +from typing import Any, Dict, List, Literal, Mapping, Optional, Union -import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp import BasicAuth, ClientSession, Fingerprint from aiohttp.typedefs import LooseHeaders, StrOrURL -from graphql import DocumentNode, ExecutionResult, print_ast -from multidict import CIMultiDictProxy -from .aiohttp import AIOHTTPTransport -from .async_transport import AsyncTransport -from .common import ListenerQueue -from .exceptions import ( - TransportAlreadyConnected, - TransportClosed, - TransportProtocolError, - TransportQueryError, - TransportServerError, -) +from .common.adapters.aiohttp import AIOHTTPWebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger("gql.transport.aiohttp_websockets") +class AIOHTTPWebsocketsTransport(WebsocketsProtocolTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. -class AIOHTTPWebsocketsTransport(AsyncTransport): - - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL: str = "graphql-ws" - GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" + This transport uses asyncio and the provided aiohttp adapter library + in order to send requests on a websocket connection. + """ def __init__( self, url: StrOrURL, *, - subprotocols: Optional[Collection[str]] = None, + subprotocols: Optional[List[str]] = None, heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, @@ -68,8 +41,9 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, + session: Optional[ClientSession] = None, client_session_args: Optional[Dict[str, Any]] = None, - connect_args: Dict[str, Any] = {}, + connect_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -140,6 +114,7 @@ def __init__( :param answer_pings: Whether the client answers the pings from the backend (for the graphql-ws protocol). By default: True + :param session: Optional aiohttp.ClientSession instance. :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ :param connect_args: Dict of extra args passed to @@ -150,986 +125,46 @@ def __init__( .. _aiohttp.ClientSession: https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession """ - self.url: StrOrURL = url - self.heartbeat: Optional[float] = heartbeat - self.auth: Optional[BasicAuth] = auth - self.origin: Optional[str] = origin - self.params: Optional[Mapping[str, str]] = params - self.headers: Optional[LooseHeaders] = headers - - self.proxy: Optional[StrOrURL] = proxy - self.proxy_auth: Optional[BasicAuth] = proxy_auth - self.proxy_headers: Optional[LooseHeaders] = proxy_headers - - self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl - - self.websocket_close_timeout: float = websocket_close_timeout - self.receive_timeout: Optional[float] = receive_timeout - - self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - - self.init_payload: Dict[str, Any] = init_payload - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() - - self.session: Optional[aiohttp.ClientSession] = None - self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - self._connecting: bool = False - self.response_headers: Optional[CIMultiDictProxy[str]] = None - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - self.payloads: Dict[str, Any] = {} - - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - self.supported_subprotocols: Collection[str] = subprotocols or ( - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ) - - self.close_exception: Optional[Exception] = None - - self.client_session_args = client_session_args - self.connect_args = connect_args - - def _parse_answer_graphqlws( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, _, _ = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = {"type": "connection_init", "payload": self.init_payload} - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - """Hook to send the initialization messages after the connection - and potentially wait for the backend ack. - """ - await self._send_init_message_and_wait_ack() - - async def _stop_listener(self, query_id: int): - """Hook to stop to listen to a specific query. - Will send a stop message in some subclasses. - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _after_connect(self): - """Hook to add custom code for subclasses after the connection - has been established. - """ - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket._response.headers - log.debug(f"Response headers: {response_headers!r}") - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(ping_message) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(pong_message) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = {"id": str(query_id), "type": "stop"} - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = {"id": str(query_id), "type": "complete"} - - await self._send(complete_message) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _after_initialize(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - """Hook to add custom code for subclasses for the connection close""" - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task - self.send_ping_task = None - - async def _connection_terminate(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = {"type": "connection_terminate"} - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query = {"id": str(query_id), "type": query_type, "payload": payload} - - await self._send(query) - - return query_id - - async def _send(self, message: Dict[str, Any]) -> None: - """Send the provided message to the websocket connection and log the message""" - - if self.websocket is None: - raise TransportClosed("WebSocket connection is closed") - - try: - await self.websocket.send_json(message) - log.info(">>> %s", message) - except ConnectionResetError as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - while True: - ws_message = await self.websocket.receive() - - # Ignore low-level ping and pong received - if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): - break - - if ws_message.type in ( - WSMsgType.CLOSE, - WSMsgType.CLOSED, - WSMsgType.CLOSING, - WSMsgType.ERROR, - ): - raise ConnectionResetError - elif ws_message.type is WSMsgType.BINARY: - raise TransportProtocolError("Binary data received in the websocket") - - assert ws_message.type is WSMsgType.TEXT - - answer: str = ws_message.data - - log.info("<<< %s", answer) - - return answer - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) - - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) - - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _receive_data_loop(self) -> None: - """Main asyncio task which will listen to the incoming messages and will - call the parse_answer and handle_answer methods of the subclass.""" - log.debug("Entering _receive_data_loop()") - - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionResetError, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed as e: - await self._fail(e, clean_close=False) - raise e - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - - async def connect(self) -> None: - log.debug("connect: starting") - - if self.session is None: - client_session_args: Dict[str, Any] = {} - - # Adding custom parameters passed from init - if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore - - self.session = aiohttp.ClientSession(**client_session_args) - - if self.websocket is None and not self._connecting: - self._connecting = True - - connect_args: Dict[str, Any] = { - "url": self.url, - "headers": self.headers, - "auth": self.auth, - "heartbeat": self.heartbeat, - "origin": self.origin, - "params": self.params, - "protocols": self.supported_subprotocols, - "proxy": self.proxy, - "proxy_auth": self.proxy_auth, - "proxy_headers": self.proxy_headers, - "timeout": self.websocket_close_timeout, - "receive_timeout": self.receive_timeout, - } - - if self.ssl is not None: - connect_args.update( - { - "ssl": self.ssl, - } - ) - - # Adding custom parameters passed from init - if self.connect_args: - connect_args.update(self.connect_args) - - try: - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - self.websocket = await asyncio.wait_for( - self.session.ws_connect( - **connect_args, - ), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.response_headers = self.websocket._response.headers - - await self._after_connect() - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This should generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._initialize() - except ConnectionResetError as e: - raise e - except ( - TransportProtocolError, - TransportServerError, - asyncio.TimeoutError, - ) as e: - await self._fail(e, clean_close=False) - raise e - - # Run the after_init hook of the subclass - await self._after_initialize() - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - async def _clean_close(self) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - log.debug(f"Listeners: {self.listeners}") - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - print(f"Listener {query_id} send_stop: {listener.send_stop}") - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - # Calling the subclass hook - await self._connection_terminate() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") - - try: - - try: - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel keep alive task exception: " + repr(exc) - ) - - try: - # Calling the subclass close hook - await self._close_hook() - except Exception as exc: # pragma: no cover - log.warning("_close_coro close_hook exception: " + repr(exc)) - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close() - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - try: - assert self.websocket is not None - - await self.websocket.close() - self.websocket = None - except Exception as exc: - log.warning("_close_coro websocket close exception: " + repr(exc)) - - log.debug("_close_coro: close aiohttp session") - - if ( - self.client_session_args - and self.client_session_args.get("connector_owner") is False - ): - - log.debug("connector_owner is False -> not closing connector") - - else: - try: - assert self.session is not None - - closed_event = AIOHTTPTransport.create_aiohttp_closed_event( - self.session - ) - await self.session.close() - try: - await asyncio.wait_for( - closed_event.wait(), self.ssl_close_timeout - ) - except asyncio.TimeoutError: - pass - except Exception as exc: # pragma: no cover - log.warning("_close_coro session close exception: " + repr(exc)) - - self.session = None - - log.debug("_close_coro: aiohttp session closed") - - try: - assert self.receive_data_task is not None - - self.receive_data_task.cancel() - with suppress(asyncio.CancelledError): - await self.receive_data_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel receive data task exception: " + repr(exc) - ) - - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) - - finally: - - log.debug("_close_coro: final cleanup") - - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None - self.receive_data_task = None - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self._wait_closed.is_set(): - log.debug("_fail started but transport is already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - if not self._wait_closed.is_set(): - await self._wait_closed.wait() - - log.debug("wait_close: done") - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + # Instanciate a AIOHTTPWebSocketAdapter to indicate the use + # of the aiohttp dependency for this transport + self.adapter: AIOHTTPWebSocketsAdapter = AIOHTTPWebSocketsAdapter( + url=url, + headers=headers, + ssl=ssl, + session=session, + client_session_args=client_session_args, + connect_args=connect_args, + heartbeat=heartbeat, + auth=auth, + origin=origin, + params=params, + proxy=proxy, + proxy_auth=proxy_auth, + proxy_headers=proxy_headers, + websocket_close_timeout=websocket_close_timeout, + receive_timeout=receive_timeout, + ssl_close_timeout=ssl_close_timeout, ) - async for result in generator: - first_result = result - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, + subprotocols=subprotocols, ) - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False + @property + def headers(self) -> Optional[LooseHeaders]: + return self.adapter.headers - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) + @property + def ssl(self) -> Optional[Union[SSLContext, Literal[False], Fingerprint]]: + return self.adapter.ssl diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index e0f5c031..f35cefe5 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -95,7 +95,7 @@ def __init__( ) # Using the same 'graphql-ws' protocol as the apollo protocol - self.supported_subprotocols = [ + self.adapter.subprotocols = [ WebsocketsTransport.APOLLO_SUBPROTOCOL, ] self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py new file mode 100644 index 00000000..d9af7c50 --- /dev/null +++ b/gql/transport/common/adapters/aiohttp.py @@ -0,0 +1,269 @@ +import asyncio +import logging +from ssl import SSLContext +from typing import Any, Dict, Literal, Mapping, Optional, Union + +import aiohttp +from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp.typedefs import LooseHeaders, StrOrURL +from multidict import CIMultiDictProxy + +from ...exceptions import TransportConnectionClosed, TransportProtocolError +from ..aiohttp_closed_event import create_aiohttp_closed_event +from .connection import AdapterConnection + +log = logging.getLogger("gql.transport.common.adapters.aiohttp") + + +class AIOHTTPWebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the aiohttp library.""" + + def __init__( + self, + url: StrOrURL, + *, + headers: Optional[LooseHeaders] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + session: Optional[aiohttp.ClientSession] = None, + client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Optional[Dict[str, Any]] = None, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param session: Optional aiohttp opened session. + :param client_session_args: Dict of extra args passed to + `aiohttp.ClientSession`_ + :param connect_args: Dict of extra args passed to + `aiohttp.ClientSession.ws_connect`_ + + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + """ + super().__init__( + url=str(url), + connect_args=connect_args, + ) + + self._headers: Optional[LooseHeaders] = headers + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl + + self.session: Optional[aiohttp.ClientSession] = session + self._using_external_session = True if self.session else False + + if client_session_args is None: + client_session_args = {} + self.client_session_args = client_session_args + + self.heartbeat: Optional[float] = heartbeat + self.auth: Optional[BasicAuth] = auth + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + self._response_headers: Optional[CIMultiDictProxy[str]] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + # Create a session if necessary + if self.session is None: + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) + + connect_args: Dict[str, Any] = { + "url": self.url, + "headers": self.headers, + "auth": self.auth, + "heartbeat": self.heartbeat, + "origin": self.origin, + "params": self.params, + "proxy": self.proxy, + "proxy_auth": self.proxy_auth, + "proxy_headers": self.proxy_headers, + "timeout": self.websocket_close_timeout, + "receive_timeout": self.receive_timeout, + } + + if self.subprotocols: + connect_args["protocols"] = self.subprotocols + + if self.ssl is not None: + connect_args["ssl"] = self.ssl + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + try: + self.websocket = await self.session.ws_connect( + **connect_args, + ) + except Exception as e: + raise TransportConnectionClosed("Connect failed") from e + + self._response_headers = self.websocket._response.headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionClosed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + try: + await self.websocket.send_str(message) + except ConnectionResetError as e: + raise TransportConnectionClosed("Connection was closed") from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionClosed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionClosed("Connection is already closed") + + while True: + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break + + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): + raise TransportConnectionClosed("Connection was closed") + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") + + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + return answer + + async def _close_session(self) -> None: + """Close the aiohttp session.""" + + assert self.session is not None + + closed_event = create_aiohttp_closed_event(self.session) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + finally: + self.session = None + + async def close(self) -> None: + """Close the WebSocket connection.""" + + if self.websocket: + websocket = self.websocket + self.websocket = None + try: + await websocket.close() + except Exception as exc: # pragma: no cover + log.warning("websocket.close() exception: " + repr(exc)) + + if self.session and not self._using_external_session: + await self._close_session() + + @property + def headers(self) -> Optional[LooseHeaders]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return self._headers + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers) + return {} diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py index cf361b8d..f3d77421 100644 --- a/gql/transport/common/adapters/connection.py +++ b/gql/transport/common/adapters/connection.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict +from typing import Any, Dict, List, Optional class AdapterConnection(abc.ABC): @@ -10,6 +10,17 @@ class AdapterConnection(abc.ABC): url: str connect_args: Dict[str, Any] + subprotocols: Optional[List[str]] + + def __init__(self, url: str, connect_args: Optional[Dict[str, Any]]): + """Initialize the connection adapter.""" + self.url: str = url + + if connect_args is None: + connect_args = {} + self.connect_args = connect_args + + self.subprotocols = None @abc.abstractmethod async def connect(self) -> None: diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index 4494e256..383d4def 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -1,14 +1,16 @@ +import logging from ssl import SSLContext from typing import Any, Dict, Optional, Union import websockets from websockets.client import WebSocketClientProtocol from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import WebSocketException from ...exceptions import TransportConnectionClosed, TransportProtocolError from .connection import AdapterConnection +log = logging.getLogger("gql.transport.common.adapters.websockets") + class WebSocketsAdapter(AdapterConnection): """AdapterConnection implementation using the websockets library.""" @@ -26,16 +28,17 @@ def __init__( :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. :param headers: Dict of HTTP Headers. :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param connect_args: Other parameters forwarded to websockets.connect + :param connect_args: Other parameters forwarded to + `websockets.connect `_ """ - self.url: str = url - self._headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl + super().__init__( + url=url, + connect_args=connect_args, + ) - if connect_args is None: - connect_args = {} - - self.connect_args = connect_args + self._headers: Optional[HeadersLike] = headers + self.ssl = ssl self.websocket: Optional[WebSocketClientProtocol] = None self._response_headers: Optional[Headers] = None @@ -57,14 +60,17 @@ async def connect(self) -> None: "extra_headers": self.headers, } + if self.subprotocols: + connect_args["subprotocols"] = self.subprotocols + # Adding custom parameters passed from init connect_args.update(self.connect_args) # Connection to the specified url try: self.websocket = await websockets.client.connect(self.url, **connect_args) - except WebSocketException as e: - raise TransportConnectionClosed("Connection was closed") from e + except Exception as e: + raise TransportConnectionClosed("Connect failed") from e self._response_headers = self.websocket.response_headers @@ -82,7 +88,7 @@ async def send(self, message: str) -> None: try: await self.websocket.send(message) - except WebSocketException as e: + except Exception as e: raise TransportConnectionClosed("Connection was closed") from e async def receive(self) -> str: @@ -102,9 +108,7 @@ async def receive(self) -> str: # Wait for the next websocket frame. Can raise ConnectionClosed try: data = await self.websocket.recv() - except WebSocketException as e: - # When the connection is closed, make sure to clean up resources - self.websocket = None + except Exception as e: raise TransportConnectionClosed("Connection was closed") from e # websocket.recv() can return either str or bytes @@ -124,14 +128,14 @@ async def close(self) -> None: await websocket.close() @property - def headers(self) -> Dict[str, str]: + def headers(self) -> Optional[HeadersLike]: """Get the response headers from the WebSocket connection. Returns: Dictionary of response headers """ if self._headers: - return dict(self._headers) + return self._headers return {} @property diff --git a/gql/transport/common/aiohttp_closed_event.py b/gql/transport/common/aiohttp_closed_event.py new file mode 100644 index 00000000..412448f9 --- /dev/null +++ b/gql/transport/common/aiohttp_closed_event.py @@ -0,0 +1,59 @@ +import asyncio +import functools + +from aiohttp import ClientSession + + +def create_aiohttp_closed_event(session: ClientSession) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + ssl_transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal ssl_transports + + try: + orig_lost(exc) + finally: + ssl_transports -= 1 + if ssl_transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: # pragma: no cover + orig_eof_received() + except AttributeError: # pragma: no cover + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + assert session.connector is not None + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + ssl_transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if ssl_transports == 0: + all_is_lost.set() + + return all_is_lost diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 84ba7656..f004d240 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -6,7 +6,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast -from .common.adapters.websockets import AdapterConnection +from .common.adapters.connection import AdapterConnection from .common.base import SubscriptionTransportBase from .exceptions import ( TransportConnectionClosed, @@ -80,7 +80,7 @@ def __init__( self.GRAPHQLWS_SUBPROTOCOL, ] - self.adapter.connect_args.update({"subprotocols": subprotocols}) + self.adapter.subprotocols = subprotocols # Initialize the generic SubscriptionTransportBase parent class super().__init__( diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 8ee44d2c..e4e56fcd 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -148,7 +148,7 @@ async def test_aiohttp_websocket_sending_invalid_data( invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send_str(invalid_data) + await session.transport.adapter.websocket.send_str(invalid_data) await asyncio.sleep(2 * MS) @@ -289,7 +289,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): sample_transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async with Client(transport=sample_transport): pass @@ -309,7 +309,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index b234d296..8f3567a7 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -6,6 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, + TransportConnectionClosed, TransportProtocolError, TransportQueryError, ) @@ -247,7 +248,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async with Client(transport=transport): pass @@ -267,7 +268,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index d40d15ce..79cf506d 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -390,7 +390,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -839,7 +839,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except ConnectionResetError: + except TransportConnectionClosed: pass await asyncio.sleep(50 * MS) diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index d76d646f..30b35d73 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportQueryError, TransportServerError, ) @@ -60,7 +61,14 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) + transport = AIOHTTPWebsocketsTransport( + url=url, + websocket_close_timeout=10, + headers={"test": "1234"}, + ) + + assert transport.response_headers == {} + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: @@ -84,7 +92,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -135,7 +143,7 @@ async def test_aiohttp_websocket_using_ssl_connection( assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -166,19 +174,26 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( **extra_args, ) - with pytest.raises(ClientConnectorCertificateError) as exc_info: + if verify_https == "explicitely_enabled": + assert transport.ssl is True + + with pytest.raises(TransportConnectionClosed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, ClientConnectorCertificateError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -380,13 +395,13 @@ async def test_aiohttp_websocket_multiple_connections_in_series( await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -519,8 +534,8 @@ async def test_aiohttp_websocket_connect_failed_with_authentication_in_connectio await session.execute(query1) - assert transport.session is None - assert transport.websocket is None + assert transport.adapter.session is None + assert transport._connected is False @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) @@ -564,7 +579,7 @@ def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -753,6 +768,6 @@ async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_se assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False await connector.close() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 9d2d652b..188e006e 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionClosed, TransportServerError from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef @@ -381,7 +381,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionClosed): async for result in session.subscribe(subscription): @@ -772,14 +772,12 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s subscription = gql(subscription_str.format(count=count)) async with client as session: - session.transport.websocket._writer._closing = True + session.transport.adapter.websocket._writer._closing = True - with pytest.raises(ConnectionResetError) as e: + with pytest.raises(TransportConnectionClosed): async for _ in session.subscribe(subscription): pass - assert e.value.args[0] == "Cannot write to closing transport" - @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -798,9 +796,7 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: - session.transport.websocket = None - with pytest.raises(TransportClosed) as e: + session.transport.adapter.websocket = None + with pytest.raises(TransportConnectionClosed): async for _ in session.subscribe(subscription): pass - - assert e.value.args[0] == "WebSocket connection is closed" diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 320d1da3..732c0e14 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -1,6 +1,7 @@ import pytest from gql import Client, gql +from gql.transport.exceptions import TransportConnectionClosed from .conftest import get_localhost_ssl_context_client @@ -71,14 +72,10 @@ async def test_phoenix_channel_query(event_loop, server, query_str): assert africa["code"] == "AF" -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_phoenix_channel_query_ssl( - event_loop, ws_ssl_server, query_str, verify_https -): +async def test_phoenix_channel_query_ssl(event_loop, ws_ssl_server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -89,12 +86,9 @@ async def test_phoenix_channel_query_ssl( extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", @@ -138,13 +132,17 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( query = gql(query_str) - with pytest.raises(SSLCertVerificationError) as exc_info: + with pytest.raises(TransportConnectionClosed) as exc_info: async with Client(transport=transport) as session: await session.execute(query) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) query2_str = """ diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 7aa853bf..f7e92840 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionClosed, TransportQueryError, TransportServerError, ) @@ -88,11 +89,9 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport._connected is False -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_https): +async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): import websockets from gql.transport.websockets import WebsocketsTransport @@ -103,19 +102,16 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = WebsocketsTransport(url=url, **extra_args) async with Client(transport=transport) as session: assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol + transport.adapter.websocket, websockets.client.WebSocketClientProtocol ) query1 = gql(query1_str) @@ -160,16 +156,20 @@ async def test_websocket_using_ssl_connection_self_cert_fail( if verify_https == "explicitely_enabled": assert transport.ssl is True - with pytest.raises(SSLCertVerificationError) as exc_info: + with pytest.raises(TransportConnectionClosed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here assert transport._connected is False From 750e695315380b9a1e3adb9bdf361d2b8c018c26 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 10 Mar 2025 17:00:12 +0100 Subject: [PATCH 200/239] Fix PyPy tests --- gql/transport/common/base.py | 10 +++++-- tests/conftest.py | 4 +++ ...iohttp_websocket_graphqlws_subscription.py | 28 +++++++++++++++---- tests/test_aiohttp_websocket_subscription.py | 6 +++- tests/test_client.py | 4 +++ tests/test_graphqlws_subscription.py | 28 +++++++++++++++---- tests/test_phoenix_channel_query.py | 6 +++- tests/test_phoenix_channel_subscription.py | 13 +++++++-- tests/test_websocket_subscription.py | 17 +++++++++-- 9 files changed, 95 insertions(+), 21 deletions(-) diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 40d0b4cb..2a4d4d65 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -317,6 +317,8 @@ async def subscribe( if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False + if isinstance(e, GeneratorExit): + raise e finally: log.debug(f"In subscribe finally for query_id {query_id}") @@ -345,6 +347,11 @@ async def execute( first_result = result break + # Apparently, on pypy the GeneratorExit exception is not raised after a break + # --> the clean_close has to time out + # We still need to manually close the async generator + await generator.aclose() + if first_result is None: raise TransportQueryError( "Query completed without any answer received from the server" @@ -445,7 +452,6 @@ async def _clean_close(self, e: Exception) -> None: # Send 'stop' message for all current queries for query_id, listener in self.listeners.items(): - if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False @@ -556,7 +562,7 @@ async def wait_closed(self) -> None: try: await asyncio.wait_for(self._wait_closed.wait(), self.close_timeout) except asyncio.TimeoutError: - log.debug("Timer close_timeout fired in wait_closed") + log.warning("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") diff --git a/tests/conftest.py b/tests/conftest.py index 664fe8c9..f9e11dab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import logging import os import pathlib +import platform import re import ssl import sys @@ -19,6 +20,9 @@ all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] +PyPy = platform.python_implementation() == "PyPy" + + def pytest_addoption(parser): parser.addoption( "--run-online", diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 79cf506d..e97da29a 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -10,7 +10,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportConnectionClosed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] @@ -260,7 +260,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,6 +275,9 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @@ -847,23 +851,33 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( # Then with the same session handle, we make a subscription or an execute # which will detect that the transport is closed so that the client could # try to reconnect + generator = None try: if execute_instead_of_subscribe: print("\nEXECUTION_2\n") await session.execute(subscription) else: print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: pass - except TransportClosed: + except (TransportClosed, TransportConnectionClosed): + if generator: + await generator.aclose() pass - await asyncio.sleep(50 * MS) + timeout = 50 + + if PyPy: + timeout = 500 + + await asyncio.sleep(timeout * MS) # And finally with the same session handle, we make a subscription # which works correctly print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -871,6 +885,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( assert number == count count -= 1 + await generator.aclose() + assert count == -1 await client.close_async() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 188e006e..61270fe1 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -250,7 +250,8 @@ async def test_aiohttp_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -264,6 +265,9 @@ async def test_aiohttp_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) diff --git a/tests/test_client.py b/tests/test_client.py index 1e794558..e5edec8b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -280,3 +280,7 @@ async def test_async_transport_close_on_schema_retrieval_failure(): pass assert client.transport.session is None + + import asyncio + + await asyncio.sleep(1) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 1b8f7ccb..8284fea8 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -10,7 +10,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportConnectionClosed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -260,7 +260,8 @@ async def test_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,6 +275,9 @@ async def test_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @@ -843,23 +847,33 @@ async def test_graphqlws_subscription_reconnecting_session( # Then with the same session handle, we make a subscription or an execute # which will detect that the transport is closed so that the client could # try to reconnect + generator = None try: if execute_instead_of_subscribe: print("\nEXECUTION_2\n") await session.execute(subscription) else: print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: pass - except TransportClosed: + except (TransportClosed, TransportConnectionClosed): + if generator: + await generator.aclose() pass - await asyncio.sleep(50 * MS) + timeout = 50 + + if PyPy: + timeout = 500 + + await asyncio.sleep(timeout * MS) # And finally with the same session handle, we make a subscription # which works correctly print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -867,6 +881,8 @@ async def test_graphqlws_subscription_reconnecting_session( assert number == count count -= 1 + await generator.aclose() + assert count == -1 await client.close_async() diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 732c0e14..16d4e4f4 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -216,8 +216,12 @@ async def test_phoenix_channel_subscription(event_loop, server, query_str): first_result = None query = gql(query_str) async with Client(transport=transport) as session: - async for result in session.subscribe(query): + generator = session.subscribe(query) + async for result in generator: first_result = result break + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + print("Client received:", first_result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 3be4b07d..35ca665b 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -201,7 +201,9 @@ async def test_phoenix_channel_subscription( subscription = gql(subscription_str.format(count=count)) async with Client(transport=sample_transport) as session: - async for result in session.subscribe(subscription): + + generator = session.subscribe(subscription) + async for result in generator: number = result["countdown"]["number"] print(f"Number received: {number}") @@ -212,6 +214,9 @@ async def test_phoenix_channel_subscription( count -= 1 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + assert count == end_count @@ -378,7 +383,8 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): subscription = gql(heartbeat_subscription_str) async with Client(transport=sample_transport) as session: i = 0 - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") @@ -387,3 +393,6 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): break i += 1 + + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 3efe63a6..927db4e9 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -11,7 +11,7 @@ from gql import Client, gql from gql.transport.exceptions import TransportConnectionClosed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -181,7 +181,8 @@ async def test_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -195,6 +196,9 @@ async def test_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -413,7 +417,14 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + + keep_alive_timeout = 20 * MS + if PyPy: + keep_alive_timeout = 200 * MS + + sample_transport = WebsocketsTransport( + url=url, keep_alive_timeout=keep_alive_timeout + ) client = Client(transport=sample_transport) From 7fb869a6abd8fb035e00f3706a9e104f5eb655eb Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 11 Mar 2025 11:31:24 +0100 Subject: [PATCH 201/239] Renaming TransportConnectionClosed to TransportConnectionFailed --- gql/transport/common/adapters/aiohttp.py | 16 ++++++++-------- gql/transport/common/adapters/connection.py | 4 ++-- gql/transport/common/adapters/websockets.py | 16 ++++++++-------- gql/transport/common/base.py | 12 ++++++------ gql/transport/exceptions.py | 2 +- gql/transport/phoenix_channel_websockets.py | 4 ++-- gql/transport/websockets_protocol.py | 4 ++-- tests/test_aiohttp_websocket_exceptions.py | 6 +++--- ...est_aiohttp_websocket_graphqlws_exceptions.py | 6 +++--- ...t_aiohttp_websocket_graphqlws_subscription.py | 8 ++++---- tests/test_aiohttp_websocket_query.py | 4 ++-- tests/test_aiohttp_websocket_subscription.py | 8 ++++---- tests/test_graphqlws_exceptions.py | 6 +++--- tests/test_graphqlws_subscription.py | 8 ++++---- tests/test_phoenix_channel_query.py | 4 ++-- tests/test_websocket_exceptions.py | 6 +++--- tests/test_websocket_query.py | 4 ++-- tests/test_websocket_subscription.py | 4 ++-- tests/test_websockets_adapter.py | 6 +++--- 19 files changed, 64 insertions(+), 64 deletions(-) diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index d9af7c50..f2dff699 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -8,7 +8,7 @@ from aiohttp.typedefs import LooseHeaders, StrOrURL from multidict import CIMultiDictProxy -from ...exceptions import TransportConnectionClosed, TransportProtocolError +from ...exceptions import TransportConnectionFailed, TransportProtocolError from ..aiohttp_closed_event import create_aiohttp_closed_event from .connection import AdapterConnection @@ -160,7 +160,7 @@ async def connect(self) -> None: **connect_args, ) except Exception as e: - raise TransportConnectionClosed("Connect failed") from e + raise TransportConnectionFailed("Connect failed") from e self._response_headers = self.websocket._response.headers @@ -171,15 +171,15 @@ async def send(self, message: str) -> None: message: String message to send Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") try: await self.websocket.send_str(message) except ConnectionResetError as e: - raise TransportConnectionClosed("Connection was closed") from e + raise TransportConnectionFailed("Connection was closed") from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -188,12 +188,12 @@ async def receive(self) -> str: String message received Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed TransportProtocolError: If protocol error or binary data received """ # It is possible that the websocket has been already closed in another task if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") while True: ws_message = await self.websocket.receive() @@ -208,7 +208,7 @@ async def receive(self) -> str: WSMsgType.CLOSING, WSMsgType.ERROR, ): - raise TransportConnectionClosed("Connection was closed") + raise TransportConnectionFailed("Connection was closed") elif ws_message.type is WSMsgType.BINARY: raise TransportProtocolError("Binary data received in the websocket") diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py index f3d77421..ac178bc6 100644 --- a/gql/transport/common/adapters/connection.py +++ b/gql/transport/common/adapters/connection.py @@ -35,7 +35,7 @@ async def send(self, message: str) -> None: message: String message to send Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed """ pass # pragma: no cover @@ -47,7 +47,7 @@ async def receive(self) -> str: String message received Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed TransportProtocolError: If protocol error or binary data received """ pass # pragma: no cover diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index 383d4def..c2524fb4 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -6,7 +6,7 @@ from websockets.client import WebSocketClientProtocol from websockets.datastructures import Headers, HeadersLike -from ...exceptions import TransportConnectionClosed, TransportProtocolError +from ...exceptions import TransportConnectionFailed, TransportProtocolError from .connection import AdapterConnection log = logging.getLogger("gql.transport.common.adapters.websockets") @@ -70,7 +70,7 @@ async def connect(self) -> None: try: self.websocket = await websockets.client.connect(self.url, **connect_args) except Exception as e: - raise TransportConnectionClosed("Connect failed") from e + raise TransportConnectionFailed("Connect failed") from e self._response_headers = self.websocket.response_headers @@ -81,15 +81,15 @@ async def send(self, message: str) -> None: message: String message to send Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") try: await self.websocket.send(message) except Exception as e: - raise TransportConnectionClosed("Connection was closed") from e + raise TransportConnectionFailed("Connection was closed") from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -98,18 +98,18 @@ async def receive(self) -> str: String message received Raises: - TransportConnectionClosed: If connection closed + TransportConnectionFailed: If connection closed TransportProtocolError: If protocol error or binary data received """ # It is possible that the websocket has been already closed in another task if self.websocket is None: - raise TransportConnectionClosed("Connection is already closed") + raise TransportConnectionFailed("Connection is already closed") # Wait for the next websocket frame. Can raise ConnectionClosed try: data = await self.websocket.recv() except Exception as e: - raise TransportConnectionClosed("Connection was closed") from e + raise TransportConnectionFailed("Connection was closed") from e # websocket.recv() can return either str or bytes # In our case, we should receive only str here diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 2a4d4d65..770a8b34 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -11,7 +11,7 @@ from ..exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -134,7 +134,7 @@ async def _send(self, message: str) -> None: try: await self.adapter.send(message) log.info(">>> %s", message) - except TransportConnectionClosed as e: + except TransportConnectionFailed as e: await self._fail(e, clean_close=False) raise e @@ -146,7 +146,7 @@ async def _receive(self) -> str: raise TransportClosed("Transport is already closed") # Wait for the next frame. - # Can raise TransportConnectionClosed or TransportProtocolError + # Can raise TransportConnectionFailed or TransportProtocolError answer: str = await self.adapter.receive() log.info("<<< %s", answer) @@ -211,7 +211,7 @@ async def _receive_data_loop(self) -> None: # Wait the next answer from the server try: answer = await self._receive() - except (TransportConnectionClosed, TransportProtocolError) as e: + except (TransportConnectionFailed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break except TransportClosed: @@ -296,7 +296,7 @@ async def subscribe( while True: # Wait for the answer from the queue of this query_id - # This can raise TransportError or TransportConnectionClosed + # This can raise TransportError or TransportConnectionFailed answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -402,7 +402,7 @@ async def connect(self) -> None: # if no ACKs are received within the ack_timeout try: await self._initialize() - except TransportConnectionClosed as e: + except TransportConnectionFailed as e: raise e except ( TransportProtocolError, diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 27cefe2f..3e63f0bc 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -61,7 +61,7 @@ class TransportClosed(TransportError): """ -class TransportConnectionClosed(TransportError): +class TransportConnectionFailed(TransportError): """Transport adapter connection closed. This exception is by the connection adapter code when a connection closed. diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 0c1bd62b..3885fcac 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -8,7 +8,7 @@ from .common.adapters.websockets import WebSocketsAdapter from .common.base import SubscriptionTransportBase from .exceptions import ( - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -127,7 +127,7 @@ async def heartbeat_coro(): } ) ) - except TransportConnectionClosed: # pragma: no cover + except TransportConnectionFailed: # pragma: no cover return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index f004d240..3348c576 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -9,7 +9,7 @@ from .common.adapters.connection import AdapterConnection from .common.base import SubscriptionTransportBase from .exceptions import ( - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -508,7 +508,7 @@ async def _close_hook(self): if self.send_ping_task is not None: log.debug("_close_hook: cancelling send_ping_task") self.send_ping_task.cancel() - with suppress(asyncio.CancelledError, TransportConnectionClosed): + with suppress(asyncio.CancelledError, TransportConnectionFailed): log.debug("_close_hook: awaiting send_ping_task") await self.send_ping_task self.send_ping_task = None diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index e4e56fcd..81c79ba7 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -289,7 +289,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): sample_transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -309,7 +309,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index 8f3567a7..f87682d2 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -248,7 +248,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=transport): pass @@ -268,7 +268,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index e97da29a..f380948c 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -394,7 +394,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -843,7 +843,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except TransportConnectionClosed: + except TransportConnectionFailed: pass await asyncio.sleep(50 * MS) @@ -861,7 +861,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( generator = session.subscribe(subscription) async for result in generator: pass - except (TransportClosed, TransportConnectionClosed): + except (TransportClosed, TransportConnectionFailed): if generator: await generator.aclose() pass diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 30b35d73..8786d58d 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -177,7 +177,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( if verify_https == "explicitely_enabled": assert transport.ssl is True - with pytest.raises(TransportConnectionClosed) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 61270fe1..4ea11a7b 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef @@ -385,7 +385,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -778,7 +778,7 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s async with client as session: session.transport.adapter.websocket._writer._closing = True - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass @@ -801,6 +801,6 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: session.transport.adapter.websocket = None - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index cce31d59..3b6bd901 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -241,7 +241,7 @@ async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -261,7 +261,7 @@ async def test_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 8284fea8..d4bed34f 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,7 +8,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -394,7 +394,7 @@ async def test_graphqlws_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -839,7 +839,7 @@ async def test_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except TransportConnectionClosed: + except TransportConnectionFailed: pass await asyncio.sleep(50 * MS) @@ -857,7 +857,7 @@ async def test_graphqlws_subscription_reconnecting_session( generator = session.subscribe(subscription) async for result in generator: pass - except (TransportClosed, TransportConnectionClosed): + except (TransportClosed, TransportConnectionFailed): if generator: await generator.aclose() pass diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 16d4e4f4..56d28875 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -1,7 +1,7 @@ import pytest from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed +from gql.transport.exceptions import TransportConnectionFailed from .conftest import get_localhost_ssl_context_client @@ -132,7 +132,7 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( query = gql(query_str) - with pytest.raises(TransportConnectionClosed) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: await session.execute(query) diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index f9f1f8db..68b2fe52 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -280,7 +280,7 @@ async def test_websocket_server_closing_directly(event_loop, server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -298,7 +298,7 @@ async def test_websocket_server_closing_after_ack(event_loop, client_and_server) query = gql("query { hello }") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index f7e92840..b1e3c07a 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, - TransportConnectionClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -156,7 +156,7 @@ async def test_websocket_using_ssl_connection_self_cert_fail( if verify_https == "explicitely_enabled": assert transport.ssl is True - with pytest.raises(TransportConnectionClosed) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 927db4e9..6f291218 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportConnectionClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -315,7 +315,7 @@ async def test_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py index f266ce29..85fbf00a 100644 --- a/tests/test_websockets_adapter.py +++ b/tests/test_websockets_adapter.py @@ -4,7 +4,7 @@ from graphql import print_ast from gql import gql -from gql.transport.exceptions import TransportConnectionClosed +from gql.transport.exceptions import TransportConnectionFailed # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -91,8 +91,8 @@ async def test_websockets_adapter_edge_cases(event_loop, server): # Second close call is ignored await adapter.close() - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await adapter.send("Blah") - with pytest.raises(TransportConnectionClosed): + with pytest.raises(TransportConnectionFailed): await adapter.receive() From fdc3b28786ef8a31c9616d9c824442d9c3c88740 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 11 Mar 2025 19:41:09 +0100 Subject: [PATCH 202/239] chore: Update pytest to 8.3.4 and pytest-asyncio to 0.25.3 (#537) --- pyproject.toml | 3 + setup.py | 4 +- tests/conftest.py | 23 ++-- tests/custom_scalars/test_money.py | 28 +++-- tests/starwars/test_introspection.py | 2 +- tests/starwars/test_validation.py | 8 +- tests/test_aiohttp.py | 100 ++++++++--------- tests/test_aiohttp_online.py | 6 +- tests/test_aiohttp_websocket_exceptions.py | 36 +++---- ..._aiohttp_websocket_graphqlws_exceptions.py | 20 ++-- ...iohttp_websocket_graphqlws_subscription.py | 32 +++--- tests/test_aiohttp_websocket_query.py | 46 ++++---- tests/test_aiohttp_websocket_subscription.py | 34 +++--- tests/test_appsync_http.py | 4 +- tests/test_appsync_websockets.py | 26 ++--- tests/test_async_client_validation.py | 14 ++- tests/test_graphqlws_exceptions.py | 26 ++--- tests/test_graphqlws_subscription.py | 34 +++--- tests/test_http_async_sync.py | 4 +- tests/test_httpx.py | 86 +++++++-------- tests/test_httpx_async.py | 88 +++++++-------- tests/test_httpx_online.py | 6 +- tests/test_phoenix_channel_exceptions.py | 14 ++- tests/test_phoenix_channel_query.py | 8 +- tests/test_phoenix_channel_subscription.py | 10 +- tests/test_requests.py | 102 +++++++----------- tests/test_requests_batch.py | 54 ++++------ tests/test_websocket_exceptions.py | 28 +++-- tests/test_websocket_query.py | 50 ++++----- tests/test_websocket_subscription.py | 30 +++--- tests/test_websockets_adapter.py | 4 +- 31 files changed, 397 insertions(+), 533 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9b631e08..122cec88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,3 +7,6 @@ dynamic = ["authors", "classifiers", "dependencies", "description", "entry-point [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" diff --git a/setup.py b/setup.py index a44c2e01..6b4c1fd2 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ tests_requires = [ "parse==1.15.0", - "pytest==7.4.2", - "pytest-asyncio==0.21.1", + "pytest==8.3.4", + "pytest-asyncio==0.25.3", "pytest-console-scripts==1.4.1", "pytest-cov==5.0.0", "vcrpy==7.0.0", diff --git a/tests/conftest.py b/tests/conftest.py index f9e11dab..5b8807ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -639,9 +639,9 @@ async def client_and_server(server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, server @@ -659,9 +659,9 @@ async def aiohttp_client_and_server(server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, server @@ -681,9 +681,9 @@ async def aiohttp_client_and_aiohttp_ws_server(aiohttp_ws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, server @@ -699,12 +699,12 @@ async def client_and_graphqlws_server(graphqlws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url=url, subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, graphqlws_server @@ -720,12 +720,12 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport( + transport = AIOHTTPWebsocketsTransport( url=url, subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, graphqlws_server @@ -733,11 +733,12 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): @pytest_asyncio.fixture async def run_sync_test(): - async def run_sync_test_inner(event_loop, server, test_function): + async def run_sync_test_inner(server, test_function): """This function will run the test in a different Thread. This allows us to run sync code while aiohttp server can still run. """ + event_loop = asyncio.get_running_loop() executor = ThreadPoolExecutor(max_workers=2) test_task = event_loop.run_in_executor(executor, test_function) diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 374c70e6..cf4ca45d 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -491,7 +491,7 @@ async def make_sync_money_transport(aiohttp_server): @pytest.mark.asyncio -async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server): +async def test_custom_scalar_in_output_with_transport(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -509,7 +509,7 @@ async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server @pytest.mark.asyncio -async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_server): +async def test_custom_scalar_in_input_query_with_transport(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -531,9 +531,7 @@ async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_s @pytest.mark.asyncio -async def test_custom_scalar_in_input_variable_values_with_transport( - event_loop, aiohttp_server -): +async def test_custom_scalar_in_input_variable_values_with_transport(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -556,7 +554,7 @@ async def test_custom_scalar_in_input_variable_values_with_transport( @pytest.mark.asyncio async def test_custom_scalar_in_input_variable_values_split_with_transport( - event_loop, aiohttp_server + aiohttp_server, ): transport = await make_money_transport(aiohttp_server) @@ -581,7 +579,7 @@ async def test_custom_scalar_in_input_variable_values_split_with_transport( @pytest.mark.asyncio -async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): +async def test_custom_scalar_serialize_variables(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -603,7 +601,7 @@ async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): @pytest.mark.asyncio -async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_server): +async def test_custom_scalar_serialize_variables_no_schema(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -623,7 +621,7 @@ async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_s @pytest.mark.asyncio async def test_custom_scalar_serialize_variables_schema_from_introspection( - event_loop, aiohttp_server + aiohttp_server, ): transport = await make_money_transport(aiohttp_server) @@ -656,7 +654,7 @@ async def test_custom_scalar_serialize_variables_schema_from_introspection( @pytest.mark.asyncio -async def test_update_schema_scalars(event_loop, aiohttp_server): +async def test_update_schema_scalars(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -735,7 +733,7 @@ def test_update_schema_scalars_scalar_type_is_not_a_scalar_in_schema(): @pytest.mark.asyncio @pytest.mark.requests async def test_custom_scalar_serialize_variables_sync_transport( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): server, transport = await make_sync_money_transport(aiohttp_server) @@ -754,13 +752,13 @@ def test_code(): print(f"result = {result!r}") assert result["toEuros"] == 5 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio @pytest.mark.requests async def test_custom_scalar_serialize_variables_sync_transport_2( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): server, transport = await make_sync_money_transport(aiohttp_server) @@ -783,7 +781,7 @@ def test_code(): assert results[0]["toEuros"] == 5 assert results[1]["toEuros"] == 5 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) def test_serialize_value_with_invalid_type(): @@ -818,7 +816,7 @@ def test_serialize_value_with_nullable_type(): @pytest.mark.asyncio -async def test_gql_cli_print_schema(event_loop, aiohttp_server, capsys): +async def test_gql_cli_print_schema(aiohttp_server, capsys): from gql.cli import get_parser, main diff --git a/tests/starwars/test_introspection.py b/tests/starwars/test_introspection.py index c3063808..0d8369c0 100644 --- a/tests/starwars/test_introspection.py +++ b/tests/starwars/test_introspection.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio -async def test_starwars_introspection_args(event_loop, aiohttp_server): +async def test_starwars_introspection_args(aiohttp_server): transport = await make_starwars_transport(aiohttp_server) diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index 1ca8a2bb..38676836 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -1,3 +1,5 @@ +import copy + import pytest from gql import Client, gql @@ -62,7 +64,8 @@ def introspection_schema(): @pytest.fixture def introspection_schema_empty_directives(): - introspection = StarWarsIntrospection + # Create a deep copy to avoid modifying the original + introspection = copy.deepcopy(StarWarsIntrospection) # Simulate an empty dictionary for directives introspection["__schema"]["directives"] = [] @@ -72,7 +75,8 @@ def introspection_schema_empty_directives(): @pytest.fixture def introspection_schema_no_directives(): - introspection = StarWarsIntrospection + # Create a deep copy to avoid modifying the original + introspection = copy.deepcopy(StarWarsIntrospection) # Simulate no directives key del introspection["__schema"]["directives"] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 81af20ff..88c4db98 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -45,7 +45,7 @@ @pytest.mark.asyncio -async def test_aiohttp_query(event_loop, aiohttp_server): +async def test_aiohttp_query(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -84,7 +84,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_ignore_backend_content_type(event_loop, aiohttp_server): +async def test_aiohttp_ignore_backend_content_type(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -113,7 +113,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_cookies(event_loop, aiohttp_server): +async def test_aiohttp_cookies(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -146,7 +146,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_code_401(event_loop, aiohttp_server): +async def test_aiohttp_error_code_401(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -177,7 +177,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_code_429(event_loop, aiohttp_server): +async def test_aiohttp_error_code_429(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -224,7 +224,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_code_500(event_loop, aiohttp_server): +async def test_aiohttp_error_code_500(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -259,7 +259,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("query_error", transport_query_error_responses) -async def test_aiohttp_error_code(event_loop, aiohttp_server, query_error): +async def test_aiohttp_error_code(aiohttp_server, query_error): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -314,7 +314,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("param", invalid_protocol_responses) -async def test_aiohttp_invalid_protocol(event_loop, aiohttp_server, param): +async def test_aiohttp_invalid_protocol(aiohttp_server, param): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -342,7 +342,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_subscribe_not_supported(event_loop, aiohttp_server): +async def test_aiohttp_subscribe_not_supported(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -367,7 +367,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_cannot_connect_twice(event_loop, aiohttp_server): +async def test_aiohttp_cannot_connect_twice(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -389,7 +389,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_cannot_execute_if_not_connected(event_loop, aiohttp_server): +async def test_aiohttp_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -411,7 +411,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_extra_args(event_loop, aiohttp_server): +async def test_aiohttp_extra_args(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -458,7 +458,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_variable_values(event_loop, aiohttp_server): +async def test_aiohttp_query_variable_values(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -490,7 +490,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_variable_values_fix_issue_292(event_loop, aiohttp_server): +async def test_aiohttp_query_variable_values_fix_issue_292(aiohttp_server): """Allow to specify variable_values without keyword. See https://github.com/graphql-python/gql/issues/292""" @@ -524,9 +524,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_execute_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_aiohttp_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -548,13 +546,11 @@ def test_code(): client.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio -async def test_aiohttp_subscribe_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_aiohttp_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -585,7 +581,7 @@ def test_code(): for result in client.subscribe(query): pass - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_server_answer = '{"data":{"success":true}}' @@ -640,7 +636,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio -async def test_aiohttp_file_upload(event_loop, aiohttp_server): +async def test_aiohttp_file_upload(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -705,7 +701,7 @@ async def single_upload_handler_with_content_type(request): @pytest.mark.asyncio -async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server): +async def test_aiohttp_file_upload_with_content_type(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -743,9 +739,7 @@ async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server) @pytest.mark.asyncio -async def test_aiohttp_file_upload_without_session( - event_loop, aiohttp_server, run_sync_test -): +async def test_aiohttp_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -778,7 +772,7 @@ def test_code(): assert success - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) # This is a sample binary file content containing all possible byte values @@ -813,7 +807,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio -async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): +async def test_aiohttp_binary_file_upload(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -848,7 +842,7 @@ async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.asyncio -async def test_aiohttp_stream_reader_upload(event_loop, aiohttp_server): +async def test_aiohttp_stream_reader_upload(aiohttp_server): from aiohttp import web, ClientSession from gql.transport.aiohttp import AIOHTTPTransport @@ -885,7 +879,7 @@ async def binary_data_handler(request): @pytest.mark.asyncio -async def test_aiohttp_async_generator_upload(event_loop, aiohttp_server): +async def test_aiohttp_async_generator_upload(aiohttp_server): import aiofiles from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -948,7 +942,7 @@ async def file_sender(file_name): @pytest.mark.asyncio -async def test_aiohttp_file_upload_two_files(event_loop, aiohttp_server): +async def test_aiohttp_file_upload_two_files(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1039,7 +1033,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_file_upload_list_of_two_files(event_loop, aiohttp_server): +async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1111,7 +1105,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_using_cli(event_loop, aiohttp_server, monkeypatch, capsys): +async def test_aiohttp_using_cli(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1148,7 +1142,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.script_launch_mode("subprocess") async def test_aiohttp_using_cli_ep( - event_loop, aiohttp_server, monkeypatch, script_runner, run_sync_test + aiohttp_server, monkeypatch, script_runner, run_sync_test ): from aiohttp import web @@ -1181,13 +1175,11 @@ def test_code(): assert received_answer == expected_answer - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio -async def test_aiohttp_using_cli_invalid_param( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_aiohttp_using_cli_invalid_param(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1221,9 +1213,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_using_cli_invalid_query( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_aiohttp_using_cli_invalid_query(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1261,7 +1251,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server): +async def test_aiohttp_query_with_extensions(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1290,9 +1280,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_aiohttp_query_https( - event_loop, ssl_aiohttp_server, ssl_close_timeout, verify_https -): +async def test_aiohttp_query_https(ssl_aiohttp_server, ssl_close_timeout, verify_https): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1339,7 +1327,7 @@ async def handler(request): @pytest.mark.skip(reason="We will change the default to fix this in a future version") @pytest.mark.asyncio -async def test_aiohttp_query_https_self_cert_fail(event_loop, ssl_aiohttp_server): +async def test_aiohttp_query_https_self_cert_fail(ssl_aiohttp_server): """By default, we should verify the ssl certificate""" from aiohttp.client_exceptions import ClientConnectorCertificateError from aiohttp import web @@ -1372,7 +1360,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_https_self_cert_warn(event_loop, ssl_aiohttp_server): +async def test_aiohttp_query_https_self_cert_warn(ssl_aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1397,7 +1385,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): +async def test_aiohttp_error_fetching_schema(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1440,7 +1428,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_reconnecting_session(event_loop, aiohttp_server): +async def test_aiohttp_reconnecting_session(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1478,9 +1466,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("retries", [False, lambda e: e]) -async def test_aiohttp_reconnecting_session_retries( - event_loop, aiohttp_server, retries -): +async def test_aiohttp_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1512,7 +1498,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_reconnecting_session_start_connecting_task_twice( - event_loop, aiohttp_server, caplog + aiohttp_server, caplog ): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1546,7 +1532,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_json_serializer(event_loop, aiohttp_server, caplog): +async def test_aiohttp_json_serializer(aiohttp_server, caplog): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1602,7 +1588,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_json_deserializer(event_loop, aiohttp_server): +async def test_aiohttp_json_deserializer(aiohttp_server): from aiohttp import web from decimal import Decimal from functools import partial @@ -1641,7 +1627,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): +async def test_aiohttp_connector_owner_false(aiohttp_server): from aiohttp import web, TCPConnector from gql.transport.aiohttp import AIOHTTPTransport diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 39b8a9d2..7cacd921 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -11,7 +11,7 @@ @pytest.mark.aiohttp @pytest.mark.online @pytest.mark.asyncio -async def test_aiohttp_simple_query(event_loop): +async def test_aiohttp_simple_query(): from gql.transport.aiohttp import AIOHTTPTransport @@ -56,7 +56,7 @@ async def test_aiohttp_simple_query(event_loop): @pytest.mark.aiohttp @pytest.mark.online @pytest.mark.asyncio -async def test_aiohttp_invalid_query(event_loop): +async def test_aiohttp_invalid_query(): from gql.transport.aiohttp import AIOHTTPTransport @@ -85,7 +85,7 @@ async def test_aiohttp_invalid_query(event_loop): @pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio -async def test_aiohttp_two_queries_in_parallel_using_two_tasks(event_loop): +async def test_aiohttp_two_queries_in_parallel_using_two_tasks(): from gql.transport.aiohttp import AIOHTTPTransport diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 81c79ba7..801af6b9 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -40,9 +40,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_websocket_invalid_query( - event_loop, aiohttp_client_and_server, query_str -): +async def test_aiohttp_websocket_invalid_query(aiohttp_client_and_server, query_str): session, server = aiohttp_client_and_server @@ -82,7 +80,7 @@ async def server_invalid_subscription(ws): @pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) async def test_aiohttp_websocket_invalid_subscription( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -115,9 +113,7 @@ async def server_no_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_websocket_server_does_not_send_ack( - event_loop, server, query_str -): +async def test_aiohttp_websocket_server_does_not_send_ack(server, query_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -141,7 +137,7 @@ async def server_connection_error(ws): @pytest.mark.parametrize("server", [server_connection_error], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_sending_invalid_data( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -171,7 +167,7 @@ async def server_invalid_payload(ws): @pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_sending_invalid_payload( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -241,9 +237,7 @@ async def monkey_patch_send_query( ], indirect=True, ) -async def test_aiohttp_websocket_transport_protocol_errors( - event_loop, aiohttp_client_and_server -): +async def test_aiohttp_websocket_transport_protocol_errors(aiohttp_client_and_server): session, server = aiohttp_client_and_server @@ -261,7 +255,7 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_without_ack], indirect=True) -async def test_aiohttp_websocket_server_does_not_ack(event_loop, server): +async def test_aiohttp_websocket_server_does_not_ack(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -280,7 +274,7 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) -async def test_aiohttp_websocket_server_closing_directly(event_loop, server): +async def test_aiohttp_websocket_server_closing_directly(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -301,9 +295,7 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) -async def test_aiohttp_websocket_server_closing_after_ack( - event_loop, aiohttp_client_and_server -): +async def test_aiohttp_websocket_server_closing_after_ack(aiohttp_client_and_server): session, server = aiohttp_client_and_server @@ -325,9 +317,7 @@ async def server_sending_invalid_query_errors(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_aiohttp_websocket_server_sending_invalid_query_errors( - event_loop, server -): +async def test_aiohttp_websocket_server_sending_invalid_query_errors(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -342,7 +332,7 @@ async def test_aiohttp_websocket_server_sending_invalid_query_errors( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_aiohttp_websocket_non_regression_bug_105(event_loop, server): +async def test_aiohttp_websocket_non_regression_bug_105(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport # This test will check a fix to a race condition which happens if the user is trying @@ -373,9 +363,7 @@ async def client_connect(client): @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) -async def test_aiohttp_websocket_using_cli_invalid_query( - event_loop, server, monkeypatch, capsys -): +async def test_aiohttp_websocket_using_cli_invalid_query(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index f87682d2..a7548cce 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -40,7 +40,7 @@ @pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_graphqlws_invalid_query( - event_loop, client_and_aiohttp_websocket_graphql_server, query_str + client_and_aiohttp_websocket_graphql_server, query_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -83,7 +83,7 @@ async def server_invalid_subscription(ws): ) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) async def test_aiohttp_websocket_graphqlws_invalid_subscription( - event_loop, client_and_aiohttp_websocket_graphql_server, query_str + client_and_aiohttp_websocket_graphql_server, query_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -111,7 +111,7 @@ async def server_no_ack(ws): @pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( - event_loop, graphqlws_server, query_str + graphqlws_server, query_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -143,7 +143,7 @@ async def server_invalid_query(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) async def test_aiohttp_websocket_graphqlws_sending_invalid_query( - event_loop, client_and_aiohttp_websocket_graphql_server + client_and_aiohttp_websocket_graphql_server, ): session, server = client_and_aiohttp_websocket_graphql_server @@ -197,7 +197,7 @@ async def test_aiohttp_websocket_graphqlws_sending_invalid_query( indirect=True, ) async def test_aiohttp_websocket_graphqlws_transport_protocol_errors( - event_loop, client_and_aiohttp_websocket_graphql_server + client_and_aiohttp_websocket_graphql_server, ): session, server = client_and_aiohttp_websocket_graphql_server @@ -216,9 +216,7 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) -async def test_aiohttp_websocket_graphqlws_server_does_not_ack( - event_loop, graphqlws_server -): +async def test_aiohttp_websocket_graphqlws_server_does_not_ack(graphqlws_server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -237,9 +235,7 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) -async def test_aiohttp_websocket_graphqlws_server_closing_directly( - event_loop, graphqlws_server -): +async def test_aiohttp_websocket_graphqlws_server_closing_directly(graphqlws_server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -261,7 +257,7 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( - event_loop, client_and_aiohttp_websocket_graphql_server + client_and_aiohttp_websocket_graphql_server, ): session, _ = client_and_aiohttp_websocket_graphql_server diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index f380948c..8863ead9 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -229,7 +229,7 @@ async def server_countdown_disconnect(ws): @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -252,7 +252,7 @@ async def test_aiohttp_websocket_graphqlws_subscription( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_break( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -283,7 +283,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -322,7 +322,7 @@ async def cancel_task_coro(): @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_close_transport( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -387,7 +387,7 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, _ = client_and_aiohttp_websocket_graphql_server @@ -408,7 +408,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -438,7 +438,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -468,7 +468,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -502,7 +502,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_time ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -537,7 +537,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_time ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -573,7 +573,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -606,7 +606,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -648,7 +648,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payloa ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_manual_pong_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -781,7 +781,7 @@ def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_running_in_thread( - event_loop, graphqlws_server, subscription_str, run_sync_test + graphqlws_server, subscription_str, run_sync_test ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -805,7 +805,7 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, graphqlws_server, test_code) + await run_sync_test(graphqlws_server, test_code) @pytest.mark.asyncio @@ -815,7 +815,7 @@ def test_code(): @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( - event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe + graphqlws_server, subscription_str, execute_instead_of_subscribe ): from gql.transport.exceptions import TransportClosed diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 8786d58d..deb425f7 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -51,9 +51,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_starting_client_in_context_manager( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_starting_client_in_context_manager(aiohttp_ws_server): server = aiohttp_ws_server from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -101,7 +99,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_aiohttp_websocket_using_ssl_connection( - event_loop, ws_ssl_server, ssl_close_timeout, verify_https + ws_ssl_server, ssl_close_timeout, verify_https ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -152,7 +150,7 @@ async def test_aiohttp_websocket_using_ssl_connection( @pytest.mark.parametrize("ssl_close_timeout", [10]) @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( - event_loop, ws_ssl_server, ssl_close_timeout, verify_https + ws_ssl_server, ssl_close_timeout, verify_https ): from aiohttp.client_exceptions import ClientConnectorCertificateError @@ -200,9 +198,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( @pytest.mark.websockets @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_aiohttp_websocket_simple_query( - event_loop, aiohttp_client_and_server, query_str -): +async def test_aiohttp_websocket_simple_query(aiohttp_client_and_server, query_str): session, server = aiohttp_client_and_server @@ -225,7 +221,7 @@ async def test_aiohttp_websocket_simple_query( ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_two_queries_in_series( - event_loop, aiohttp_client_and_aiohttp_ws_server, query_str + aiohttp_client_and_aiohttp_ws_server, query_str ): session, server = aiohttp_client_and_aiohttp_ws_server @@ -262,7 +258,7 @@ async def server1_two_queries_in_parallel(ws): @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_two_queries_in_parallel( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -310,7 +306,7 @@ async def server_closing_while_we_are_doing_something_else(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_server_closing_after_first_query( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -342,7 +338,7 @@ async def test_aiohttp_websocket_server_closing_after_first_query( ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_ignore_invalid_id( - event_loop, aiohttp_client_and_aiohttp_ws_server, query_str + aiohttp_client_and_aiohttp_ws_server, query_str ): session, server = aiohttp_client_and_aiohttp_ws_server @@ -378,9 +374,7 @@ async def assert_client_is_working(session): @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_multiple_connections_in_series( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_multiple_connections_in_series(aiohttp_ws_server): server = aiohttp_ws_server @@ -406,9 +400,7 @@ async def test_aiohttp_websocket_multiple_connections_in_series( @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_multiple_connections_in_parallel( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_multiple_connections_in_parallel(aiohttp_ws_server): server = aiohttp_ws_server @@ -431,7 +423,7 @@ async def task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transport( - event_loop, aiohttp_ws_server + aiohttp_ws_server, ): server = aiohttp_ws_server @@ -482,7 +474,7 @@ async def server_with_authentication_in_connection_init_payload(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_connect_success_with_authentication_in_connection_init( - event_loop, server, query_str + server, query_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -518,7 +510,7 @@ async def test_aiohttp_websocket_connect_success_with_authentication_in_connecti @pytest.mark.parametrize("query_str", [query1_str]) @pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) async def test_aiohttp_websocket_connect_failed_with_authentication_in_connection_init( - event_loop, server, query_str, init_payload + server, query_str, init_payload ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -584,9 +576,7 @@ def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_add_extra_parameters_to_connect( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_add_extra_parameters_to_connect(aiohttp_ws_server): server = aiohttp_ws_server @@ -628,7 +618,7 @@ async def server_sending_keep_alive_before_connection_ack(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_non_regression_bug_108( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): # This test will check that we now ignore keepalive message @@ -653,7 +643,7 @@ async def test_aiohttp_websocket_non_regression_bug_108( @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) @pytest.mark.parametrize("transport_arg", [[], ["--transport=aiohttp_websockets"]]) async def test_aiohttp_websocket_using_cli( - event_loop, aiohttp_ws_server, transport_arg, monkeypatch, capsys + aiohttp_ws_server, transport_arg, monkeypatch, capsys ): """ @@ -717,7 +707,7 @@ async def test_aiohttp_websocket_using_cli( ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_simple_query_with_extensions( - event_loop, aiohttp_client_and_aiohttp_ws_server, query_str + aiohttp_client_and_aiohttp_ws_server, query_str ): session, server = aiohttp_client_and_aiohttp_ws_server @@ -731,7 +721,7 @@ async def test_aiohttp_websocket_simple_query_with_extensions( @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_server): +async def test_aiohttp_websocket_connector_owner_false(aiohttp_ws_server): server = aiohttp_ws_server diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 4ea11a7b..5beb023e 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -194,7 +194,7 @@ async def keepalive_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -217,7 +217,7 @@ async def test_aiohttp_websocket_subscription( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_get_execution_result( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -242,7 +242,7 @@ async def test_aiohttp_websocket_subscription_get_execution_result( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_break( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -273,7 +273,7 @@ async def test_aiohttp_websocket_subscription_break( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_task_cancel( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -312,7 +312,7 @@ async def cancel_task_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_close_transport( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, _ = aiohttp_client_and_server @@ -377,7 +377,7 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_server_connection_closed( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -401,7 +401,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_slow_consumer( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -426,7 +426,7 @@ async def test_aiohttp_websocket_subscription_slow_consumer( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_operation_name( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -457,7 +457,7 @@ async def test_aiohttp_websocket_subscription_with_operation_name( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_keepalive( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -480,7 +480,7 @@ async def test_aiohttp_websocket_subscription_with_keepalive( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -510,7 +510,7 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -688,7 +688,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_running_in_thread( - event_loop, server, subscription_str, run_sync_test + server, subscription_str, run_sync_test ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -712,7 +712,7 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio @@ -726,9 +726,7 @@ def test_code(): {"schema": StarWarsTypeDef}, ], ) -async def test_async_aiohttp_client_validation( - event_loop, server, subscription_str, client_params -): +async def test_async_aiohttp_client_validation(server, subscription_str, client_params): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -763,7 +761,7 @@ async def test_async_aiohttp_client_validation( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_closing_transport(event_loop, server, subscription_str): +async def test_subscribe_on_closing_transport(server, subscription_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -786,7 +784,7 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_null_transport(event_loop, server, subscription_str): +async def test_subscribe_on_null_transport(server, subscription_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index ca3a3fcb..2a6c9ca7 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -8,9 +8,7 @@ @pytest.mark.asyncio @pytest.mark.aiohttp @pytest.mark.botocore -async def test_appsync_iam_mutation( - event_loop, aiohttp_server, fake_credentials_factory -): +async def test_appsync_iam_mutation(aiohttp_server, fake_credentials_factory): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 88bae8b6..7aa96292 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -404,7 +404,7 @@ async def default_transport_test(transport): @pytest.mark.asyncio @pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) -async def test_appsync_subscription_api_key(event_loop, server): +async def test_appsync_subscription_api_key(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -424,7 +424,7 @@ async def test_appsync_subscription_api_key(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_iam_with_token(event_loop, server): +async def test_appsync_subscription_iam_with_token(server): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -451,7 +451,7 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_iam_without_token(event_loop, server): +async def test_appsync_subscription_iam_without_token(server): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -477,7 +477,7 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_execute_method_not_allowed(event_loop, server): +async def test_appsync_execute_method_not_allowed(server): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -524,7 +524,7 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore -async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): +async def test_appsync_fetch_schema_from_transport_not_allowed(): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -552,7 +552,7 @@ async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): @pytest.mark.asyncio @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_api_key_unauthorized(event_loop, server): +async def test_appsync_subscription_api_key_unauthorized(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -577,7 +577,7 @@ async def test_appsync_subscription_api_key_unauthorized(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_iam_not_allowed(event_loop, server): +async def test_appsync_subscription_iam_not_allowed(server): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -616,9 +616,7 @@ async def test_appsync_subscription_iam_not_allowed(event_loop, server): @pytest.mark.parametrize( "server", [realtime_appsync_server_not_json_answer], indirect=True ) -async def test_appsync_subscription_server_sending_a_not_json_answer( - event_loop, server -): +async def test_appsync_subscription_server_sending_a_not_json_answer(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -644,9 +642,7 @@ async def test_appsync_subscription_server_sending_a_not_json_answer( @pytest.mark.parametrize( "server", [realtime_appsync_server_error_without_id], indirect=True ) -async def test_appsync_subscription_server_sending_an_error_without_an_id( - event_loop, server -): +async def test_appsync_subscription_server_sending_an_error_without_an_id(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -670,9 +666,7 @@ async def test_appsync_subscription_server_sending_an_error_without_an_id( @pytest.mark.asyncio @pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) -async def test_appsync_subscription_variable_values_and_operation_name( - event_loop, server -): +async def test_appsync_subscription_variable_values_and_operation_name(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index acfabe0e..be214134 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -85,9 +85,7 @@ async def server_starwars(ws): {"schema": StarWarsTypeDef}, ], ) -async def test_async_client_validation( - event_loop, server, subscription_str, client_params -): +async def test_async_client_validation(server, subscription_str, client_params): from gql.transport.websockets import WebsocketsTransport @@ -133,7 +131,7 @@ async def test_async_client_validation( ], ) async def test_async_client_validation_invalid_query( - event_loop, server, subscription_str, client_params + server, subscription_str, client_params ): from gql.transport.websockets import WebsocketsTransport @@ -166,7 +164,7 @@ async def test_async_client_validation_invalid_query( [{"schema": StarWarsSchema, "introspection": StarWarsIntrospection}], ) async def test_async_client_validation_different_schemas_parameters_forbidden( - event_loop, server, subscription_str, client_params + server, subscription_str, client_params ): from gql.transport.websockets import WebsocketsTransport @@ -192,7 +190,7 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( @pytest.mark.asyncio @pytest.mark.parametrize("server", [hero_server_answers], indirect=True) async def test_async_client_validation_fetch_schema_from_server_valid_query( - event_loop, client_and_server + client_and_server, ): session, server = client_and_server client = session.client @@ -230,7 +228,7 @@ async def test_async_client_validation_fetch_schema_from_server_valid_query( @pytest.mark.asyncio @pytest.mark.parametrize("server", [hero_server_answers], indirect=True) async def test_async_client_validation_fetch_schema_from_server_invalid_query( - event_loop, client_and_server + client_and_server, ): session, server = client_and_server @@ -256,7 +254,7 @@ async def test_async_client_validation_fetch_schema_from_server_invalid_query( @pytest.mark.asyncio @pytest.mark.parametrize("server", [hero_server_answers], indirect=True) async def test_async_client_validation_fetch_schema_from_server_with_client_argument( - event_loop, server + server, ): from gql.transport.websockets import WebsocketsTransport diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index 3b6bd901..2e3514d1 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -39,9 +39,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_graphqlws_invalid_query( - event_loop, client_and_graphqlws_server, query_str -): +async def test_graphqlws_invalid_query(client_and_graphqlws_server, query_str): session, server = client_and_graphqlws_server @@ -82,9 +80,7 @@ async def server_invalid_subscription(ws): "graphqlws_server", [server_invalid_subscription], indirect=True ) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) -async def test_graphqlws_invalid_subscription( - event_loop, client_and_graphqlws_server, query_str -): +async def test_graphqlws_invalid_subscription(client_and_graphqlws_server, query_str): session, server = client_and_graphqlws_server @@ -110,9 +106,7 @@ async def server_no_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_graphqlws_server_does_not_send_ack( - event_loop, graphqlws_server, query_str -): +async def test_graphqlws_server_does_not_send_ack(graphqlws_server, query_str): from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -142,7 +136,7 @@ async def server_invalid_query(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) -async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_server): +async def test_graphqlws_sending_invalid_query(client_and_graphqlws_server): session, server = client_and_graphqlws_server @@ -194,9 +188,7 @@ async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_ ], indirect=True, ) -async def test_graphqlws_transport_protocol_errors( - event_loop, client_and_graphqlws_server -): +async def test_graphqlws_transport_protocol_errors(client_and_graphqlws_server): session, server = client_and_graphqlws_server @@ -214,7 +206,7 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) -async def test_graphqlws_server_does_not_ack(event_loop, graphqlws_server): +async def test_graphqlws_server_does_not_ack(graphqlws_server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -233,7 +225,7 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) -async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): +async def test_graphqlws_server_closing_directly(graphqlws_server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -253,9 +245,7 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) -async def test_graphqlws_server_closing_after_ack( - event_loop, client_and_graphqlws_server -): +async def test_graphqlws_server_closing_after_ack(client_and_graphqlws_server): session, server = client_and_graphqlws_server diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index d4bed34f..2735fbb0 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -228,9 +228,7 @@ async def server_countdown_disconnect(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_graphqlws_subscription( - event_loop, client_and_graphqlws_server, subscription_str -): +async def test_graphqlws_subscription(client_and_graphqlws_server, subscription_str): session, server = client_and_graphqlws_server @@ -252,7 +250,7 @@ async def test_graphqlws_subscription( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_break( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -283,7 +281,7 @@ async def test_graphqlws_subscription_break( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_task_cancel( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -322,7 +320,7 @@ async def cancel_task_coro(): @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_close_transport( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -387,7 +385,7 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_server_connection_closed( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -410,7 +408,7 @@ async def test_graphqlws_subscription_server_connection_closed( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_operation_name( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -440,7 +438,7 @@ async def test_graphqlws_subscription_with_operation_name( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_keepalive( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -470,7 +468,7 @@ async def test_graphqlws_subscription_with_keepalive( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_keepalive_with_timeout_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -502,7 +500,7 @@ async def test_graphqlws_subscription_with_keepalive_with_timeout_ok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_keepalive_with_timeout_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -535,7 +533,7 @@ async def test_graphqlws_subscription_with_keepalive_with_timeout_nok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_ping_interval_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -571,7 +569,7 @@ async def test_graphqlws_subscription_with_ping_interval_ok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_ping_interval_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -604,7 +602,7 @@ async def test_graphqlws_subscription_with_ping_interval_nok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_manual_pings_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -646,7 +644,7 @@ async def test_graphqlws_subscription_manual_pings_with_payload( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_manual_pong_answers_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -777,7 +775,7 @@ def test_graphqlws_subscription_sync_graceful_shutdown( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_running_in_thread( - event_loop, graphqlws_server, subscription_str, run_sync_test + graphqlws_server, subscription_str, run_sync_test ): from gql.transport.websockets import WebsocketsTransport @@ -801,7 +799,7 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, graphqlws_server, test_code) + await run_sync_test(graphqlws_server, test_code) @pytest.mark.asyncio @@ -811,7 +809,7 @@ def test_code(): @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) async def test_graphqlws_subscription_reconnecting_session( - event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe + graphqlws_server, subscription_str, execute_instead_of_subscribe ): from gql.transport.websockets import WebsocketsTransport diff --git a/tests/test_http_async_sync.py b/tests/test_http_async_sync.py index 19b6cfa2..45efd7f5 100644 --- a/tests/test_http_async_sync.py +++ b/tests/test_http_async_sync.py @@ -7,7 +7,7 @@ @pytest.mark.online @pytest.mark.asyncio @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -async def test_async_client_async_transport(event_loop, fetch_schema_from_transport): +async def test_async_client_async_transport(fetch_schema_from_transport): from gql.transport.aiohttp import AIOHTTPTransport @@ -51,7 +51,7 @@ async def test_async_client_async_transport(event_loop, fetch_schema_from_transp @pytest.mark.online @pytest.mark.asyncio @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -async def test_async_client_sync_transport(event_loop, fetch_schema_from_transport): +async def test_async_client_sync_transport(fetch_schema_from_transport): from gql.transport.requests import RequestsHTTPTransport diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 8ef57a84..c15872d7 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -36,7 +36,7 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_query(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -74,15 +74,13 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_httpx_query_https( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https -): +async def test_httpx_query_https(ssl_aiohttp_server, run_sync_test, verify_https): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -134,14 +132,14 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_httpx_query_https_self_cert_fail( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https + ssl_aiohttp_server, run_sync_test, verify_https ): """By default, we should verify the ssl certificate""" from aiohttp import web @@ -186,12 +184,12 @@ def test_code(): assert expected_error in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_cookies(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -223,12 +221,12 @@ def test_code(): assert africa["code"] == "AF" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_401(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -258,12 +256,12 @@ def test_code(): assert "Client error '401 Unauthorized'" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_429(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -312,7 +310,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_500(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -336,7 +334,7 @@ def test_code(): with pytest.raises(TransportServerError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' @@ -344,7 +342,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -369,7 +367,7 @@ def test_code(): with pytest.raises(TransportQueryError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) invalid_protocol_responses = [ @@ -382,9 +380,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_httpx_invalid_protocol( - event_loop, aiohttp_server, response, run_sync_test -): +async def test_httpx_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -407,12 +403,12 @@ def test_code(): with pytest.raises(TransportProtocolError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -433,14 +429,12 @@ def test_code(): with pytest.raises(TransportAlreadyConnected): session.transport.connect() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_execute_if_not_connected( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -461,7 +455,7 @@ def test_code(): with pytest.raises(TransportClosed): transport.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_answer_with_extensions = ( @@ -477,7 +471,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_with_extensions(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -503,7 +497,7 @@ def test_code(): assert execution_result.extensions["key1"] == "val1" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_server_answer = '{"data":{"success":true}}' @@ -532,7 +526,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_file_upload(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -587,14 +581,12 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_with_content_type( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_file_upload_with_content_type(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -655,14 +647,12 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_additional_headers( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -719,12 +709,12 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_binary_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -785,7 +775,7 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_mutation_2_operations = ( @@ -797,7 +787,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -883,7 +873,7 @@ def test_code(): f1.close() f2.close() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_mutation_3_operations = ( @@ -895,9 +885,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_list_of_two_files( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -981,12 +969,12 @@ def test_code(): f1.close() f2.close() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_fetching_schema(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXTransport @@ -1028,4 +1016,4 @@ def test_code(): assert expected_error in str(exc_info.value) assert transport.client is None - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 47744538..44764ea4 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -46,7 +46,7 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query(event_loop, aiohttp_server): +async def test_httpx_query(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -86,7 +86,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_ignore_backend_content_type(event_loop, aiohttp_server): +async def test_httpx_ignore_backend_content_type(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -116,7 +116,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cookies(event_loop, aiohttp_server): +async def test_httpx_cookies(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -150,7 +150,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_401(event_loop, aiohttp_server): +async def test_httpx_error_code_401(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -182,7 +182,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_429(event_loop, aiohttp_server): +async def test_httpx_error_code_429(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -230,7 +230,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_500(event_loop, aiohttp_server): +async def test_httpx_error_code_500(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -266,7 +266,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("query_error", transport_query_error_responses) -async def test_httpx_error_code(event_loop, aiohttp_server, query_error): +async def test_httpx_error_code(aiohttp_server, query_error): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -322,7 +322,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("param", invalid_protocol_responses) -async def test_httpx_invalid_protocol(event_loop, aiohttp_server, param): +async def test_httpx_invalid_protocol(aiohttp_server, param): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -351,7 +351,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_subscribe_not_supported(event_loop, aiohttp_server): +async def test_httpx_subscribe_not_supported(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -377,7 +377,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server): +async def test_httpx_cannot_connect_twice(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -400,7 +400,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_execute_if_not_connected(event_loop, aiohttp_server): +async def test_httpx_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -423,7 +423,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_extra_args(event_loop, aiohttp_server): +async def test_httpx_extra_args(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport import httpx @@ -468,7 +468,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_variable_values(event_loop, aiohttp_server): +async def test_httpx_query_variable_values(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -501,7 +501,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_variable_values_fix_issue_292(event_loop, aiohttp_server): +async def test_httpx_query_variable_values_fix_issue_292(aiohttp_server): """Allow to specify variable_values without keyword. See https://github.com/graphql-python/gql/issues/292""" @@ -536,9 +536,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_execute_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -560,14 +558,12 @@ def test_code(): client.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_subscribe_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -598,7 +594,7 @@ def test_code(): for result in client.subscribe(query): pass - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_server_answer = '{"data":{"success":true}}' @@ -654,7 +650,7 @@ async def single_upload_handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload(event_loop, aiohttp_server): +async def test_httpx_file_upload(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -690,9 +686,7 @@ async def test_httpx_file_upload(event_loop, aiohttp_server): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_without_session( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -725,7 +719,7 @@ def test_code(): assert success - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) # This is a sample binary file content containing all possible byte values @@ -761,7 +755,7 @@ async def binary_upload_handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_binary_file_upload(event_loop, aiohttp_server): +async def test_httpx_binary_file_upload(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -819,7 +813,7 @@ async def test_httpx_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_two_files(event_loop, aiohttp_server): +async def test_httpx_file_upload_two_files(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -911,7 +905,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_list_of_two_files(event_loop, aiohttp_server): +async def test_httpx_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -984,7 +978,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_using_cli(event_loop, aiohttp_server, monkeypatch, capsys): +async def test_httpx_using_cli(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1022,7 +1016,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.script_launch_mode("subprocess") async def test_httpx_using_cli_ep( - event_loop, aiohttp_server, monkeypatch, script_runner, run_sync_test + aiohttp_server, monkeypatch, script_runner, run_sync_test ): from aiohttp import web @@ -1055,14 +1049,12 @@ def test_code(): assert received_answer == expected_answer - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_using_cli_invalid_param( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_httpx_using_cli_invalid_param(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1097,9 +1089,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_using_cli_invalid_query( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_httpx_using_cli_invalid_query(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1138,7 +1128,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_with_extensions(event_loop, aiohttp_server): +async def test_httpx_query_with_extensions(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1167,7 +1157,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_httpx_query_https(event_loop, ssl_aiohttp_server, verify_https): +async def test_httpx_query_https(ssl_aiohttp_server, verify_https): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1210,9 +1200,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) -async def test_httpx_query_https_self_cert_fail( - event_loop, ssl_aiohttp_server, verify_https -): +async def test_httpx_query_https_self_cert_fail(ssl_aiohttp_server, verify_https): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport from httpx import ConnectError @@ -1250,7 +1238,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): +async def test_httpx_error_fetching_schema(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1294,7 +1282,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_reconnecting_session(event_loop, aiohttp_server): +async def test_httpx_reconnecting_session(aiohttp_server): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1333,7 +1321,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("retries", [False, lambda e: e]) -async def test_httpx_reconnecting_session_retries(event_loop, aiohttp_server, retries): +async def test_httpx_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1366,7 +1354,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_reconnecting_session_start_connecting_task_twice( - event_loop, aiohttp_server, caplog + aiohttp_server, caplog ): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1401,7 +1389,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_json_serializer(event_loop, aiohttp_server, caplog): +async def test_httpx_json_serializer(aiohttp_server, caplog): from aiohttp import web from gql.transport.httpx import HTTPXAsyncTransport @@ -1458,7 +1446,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_json_deserializer(event_loop, aiohttp_server): +async def test_httpx_json_deserializer(aiohttp_server): from aiohttp import web from decimal import Decimal from functools import partial diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index 23d28dcc..3b08fa18 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -11,7 +11,7 @@ @pytest.mark.httpx @pytest.mark.online @pytest.mark.asyncio -async def test_httpx_simple_query(event_loop): +async def test_httpx_simple_query(): from gql.transport.httpx import HTTPXAsyncTransport @@ -56,7 +56,7 @@ async def test_httpx_simple_query(event_loop): @pytest.mark.httpx @pytest.mark.online @pytest.mark.asyncio -async def test_httpx_invalid_query(event_loop): +async def test_httpx_invalid_query(): from gql.transport.httpx import HTTPXAsyncTransport @@ -85,7 +85,7 @@ async def test_httpx_invalid_query(event_loop): @pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio -async def test_httpx_two_queries_in_parallel_using_two_tasks(event_loop): +async def test_httpx_two_queries_in_parallel_using_two_tasks(): from gql.transport.httpx import HTTPXAsyncTransport diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index c042ce01..2a312d71 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -161,7 +161,7 @@ async def no_connection_ack_phoenix_server(ws): indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query_protocol_error(event_loop, server, query_str): +async def test_phoenix_channel_query_protocol_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -191,7 +191,7 @@ async def test_phoenix_channel_query_protocol_error(event_loop, server, query_st indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query_error(event_loop, server, query_str): +async def test_phoenix_channel_query_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -407,9 +407,7 @@ async def phoenix_server(ws): indirect=True, ) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_subscription_protocol_error( - event_loop, server, query_str -): +async def test_phoenix_channel_subscription_protocol_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -439,7 +437,7 @@ async def test_phoenix_channel_subscription_protocol_error( indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_server_error(event_loop, server, query_str): +async def test_phoenix_channel_server_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -468,7 +466,7 @@ async def test_phoenix_channel_server_error(event_loop, server, query_str): indirect=True, ) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): +async def test_phoenix_channel_unsubscribe_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -498,7 +496,7 @@ async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): indirect=True, ) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_unsubscribe_error_forcing(event_loop, server, query_str): +async def test_phoenix_channel_unsubscribe_error_forcing(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 56d28875..621f648e 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -52,7 +52,7 @@ async def query_server(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query(event_loop, server, query_str): +async def test_phoenix_channel_query(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -75,7 +75,7 @@ async def test_phoenix_channel_query(event_loop, server, query_str): @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query_ssl(event_loop, ws_ssl_server, query_str): +async def test_phoenix_channel_query_ssl(ws_ssl_server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -108,7 +108,7 @@ async def test_phoenix_channel_query_ssl(event_loop, ws_ssl_server, query_str): @pytest.mark.parametrize("query_str", [query1_str]) @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_phoenix_channel_query_ssl_self_cert_fail( - event_loop, ws_ssl_server, query_str, verify_https + ws_ssl_server, query_str, verify_https ): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -204,7 +204,7 @@ async def subscription_server(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [subscription_server], indirect=True) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_subscription(event_loop, server, query_str): +async def test_phoenix_channel_subscription(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 35ca665b..25ca0f0b 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -173,9 +173,7 @@ async def stopping_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("end_count", [0, 5]) -async def test_phoenix_channel_subscription( - event_loop, server, subscription_str, end_count -): +async def test_phoenix_channel_subscription(server, subscription_str, end_count): """Parameterized test. :param end_count: Target count at which the test will 'break' to unsubscribe. @@ -223,9 +221,7 @@ async def test_phoenix_channel_subscription( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_phoenix_channel_subscription_no_break( - event_loop, server, subscription_str -): +async def test_phoenix_channel_subscription_no_break(server, subscription_str): import logging from gql.transport.phoenix_channel_websockets import ( @@ -369,7 +365,7 @@ async def heartbeat_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [phoenix_heartbeat_server], indirect=True) @pytest.mark.parametrize("subscription_str", [heartbeat_subscription_str]) -async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): +async def test_phoenix_channel_heartbeat(server, subscription_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) diff --git a/tests/test_requests.py b/tests/test_requests.py index 95db0b3f..8f3b0b7a 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -40,7 +40,7 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query(event_loop, aiohttp_server, run_sync_test): +async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -78,15 +78,13 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_requests_query_https( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https -): +async def test_requests_query_https(ssl_aiohttp_server, run_sync_test, verify_https): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport import warnings @@ -142,14 +140,14 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_requests_query_https_self_cert_fail( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https + ssl_aiohttp_server, run_sync_test, verify_https ): """By default, we should verify the ssl certificate""" from aiohttp import web @@ -192,12 +190,12 @@ def test_code(): assert expected_error in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): +async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -229,12 +227,12 @@ def test_code(): assert africa["code"] == "AF" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -264,12 +262,12 @@ def test_code(): assert "401 Client Error: Unauthorized" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -318,7 +316,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -342,7 +340,7 @@ def test_code(): with pytest.raises(TransportServerError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' @@ -350,7 +348,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -375,7 +373,7 @@ def test_code(): with pytest.raises(TransportQueryError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) invalid_protocol_responses = [ @@ -388,9 +386,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_requests_invalid_protocol( - event_loop, aiohttp_server, response, run_sync_test -): +async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -413,12 +409,12 @@ def test_code(): with pytest.raises(TransportProtocolError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): +async def test_requests_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -439,14 +435,12 @@ def test_code(): with pytest.raises(TransportAlreadyConnected): session.transport.connect() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cannot_execute_if_not_connected( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -467,7 +461,7 @@ def test_code(): with pytest.raises(TransportClosed): transport.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_answer_with_extensions = ( @@ -483,9 +477,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query_with_extensions( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -511,7 +503,7 @@ def test_code(): assert execution_result.extensions["key1"] == "val1" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_server_answer = '{"data":{"success":true}}' @@ -540,7 +532,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_requests_file_upload(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -595,14 +587,12 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_with_content_type( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_file_upload_with_content_type(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -663,14 +653,12 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_additional_headers( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -727,12 +715,12 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_binary_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_requests_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -793,7 +781,7 @@ def test_code(): assert execution_result.data["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_mutation_2_operations = ( @@ -805,9 +793,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_two_files( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -893,7 +879,7 @@ def test_code(): f1.close() f2.close() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) file_upload_mutation_3_operations = ( @@ -905,9 +891,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_list_of_two_files( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -991,14 +975,12 @@ def test_code(): f1.close() f2.close() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_fetching_schema( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -1040,14 +1022,12 @@ def test_code(): assert expected_error in str(exc_info.value) assert transport.session is None - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_json_serializer( - event_loop, aiohttp_server, run_sync_test, caplog -): +async def test_requests_json_serializer(aiohttp_server, run_sync_test, caplog): import json from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -1091,7 +1071,7 @@ def test_code(): expected_log = '"query":"query getContinents' assert expected_log in caplog.text - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query_float_str = """ @@ -1107,7 +1087,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_test): +async def test_requests_json_deserializer(aiohttp_server, run_sync_test): import json from aiohttp import web from decimal import Decimal @@ -1146,4 +1126,4 @@ def test_code(): assert pi == Decimal("3.141592653589793238462643383279502884197") - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 4d8bf27e..dbd3dfa5 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -48,7 +48,7 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query(event_loop, aiohttp_server, run_sync_test): +async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -86,14 +86,12 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query_auto_batch_enabled( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -134,13 +132,13 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_query_auto_batch_enabled_two_requests( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -194,12 +192,12 @@ def test_thread(): for thread in threads: thread.join() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): +async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -233,12 +231,12 @@ def test_code(): assert africa["code"] == "AF" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -268,13 +266,13 @@ def test_code(): assert "401 Client Error: Unauthorized" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_401_auto_batch_enabled( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -308,12 +306,12 @@ def test_code(): assert "401 Client Error: Unauthorized" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -362,7 +360,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -386,7 +384,7 @@ def test_code(): with pytest.raises(TransportServerError): session.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' @@ -394,7 +392,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -419,7 +417,7 @@ def test_code(): with pytest.raises(TransportQueryError): session.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) invalid_protocol_responses = [ @@ -437,9 +435,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_requests_invalid_protocol( - event_loop, aiohttp_server, response, run_sync_test -): +async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -462,14 +458,12 @@ def test_code(): with pytest.raises(TransportProtocolError): session.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cannot_execute_if_not_connected( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -492,7 +486,7 @@ def test_code(): with pytest.raises(TransportClosed): transport.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_answer_with_extensions_list = ( @@ -508,9 +502,7 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query_with_extensions( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -537,7 +529,7 @@ def test_code(): assert execution_results[0].extensions["key1"] == "val1" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) ONLINE_URL = "https://countries.trevorblades.com/" diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 68b2fe52..9c43965f 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -42,7 +42,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_invalid_query(event_loop, client_and_server, query_str): +async def test_websocket_invalid_query(client_and_server, query_str): session, server = client_and_server @@ -81,7 +81,7 @@ async def server_invalid_subscription(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) -async def test_websocket_invalid_subscription(event_loop, client_and_server, query_str): +async def test_websocket_invalid_subscription(client_and_server, query_str): session, server = client_and_server @@ -113,7 +113,7 @@ async def server_no_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_server_does_not_send_ack(event_loop, server, query_str): +async def test_websocket_server_does_not_send_ack(server, query_str): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -136,7 +136,7 @@ async def server_connection_error(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_connection_error], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_data(event_loop, client_and_server, query_str): +async def test_websocket_sending_invalid_data(client_and_server, query_str): session, server = client_and_server @@ -164,9 +164,7 @@ async def server_invalid_payload(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_payload( - event_loop, client_and_server, query_str -): +async def test_websocket_sending_invalid_payload(client_and_server, query_str): session, server = client_and_server @@ -235,7 +233,7 @@ async def monkey_patch_send_query( ], indirect=True, ) -async def test_websocket_transport_protocol_errors(event_loop, client_and_server): +async def test_websocket_transport_protocol_errors(client_and_server): session, server = client_and_server @@ -253,7 +251,7 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_without_ack], indirect=True) -async def test_websocket_server_does_not_ack(event_loop, server): +async def test_websocket_server_does_not_ack(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -272,7 +270,7 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) -async def test_websocket_server_closing_directly(event_loop, server): +async def test_websocket_server_closing_directly(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -292,7 +290,7 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) -async def test_websocket_server_closing_after_ack(event_loop, client_and_server): +async def test_websocket_server_closing_after_ack(client_and_server): session, server = client_and_server @@ -319,7 +317,7 @@ async def server_sending_invalid_query_errors(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_websocket_server_sending_invalid_query_errors(event_loop, server): +async def test_websocket_server_sending_invalid_query_errors(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -334,7 +332,7 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server) @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_websocket_non_regression_bug_105(event_loop, server): +async def test_websocket_non_regression_bug_105(server): from gql.transport.websockets import WebsocketsTransport # This test will check a fix to a race condition which happens if the user is trying @@ -363,9 +361,7 @@ async def client_connect(client): @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) -async def test_websocket_using_cli_invalid_query( - event_loop, server, monkeypatch, capsys -): +async def test_websocket_using_cli_invalid_query(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index b1e3c07a..919f6bdb 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -51,7 +51,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_starting_client_in_context_manager(event_loop, server): +async def test_websocket_starting_client_in_context_manager(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -91,7 +91,7 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): +async def test_websocket_using_ssl_connection(ws_ssl_server): import websockets from gql.transport.websockets import WebsocketsTransport @@ -136,7 +136,7 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_websocket_using_ssl_connection_self_cert_fail( - event_loop, ws_ssl_server, verify_https + ws_ssl_server, verify_https ): from gql.transport.websockets import WebsocketsTransport from ssl import SSLCertVerificationError @@ -178,7 +178,7 @@ async def test_websocket_using_ssl_connection_self_cert_fail( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_simple_query(event_loop, client_and_server, query_str): +async def test_websocket_simple_query(client_and_server, query_str): session, server = client_and_server @@ -198,9 +198,7 @@ async def test_websocket_simple_query(event_loop, client_and_server, query_str): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_series( - event_loop, client_and_server, query_str -): +async def test_websocket_two_queries_in_series(client_and_server, query_str): session, server = client_and_server @@ -234,9 +232,7 @@ async def server1_two_queries_in_parallel(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_parallel( - event_loop, client_and_server, query_str -): +async def test_websocket_two_queries_in_parallel(client_and_server, query_str): session, server = client_and_server @@ -281,9 +277,7 @@ async def server_closing_while_we_are_doing_something_else(ws): "server", [server_closing_while_we_are_doing_something_else], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_server_closing_after_first_query( - event_loop, client_and_server, query_str -): +async def test_websocket_server_closing_after_first_query(client_and_server, query_str): session, server = client_and_server @@ -311,7 +305,7 @@ async def test_websocket_server_closing_after_first_query( @pytest.mark.asyncio @pytest.mark.parametrize("server", [ignore_invalid_id_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_ignore_invalid_id(event_loop, client_and_server, query_str): +async def test_websocket_ignore_invalid_id(client_and_server, query_str): session, server = client_and_server @@ -346,7 +340,7 @@ async def assert_client_is_working(session): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_multiple_connections_in_series(event_loop, server): +async def test_websocket_multiple_connections_in_series(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -369,7 +363,7 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_multiple_connections_in_parallel(event_loop, server): +async def test_websocket_multiple_connections_in_parallel(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -388,9 +382,7 @@ async def task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_trying_to_connect_to_already_connected_transport( - event_loop, server -): +async def test_websocket_trying_to_connect_to_already_connected_transport(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -437,7 +429,7 @@ async def server_with_authentication_in_connection_init_payload(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_connect_success_with_authentication_in_connection_init( - event_loop, server, query_str + server, query_str ): from gql.transport.websockets import WebsocketsTransport @@ -472,7 +464,7 @@ async def test_websocket_connect_success_with_authentication_in_connection_init( @pytest.mark.parametrize("query_str", [query1_str]) @pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) async def test_websocket_connect_failed_with_authentication_in_connection_init( - event_loop, server, query_str, init_payload + server, query_str, init_payload ): from gql.transport.websockets import WebsocketsTransport @@ -534,7 +526,7 @@ def test_websocket_execute_sync(server): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_add_extra_parameters_to_connect(event_loop, server): +async def test_websocket_add_extra_parameters_to_connect(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -566,9 +558,7 @@ async def server_sending_keep_alive_before_connection_ack(ws): "server", [server_sending_keep_alive_before_connection_ack], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_non_regression_bug_108( - event_loop, client_and_server, query_str -): +async def test_websocket_non_regression_bug_108(client_and_server, query_str): # This test will check that we now ignore keepalive message # arriving before the connection_ack @@ -590,7 +580,7 @@ async def test_websocket_non_regression_bug_108( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): +async def test_websocket_using_cli(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") @@ -641,9 +631,7 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_simple_query_with_extensions( - event_loop, client_and_server, query_str -): +async def test_websocket_simple_query_with_extensions(client_and_server, query_str): session, server = client_and_server @@ -656,7 +644,7 @@ async def test_websocket_simple_query_with_extensions( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_adapter_connection_closed(event_loop, server): +async def test_websocket_adapter_connection_closed(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -680,7 +668,7 @@ async def test_websocket_adapter_connection_closed(event_loop, server): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_transport_closed_in_receive(event_loop, server): +async def test_websocket_transport_closed_in_receive(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 6f291218..a020e1f5 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -126,7 +126,7 @@ async def keepalive_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription(event_loop, client_and_server, subscription_str): +async def test_websocket_subscription(client_and_server, subscription_str): session, server = client_and_server @@ -148,7 +148,7 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_get_execution_result( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -172,9 +172,7 @@ async def test_websocket_subscription_get_execution_result( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_break( - event_loop, client_and_server, subscription_str -): +async def test_websocket_subscription_break(client_and_server, subscription_str): session, server = client_and_server @@ -203,9 +201,7 @@ async def test_websocket_subscription_break( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_task_cancel( - event_loop, client_and_server, subscription_str -): +async def test_websocket_subscription_task_cancel(client_and_server, subscription_str): session, server = client_and_server @@ -243,7 +239,7 @@ async def cancel_task_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_close_transport( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -308,7 +304,7 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_server_connection_closed( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -331,7 +327,7 @@ async def test_websocket_subscription_server_connection_closed( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_slow_consumer( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -356,7 +352,7 @@ async def test_websocket_subscription_slow_consumer( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_operation_name( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -387,7 +383,7 @@ async def test_websocket_subscription_with_operation_name( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -410,7 +406,7 @@ async def test_websocket_subscription_with_keepalive( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive_with_timeout_ok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -447,7 +443,7 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive_with_timeout_nok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -623,7 +619,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_running_in_thread( - event_loop, server, subscription_str, run_sync_test + server, subscription_str, run_sync_test ): from gql.transport.websockets import WebsocketsTransport @@ -647,4 +643,4 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py index 85fbf00a..f070f497 100644 --- a/tests/test_websockets_adapter.py +++ b/tests/test_websockets_adapter.py @@ -33,7 +33,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websockets_adapter_simple_query(event_loop, server): +async def test_websockets_adapter_simple_query(server): from gql.transport.common.adapters.websockets import WebSocketsAdapter url = f"ws://{server.hostname}:{server.port}/graphql" @@ -65,7 +65,7 @@ async def test_websockets_adapter_simple_query(event_loop, server): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websockets_adapter_edge_cases(event_loop, server): +async def test_websockets_adapter_edge_cases(server): from gql.transport.common.adapters.websockets import WebSocketsAdapter url = f"ws://{server.hostname}:{server.port}/graphql" From 5a9f98360d8f1717d91117aad8e9519411dfbf82 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 11 Mar 2025 21:34:01 +0100 Subject: [PATCH 203/239] Set ssl=True by default for AIOHTTPTransport (#538) --- docs/code_examples/fastapi_async.py | 1 + docs/code_examples/httpx_async_trio.py | 1 + gql/transport/aiohttp.py | 33 ++++---------------------- tests/test_aiohttp.py | 11 +++------ 4 files changed, 10 insertions(+), 36 deletions(-) diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 80920252..3bedd187 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,6 +10,7 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse + from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport diff --git a/docs/code_examples/httpx_async_trio.py b/docs/code_examples/httpx_async_trio.py index b76dab42..058b952b 100644 --- a/docs/code_examples/httpx_async_trio.py +++ b/docs/code_examples/httpx_async_trio.py @@ -1,4 +1,5 @@ import trio + from gql import Client, gql from gql.transport.httpx import HTTPXAsyncTransport diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index c1302794..b581e311 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -2,19 +2,8 @@ import io import json import logging -import warnings from ssl import SSLContext -from typing import ( - Any, - AsyncGenerator, - Callable, - Dict, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -57,7 +46,7 @@ def __init__( headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, - ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning", + ssl: Union[SSLContext, bool, Fingerprint] = True, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, @@ -71,7 +60,8 @@ def __init__( :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed Or Appsync Authentication class - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param ssl: ssl_context of the connection. + Use ssl=False to not verify ssl certificates. :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly :param json_serialize: Json serializer callable. @@ -88,20 +78,7 @@ def __init__( self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth - - if ssl == "ssl_warning": - ssl = False - if str(url).startswith("https"): - warnings.warn( - "WARNING: By default, AIOHTTPTransport does not verify" - " ssl certificates. This will be fixed in the next major version." - " You can set ssl=True to force the ssl certificate verification" - " or ssl=False to disable this warning" - ) - - self.ssl: Union[SSLContext, bool, Fingerprint] = cast( - Union[SSLContext, bool, Fingerprint], ssl - ) + self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 88c4db98..e843db6c 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1325,7 +1325,6 @@ async def handler(request): assert africa["code"] == "AF" -@pytest.mark.skip(reason="We will change the default to fix this in a future version") @pytest.mark.asyncio async def test_aiohttp_query_https_self_cert_fail(ssl_aiohttp_server): """By default, we should verify the ssl certificate""" @@ -1360,7 +1359,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_https_self_cert_warn(ssl_aiohttp_server): +async def test_aiohttp_query_https_self_cert_default(ssl_aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport @@ -1375,13 +1374,9 @@ async def handler(request): assert str(url).startswith("https://") - expected_warning = ( - "WARNING: By default, AIOHTTPTransport does not verify ssl certificates." - " This will be fixed in the next major version." - ) + transport = AIOHTTPTransport(url=url) - with pytest.warns(Warning, match=expected_warning): - AIOHTTPTransport(url=url, timeout=10) + assert transport.ssl is True @pytest.mark.asyncio From c6937eb1a2834e68106ac57f8070f45731b14051 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 11 Mar 2025 22:23:10 +0100 Subject: [PATCH 204/239] Chore bump test dependencies (#539) * Bump parse to 1.20.2 * Bump pytest-cov to 6.0.0 --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6b4c1fd2..e8be1ef6 100644 --- a/setup.py +++ b/setup.py @@ -14,11 +14,11 @@ ] tests_requires = [ - "parse==1.15.0", + "parse==1.20.2", "pytest==8.3.4", "pytest-asyncio==0.25.3", "pytest-console-scripts==1.4.1", - "pytest-cov==5.0.0", + "pytest-cov==6.0.0", "vcrpy==7.0.0", "aiofiles", ] From b1c976d9a2caa65e8178512e85927e0f1f5b094b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 12 Mar 2025 14:49:58 +0100 Subject: [PATCH 205/239] Chore Bumping all the dev dependencies to latest versions (#540) * Bump black to 25.1.0 and run make check * Bump flake8 to 7.1.2 * Bump isort to 6.0.1 * Running make check with new isort * Bump mypy to 1.15 * Moving mypy config to pyproject.toml * Fix new mypy errors * mypy disallow_incomplete_defs = true * Bump sphinx-argparse to 0.5.2 * Fix Sphinx make docs warnings * Fix moved intersphinx url for yarl and multidict --- Makefile | 2 +- docs/code_examples/aiohttp_async_dsl.py | 2 + docs/code_examples/console_async.py | 9 +- docs/code_examples/fastapi_async.py | 2 + docs/conf.py | 6 +- docs/modules/gql.rst | 6 +- .../transport_common_adapters_aiohttp.rst | 7 + .../transport_common_adapters_connection.rst | 7 + .../transport_common_adapters_websockets.rst | 7 + docs/modules/transport_common_base.rst | 7 + docs/modules/transport_websockets_base.rst | 7 - .../modules/transport_websockets_protocol.rst | 7 + gql/cli.py | 3 +- gql/client.py | 242 ++++++++---------- gql/dsl.py | 15 +- gql/transport/aiohttp.py | 22 +- gql/transport/appsync_websockets.py | 4 +- gql/transport/common/adapters/aiohttp.py | 4 +- gql/transport/common/base.py | 12 +- gql/transport/httpx.py | 9 +- gql/transport/local_schema.py | 10 +- gql/transport/phoenix_channel_websockets.py | 4 +- gql/transport/requests.py | 27 +- gql/transport/transport.py | 12 +- gql/transport/websockets_protocol.py | 2 +- gql/utilities/node_tree.py | 4 +- gql/utilities/parse_result.py | 4 +- gql/utilities/update_schema_enum.py | 2 +- gql/utilities/update_schema_scalars.py | 8 +- gql/utils.py | 12 +- pyproject.toml | 10 + setup.cfg | 11 - setup.py | 10 +- tests/conftest.py | 28 +- tests/custom_scalars/test_enum_colors.py | 39 ++- tests/custom_scalars/test_money.py | 20 +- tests/fixtures/aws/fake_signer.py | 4 +- .../test_dsl_directives.py | 5 + tests/starwars/fixtures.py | 3 +- tests/starwars/schema.py | 6 +- tests/starwars/test_dsl.py | 17 +- tests/starwars/test_parse_results.py | 4 +- tests/starwars/test_query.py | 2 +- tests/starwars/test_validation.py | 6 +- tests/test_aiohttp.py | 46 +++- tests/test_aiohttp_websocket_exceptions.py | 2 +- ...iohttp_websocket_graphqlws_subscription.py | 4 +- tests/test_aiohttp_websocket_query.py | 9 +- tests/test_aiohttp_websocket_subscription.py | 4 + tests/test_appsync_auth.py | 11 +- tests/test_appsync_http.py | 4 +- tests/test_appsync_websockets.py | 15 +- tests/test_cli.py | 4 +- tests/test_client.py | 38 ++- tests/test_graphqlws_subscription.py | 4 +- tests/test_httpx.py | 25 +- tests/test_httpx_async.py | 45 +++- tests/test_phoenix_channel_exceptions.py | 7 +- tests/test_phoenix_channel_query.py | 3 +- tests/test_requests.py | 36 ++- tests/test_requests_batch.py | 16 +- tests/test_transport.py | 1 + tests/test_transport_batch.py | 1 + tests/test_websocket_exceptions.py | 5 +- tests/test_websocket_online.py | 35 +-- tests/test_websocket_query.py | 12 +- tests/test_websocket_subscription.py | 4 + tests/test_websockets_adapter.py | 6 +- tox.ini | 2 +- 69 files changed, 621 insertions(+), 348 deletions(-) create mode 100644 docs/modules/transport_common_adapters_aiohttp.rst create mode 100644 docs/modules/transport_common_adapters_connection.rst create mode 100644 docs/modules/transport_common_adapters_websockets.rst create mode 100644 docs/modules/transport_common_base.rst delete mode 100644 docs/modules/transport_websockets_base.rst create mode 100644 docs/modules/transport_websockets_protocol.rst diff --git a/Makefile b/Makefile index 59d08bac..9af372f7 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ tests_websockets: pytest tests --websockets-only check: - isort --recursive $(SRC_PYTHON) + isort $(SRC_PYTHON) black $(SRC_PYTHON) flake8 $(SRC_PYTHON) mypy $(SRC_PYTHON) diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py index 958ea490..2c4804db 100644 --- a/docs/code_examples/aiohttp_async_dsl.py +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -17,6 +17,8 @@ async def main(): # GQL will fetch the schema just after the establishment of the first session async with client as session: + assert client.schema is not None + # Instantiate the root of the DSL Schema as ds ds = DSLSchema(client.schema) diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 9a5e94e5..6c0b86d0 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -1,8 +1,11 @@ import asyncio import logging +from typing import Optional from aioconsole import ainput + from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.aiohttp import AIOHTTPTransport logging.basicConfig(level=logging.INFO) @@ -21,7 +24,7 @@ def __init__(self): self._client = Client( transport=AIOHTTPTransport(url="https://countries.trevorblades.com/") ) - self._session = None + self._session: Optional[AsyncClientSession] = None self.get_continent_name_query = gql(GET_CONTINENT_NAME) @@ -34,11 +37,13 @@ async def close(self): async def get_continent_name(self, code): params = {"code": code} + assert self._session is not None + answer = await self._session.execute( self.get_continent_name_query, variable_values=params ) - return answer.get("continent").get("name") + return answer.get("continent").get("name") # type: ignore async def main(): diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 3bedd187..f4a5c14b 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -12,6 +12,7 @@ from fastapi.responses import HTMLResponse from gql import Client, gql +from gql.client import ReconnectingAsyncClientSession from gql.transport.aiohttp import AIOHTTPTransport logging.basicConfig(level=logging.DEBUG) @@ -91,6 +92,7 @@ async def get_continent(continent_code): raise HTTPException(status_code=404, detail="Continent not found") try: + assert isinstance(client.session, ReconnectingAsyncClientSession) result = await client.session.execute( query, variable_values={"code": continent_code} ) diff --git a/docs/conf.py b/docs/conf.py index 94daf942..8289ef4b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,11 +83,11 @@ intersphinx_mapping = { 'aiohttp': ('https://docs.aiohttp.org/en/stable/', None), 'graphql': ('https://graphql-core-3.readthedocs.io/en/latest/', None), - 'multidict': ('https://multidict.readthedocs.io/en/stable/', None), + 'multidict': ('https://multidict.aio-libs.org/en/stable/', None), 'python': ('https://docs.python.org/3/', None), 'requests': ('https://requests.readthedocs.io/en/latest/', None), 'websockets': ('https://websockets.readthedocs.io/en/11.0.3/', None), - 'yarl': ('https://yarl.readthedocs.io/en/stable/', None), + 'yarl': ('https://yarl.aio-libs.org/en/stable/', None), } nitpick_ignore = [ @@ -100,6 +100,8 @@ ('py:class', 'asyncio.locks.Event'), # aiohttp: should be fixed + # See issue: https://github.com/aio-libs/aiohttp/issues/10468 + ('py:class', 'aiohttp.client.ClientSession'), ('py:class', 'aiohttp.client_reqrep.Fingerprint'), ('py:class', 'aiohttp.helpers.BasicAuth'), diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index b7c13c7c..035f196f 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -24,11 +24,15 @@ Sub-Packages transport_aiohttp_websockets transport_appsync_auth transport_appsync_websockets + transport_common_base + transport_common_adapters_connection + transport_common_adapters_aiohttp + transport_common_adapters_websockets transport_exceptions transport_phoenix_channel_websockets transport_requests transport_httpx transport_websockets - transport_websockets_base + transport_websockets_protocol dsl utilities diff --git a/docs/modules/transport_common_adapters_aiohttp.rst b/docs/modules/transport_common_adapters_aiohttp.rst new file mode 100644 index 00000000..537c8673 --- /dev/null +++ b/docs/modules/transport_common_adapters_aiohttp.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.aiohttp +===================================== + +.. currentmodule:: gql.transport.common.adapters.aiohttp + +.. automodule:: gql.transport.common.adapters.aiohttp + :member-order: bysource diff --git a/docs/modules/transport_common_adapters_connection.rst b/docs/modules/transport_common_adapters_connection.rst new file mode 100644 index 00000000..ffa1a1b3 --- /dev/null +++ b/docs/modules/transport_common_adapters_connection.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.connection +======================================== + +.. currentmodule:: gql.transport.common.adapters.connection + +.. automodule:: gql.transport.common.adapters.connection + :member-order: bysource diff --git a/docs/modules/transport_common_adapters_websockets.rst b/docs/modules/transport_common_adapters_websockets.rst new file mode 100644 index 00000000..4005694c --- /dev/null +++ b/docs/modules/transport_common_adapters_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.websockets +======================================== + +.. currentmodule:: gql.transport.common.adapters.websockets + +.. automodule:: gql.transport.common.adapters.websockets + :member-order: bysource diff --git a/docs/modules/transport_common_base.rst b/docs/modules/transport_common_base.rst new file mode 100644 index 00000000..4a7ec15a --- /dev/null +++ b/docs/modules/transport_common_base.rst @@ -0,0 +1,7 @@ +gql.transport.common.base +========================= + +.. currentmodule:: gql.transport.common.base + +.. automodule:: gql.transport.common.base + :member-order: bysource diff --git a/docs/modules/transport_websockets_base.rst b/docs/modules/transport_websockets_base.rst deleted file mode 100644 index 548351eb..00000000 --- a/docs/modules/transport_websockets_base.rst +++ /dev/null @@ -1,7 +0,0 @@ -gql.transport.websockets_base -============================= - -.. currentmodule:: gql.transport.websockets_base - -.. automodule:: gql.transport.websockets_base - :member-order: bysource diff --git a/docs/modules/transport_websockets_protocol.rst b/docs/modules/transport_websockets_protocol.rst new file mode 100644 index 00000000..b835abee --- /dev/null +++ b/docs/modules/transport_websockets_protocol.rst @@ -0,0 +1,7 @@ +gql.transport.websockets_protocol +================================= + +.. currentmodule:: gql.transport.websockets_protocol + +.. automodule:: gql.transport.websockets_protocol + :member-order: bysource diff --git a/gql/cli.py b/gql/cli.py index 91c67873..9ae92e83 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -391,9 +391,10 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt) else: - from gql.transport.appsync_auth import AppSyncIAMAuthentication from botocore.exceptions import NoRegionError + from gql.transport.appsync_auth import AppSyncIAMAuthentication + try: auth = AppSyncIAMAuthentication(host=url.host) except NoRegionError: diff --git a/gql/client.py b/gql/client.py index c52a00b2..faf3230a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -131,7 +131,10 @@ def __init__( self.introspection: Optional[IntrospectionQuery] = introspection # GraphQL transport chosen - self.transport: Optional[Union[Transport, AsyncTransport]] = transport + assert ( + transport is not None + ), "You need to provide either a transport or a schema to the Client." + self.transport: Union[Transport, AsyncTransport] = transport # Flag to indicate that we need to fetch the schema from the transport # On async transports, we fetch the schema before executing the first query @@ -149,10 +152,10 @@ def __init__( self.batch_max = batch_max @property - def batching_enabled(self): + def batching_enabled(self) -> bool: return self.batch_interval != 0 - def validate(self, document: DocumentNode): + def validate(self, document: DocumentNode) -> None: """:meta private:""" assert ( self.schema @@ -162,7 +165,9 @@ def validate(self, document: DocumentNode): if validation_errors: raise validation_errors[0] - def _build_schema_from_introspection(self, execution_result: ExecutionResult): + def _build_schema_from_introspection( + self, execution_result: ExecutionResult + ) -> None: if execution_result.errors: raise TransportQueryError( ( @@ -189,9 +194,8 @@ def execute_sync( parse_result: Optional[bool] = ..., *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute_sync( @@ -203,9 +207,8 @@ def execute_sync( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute_sync( @@ -217,9 +220,8 @@ def execute_sync( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute_sync( self, @@ -229,7 +231,7 @@ def execute_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" with self as session: @@ -251,9 +253,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch_sync( @@ -263,9 +264,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch_sync( @@ -275,9 +275,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch_sync( self, @@ -286,7 +285,7 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """:meta private:""" with self as session: @@ -308,9 +307,8 @@ async def execute_async( parse_result: Optional[bool] = ..., *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute_async( @@ -322,9 +320,8 @@ async def execute_async( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute_async( @@ -336,9 +333,8 @@ async def execute_async( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute_async( self, @@ -348,7 +344,7 @@ async def execute_async( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" async with self as session: @@ -372,9 +368,8 @@ def execute( parse_result: Optional[bool] = ..., *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( @@ -386,9 +381,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( @@ -400,9 +394,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, @@ -412,7 +405,7 @@ def execute( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """Execute the provided document AST against the remote server using the transport provided during init. @@ -487,9 +480,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -499,9 +491,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -511,9 +502,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -522,7 +512,7 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """Execute multiple GraphQL requests in a batch against the remote server using the transport provided during init. @@ -568,9 +558,8 @@ def subscribe_async( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe_async( @@ -582,9 +571,8 @@ def subscribe_async( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe_async( @@ -596,11 +584,10 @@ def subscribe_async( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe_async( self, @@ -610,7 +597,7 @@ async def subscribe_async( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: @@ -639,9 +626,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Generator[Dict[str, Any], None, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover @overload def subscribe( @@ -653,9 +639,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> Generator[ExecutionResult, None, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover @overload def subscribe( @@ -667,11 +652,10 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover def subscribe( self, @@ -682,7 +666,7 @@ def subscribe( parse_result: Optional[bool] = None, *, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] ]: @@ -770,6 +754,8 @@ async def connect_async(self, reconnecting=False, **kwargs): self.transport, AsyncTransport ), "Only a transport of type AsyncTransport can be used asynchronously" + self.session: Union[AsyncClientSession, SyncClientSession] + if reconnecting: self.session = ReconnectingAsyncClientSession(client=self, **kwargs) await self.session.start_connecting_task() @@ -825,6 +811,8 @@ def connect_sync(self): if not hasattr(self, "session"): self.session = SyncClientSession(client=self) + assert isinstance(self.session, SyncClientSession) + self.session.connect() # Get schema from transport if needed @@ -846,6 +834,8 @@ def close_sync(self): If batching is enabled, this will block until the remaining queries in the batching queue have been processed. """ + assert isinstance(self.session, SyncClientSession) + self.session.close() def __enter__(self): @@ -873,7 +863,7 @@ def _execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Execute the provided document AST synchronously using the sync transport, returning an ExecutionResult object. @@ -944,9 +934,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( @@ -958,9 +947,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( @@ -972,9 +960,8 @@ def execute( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, @@ -984,7 +971,7 @@ def execute( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """Execute the provided document AST synchronously using the sync transport. @@ -1040,7 +1027,7 @@ def _execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, validate_document: Optional[bool] = True, - **kwargs, + **kwargs: Any, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch, using the sync transport, returning a list of ExecutionResult objects. @@ -1067,9 +1054,11 @@ def _execute_batch( serialize_variables is None and self.client.serialize_variables ): requests = [ - req.serialize_variable_values(self.client.schema) - if req.variable_values is not None - else req + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) for req in requests ] @@ -1096,9 +1085,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -1108,9 +1096,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -1120,9 +1107,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -1131,7 +1117,7 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """Execute multiple GraphQL requests in a batch, using the sync transport. This method sends the requests to the server all at once. @@ -1312,7 +1298,7 @@ async def _subscribe( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Coroutine to subscribe asynchronously to the provided document AST asynchronously using the async transport, @@ -1349,13 +1335,13 @@ async def _subscribe( ) # Subscribe to the transport - inner_generator: AsyncGenerator[ - ExecutionResult, None - ] = self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, + inner_generator: AsyncGenerator[ExecutionResult, None] = ( + self.transport.subscribe( + document, + variable_values=variable_values, + operation_name=operation_name, + **kwargs, + ) ) # Keep a reference to the inner generator @@ -1390,9 +1376,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe( @@ -1404,9 +1389,8 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe( @@ -1418,11 +1402,10 @@ def subscribe( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe( self, @@ -1432,7 +1415,7 @@ async def subscribe( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: @@ -1491,7 +1474,7 @@ async def _execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Coroutine to execute the provided document AST asynchronously using the async transport, returning an ExecutionResult object. @@ -1557,9 +1540,8 @@ async def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute( @@ -1571,9 +1553,8 @@ async def execute( parse_result: Optional[bool] = ..., *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute( @@ -1585,9 +1566,8 @@ async def execute( parse_result: Optional[bool] = ..., *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute( self, @@ -1597,7 +1577,7 @@ async def execute( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """Coroutine to execute the provided document AST asynchronously using the async transport. @@ -1775,7 +1755,7 @@ async def _execute_once( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent method _execute but requesting a reconnection if we receive a TransportClosed exception. @@ -1803,7 +1783,7 @@ async def _execute( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent, but with optional retries and requesting a reconnection if we receive a TransportClosed exception. @@ -1825,7 +1805,7 @@ async def _subscribe( operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Same Async generator as parent method _subscribe but requesting a reconnection if we receive a TransportClosed exception. diff --git a/gql/dsl.py b/gql/dsl.py index be2b5a7e..e5b5131e 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -2,6 +2,7 @@ .. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa :alt: UML diagram """ + import logging import re from abc import ABC, abstractmethod @@ -338,7 +339,7 @@ def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", - ): + ) -> Any: r"""Select the fields which should be added. :param \*fields: fields or fragments @@ -595,9 +596,11 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: VariableDefinitionNode( type=var.ast_variable_type, variable=var.ast_variable_name, - default_value=None - if var.default_value is None - else ast_from_value(var.default_value, var.type), + default_value=( + None + if var.default_value is None + else ast_from_value(var.default_value, var.type) + ), directives=(), ) for var in self.variables.values() @@ -836,10 +839,10 @@ def name(self): """:meta private:""" return self.ast_field.name.value - def __call__(self, **kwargs) -> "DSLField": + def __call__(self, **kwargs: Any) -> "DSLField": return self.args(**kwargs) - def args(self, **kwargs) -> "DSLField": + def args(self, **kwargs: Any) -> "DSLField": r"""Set the arguments of a field The arguments are parsed to be stored in the AST of this field. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index b581e311..76b46c35 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -3,7 +3,17 @@ import json import logging from ssl import SSLContext -from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -102,9 +112,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": None - if isinstance(self.auth, AppSyncAuthentication) - else self.auth, + "auth": ( + None if isinstance(self.auth, AppSyncAuthentication) else self.auth + ), "json_serialize": self.json_serialize, } @@ -262,7 +272,9 @@ async def execute( # Saving latest response headers in the transport self.response_headers = resp.headers - async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): + async def raise_response_error( + resp: aiohttp.ClientResponse, reason: str + ) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index f35cefe5..a6a7d180 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -29,7 +29,7 @@ class AppSyncWebsocketsTransport(SubscriptionTransportBase): on a websocket connection. """ - auth: Optional[AppSyncAuthentication] + auth: AppSyncAuthentication def __init__( self, @@ -72,7 +72,7 @@ def __init__( # May raise NoRegionError or NoCredentialsError or ImportError auth = AppSyncIAMAuthentication(host=host, session=session) - self.auth = auth + self.auth: AppSyncAuthentication = auth self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.init_payload: Dict[str, Any] = {} diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index f2dff699..736f2a3e 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -50,9 +50,9 @@ def __init__( certificate validation. :param session: Optional aiohttp opened session. :param client_session_args: Dict of extra args passed to - `aiohttp.ClientSession`_ + :class:`aiohttp.ClientSession` :param connect_args: Dict of extra args passed to - `aiohttp.ClientSession.ws_connect`_ + :meth:`aiohttp.ClientSession.ws_connect` :param float heartbeat: Send low level `ping` message every `heartbeat` seconds and wait `pong` response, close diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 770a8b34..a3d025c0 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -95,29 +95,29 @@ async def _initialize(self): """ pass # pragma: no cover - async def _stop_listener(self, query_id: int): + async def _stop_listener(self, query_id: int) -> None: """Hook to stop to listen to a specific query. Will send a stop message in some subclasses. """ pass # pragma: no cover - async def _after_connect(self): + async def _after_connect(self) -> None: """Hook to add custom code for subclasses after the connection has been established. """ pass # pragma: no cover - async def _after_initialize(self): + async def _after_initialize(self) -> None: """Hook to add custom code for subclasses after the initialization has been done. """ pass # pragma: no cover - async def _close_hook(self): + async def _close_hook(self) -> None: """Hook to add custom code for subclasses for the connection close""" pass # pragma: no cover - async def _connection_terminate(self): + async def _connection_terminate(self) -> None: """Hook to add custom code for subclasses after the initialization has been done. """ @@ -430,7 +430,7 @@ async def connect(self) -> None: log.debug("connect: done") - def _remove_listener(self, query_id) -> None: + def _remove_listener(self, query_id: int) -> None: """After exiting from a subscription, remove the listener and signal an event if this was the last listener for the client. """ diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 811601b8..4c5d33d0 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -7,6 +7,7 @@ Callable, Dict, List, + NoReturn, Optional, Tuple, Type, @@ -39,7 +40,7 @@ def __init__( url: Union[str, httpx.URL], json_serialize: Callable = json.dumps, json_deserialize: Callable = json.loads, - **kwargs, + **kwargs: Any, ): """Initialize the transport with the given httpx parameters. @@ -93,7 +94,9 @@ def _prepare_request( return post_args - def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: + def _prepare_file_uploads( + self, variable_values: Dict[str, Any], payload: Dict[str, Any] + ) -> Dict[str, Any]: # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( @@ -163,7 +166,7 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: extensions=result.get("extensions"), ) - def _raise_response_error(self, response: httpx.Response, reason: str): + def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 04ed4ff1..19760ad6 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,6 +1,6 @@ import asyncio from inspect import isawaitable -from typing import AsyncGenerator, Awaitable, cast +from typing import Any, AsyncGenerator, Awaitable, cast from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe @@ -31,8 +31,8 @@ async def close(self): async def execute( self, document: DocumentNode, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> ExecutionResult: """Execute the provided document AST for on a local GraphQL Schema.""" @@ -58,8 +58,8 @@ async def _await_if_necessary(obj): async def subscribe( self, document: DocumentNode, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Send a subscription and receive the results using an async generator diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 3885fcac..8a975b73 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -42,7 +42,7 @@ def __init__( channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, ack_timeout: Optional[Union[int, float]] = 10, - **kwargs, + **kwargs: Any, ) -> None: """Initialize the transport with the given parameters. @@ -244,7 +244,7 @@ def _required_value(d: Any, key: str, label: str) -> Any: return value def _required_subscription_id( - d: Any, label: str, must_exist: bool = False, must_not_exist=False + d: Any, label: str, must_exist: bool = False, must_not_exist: bool = False ) -> str: subscription_id = str(_required_value(d, "subscriptionId", label)) if must_exist and (subscription_id not in self.subscriptions): diff --git a/gql/transport/requests.py b/gql/transport/requests.py index bd370908..44f8a362 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,13 +1,25 @@ import io import json import logging -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import requests from graphql import DocumentNode, ExecutionResult, print_ast from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar +from requests.structures import CaseInsensitiveDict from requests_toolbelt.multipart.encoder import MultipartEncoder from gql.transport import Transport @@ -100,9 +112,9 @@ def __init__( self.json_deserialize: Callable = json_deserialize self.kwargs = kwargs - self.session = None + self.session: Optional[requests.Session] = None - self.response_headers = None + self.response_headers: Optional[CaseInsensitiveDict[str]] = None def connect(self): if self.session is None: @@ -159,7 +171,7 @@ def execute( # type: ignore if operation_name: payload["operationName"] = operation_name - post_args = { + post_args: Dict[str, Any] = { "headers": self.headers, "auth": self.auth, "cookies": self.cookies, @@ -219,7 +231,7 @@ def execute( # type: ignore if post_args["headers"] is None: post_args["headers"] = {} else: - post_args["headers"] = {**post_args["headers"]} + post_args["headers"] = dict(post_args["headers"]) post_args["headers"]["Content-Type"] = data.content_type @@ -247,7 +259,7 @@ def execute( # type: ignore ) self.response_headers = response.headers - def raise_response_error(resp: requests.Response, reason: str): + def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases @@ -255,7 +267,8 @@ def raise_response_error(resp: requests.Response, reason: str): # Raise a HTTPError if response status is 400 or higher resp.raise_for_status() except requests.HTTPError as e: - raise TransportServerError(str(e), e.response.status_code) from e + status_code = e.response.status_code if e.response is not None else None + raise TransportServerError(str(e), status_code) from e result_text = resp.text raise TransportProtocolError( diff --git a/gql/transport/transport.py b/gql/transport/transport.py index a5bd7100..49d0aa34 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,5 +1,5 @@ import abc -from typing import List +from typing import Any, List from graphql import DocumentNode, ExecutionResult @@ -8,7 +8,9 @@ class Transport(abc.ABC): @abc.abstractmethod - def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def execute( + self, document: DocumentNode, *args: Any, **kwargs: Any + ) -> ExecutionResult: """Execute GraphQL query. Execute the provided document AST for either a remote or local GraphQL Schema. @@ -23,8 +25,8 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: def execute_batch( self, reqs: List[GraphQLRequest], - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch. @@ -35,7 +37,7 @@ def execute_batch( """ raise NotImplementedError( "This Transport has not implemented the execute_batch method" - ) # pragma: no cover + ) def connect(self): """Establish a session with the transport.""" diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 3348c576..61a4bb85 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -194,7 +194,7 @@ async def _send_complete_message(self, query_id: int) -> None: await self._send(complete_message) - async def _stop_listener(self, query_id: int): + async def _stop_listener(self, query_id: int) -> None: """Stop the listener corresponding to the query_id depending on the detected backend protocol. diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index 4313188e..08fb1bf5 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -8,7 +8,7 @@ def _node_tree_recursive( *, indent: int = 0, ignored_keys: List, -): +) -> str: assert ignored_keys is not None @@ -65,7 +65,7 @@ def node_tree( ignore_loc: bool = True, ignore_block: bool = True, ignored_keys: Optional[List] = None, -): +) -> str: """Method which returns a tree of Node elements as a String. Useful to debug deep DocumentNode instances created by gql or dsl_gql. diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 02355425..f9bc2e0c 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -44,7 +44,7 @@ } -def _ignore_non_null(type_: GraphQLType): +def _ignore_non_null(type_: GraphQLType) -> GraphQLType: """Removes the GraphQLNonNull wrappings around types.""" if isinstance(type_, GraphQLNonNull): return type_.of_type @@ -153,6 +153,8 @@ def get_current_result_type(self, path): list_level = self.inside_list_level + assert field_type is not None + result_type = _ignore_non_null(field_type) if self.in_first_field(path): diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py index 80c73862..6f7ba0ce 100644 --- a/gql/utilities/update_schema_enum.py +++ b/gql/utilities/update_schema_enum.py @@ -9,7 +9,7 @@ def update_schema_enum( name: str, values: Union[Dict[str, Any], Type[Enum]], use_enum_values: bool = False, -): +) -> None: """Update in the schema the GraphQLEnumType corresponding to the given name. Example:: diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index db3adb17..c2c1b4e8 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -3,7 +3,9 @@ from graphql import GraphQLScalarType, GraphQLSchema -def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): +def update_schema_scalar( + schema: GraphQLSchema, name: str, scalar: GraphQLScalarType +) -> None: """Update the scalar in a schema with the scalar provided. :param schema: the GraphQL schema @@ -36,7 +38,9 @@ def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalar setattr(schema_scalar, "parse_literal", scalar.parse_literal) -def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): +def update_schema_scalars( + schema: GraphQLSchema, scalars: List[GraphQLScalarType] +) -> None: """Update the scalars in a schema with the scalars provided. :param schema: the GraphQL schema diff --git a/gql/utils.py b/gql/utils.py index b4265ce1..6a7d0791 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -25,17 +25,17 @@ def recurse_extract(path, obj): """ nonlocal files if isinstance(obj, list): - nulled_obj = [] + nulled_list = [] for key, value in enumerate(obj): value = recurse_extract(f"{path}.{key}", value) - nulled_obj.append(value) - return nulled_obj + nulled_list.append(value) + return nulled_list elif isinstance(obj, dict): - nulled_obj = {} + nulled_dict = {} for key, value in obj.items(): value = recurse_extract(f"{path}.{key}", value) - nulled_obj[key] = value - return nulled_obj + nulled_dict[key] = value + return nulled_dict elif isinstance(obj, file_classes): # extract obj from its parent and put it into files instead. files[path] = obj diff --git a/pyproject.toml b/pyproject.toml index 122cec88..f5eb5c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,5 +8,15 @@ dynamic = ["authors", "classifiers", "dependencies", "description", "entry-point requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.isort] +extra_standard_library = "ssl" +known_first_party = "gql" +profile = "black" + [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" + +[tool.mypy] +ignore_missing_imports = true +check_untyped_defs = true +disallow_incomplete_defs = true diff --git a/setup.cfg b/setup.cfg index 66380493..533b80f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,5 @@ [flake8] max-line-length = 88 -[isort] -known_standard_library = ssl -known_first_party = gql -multi_line_output = 3 -include_trailing_comma = True -line_length = 88 -not_skip = __init__.py - -[mypy] -ignore_missing_imports = true - [tool:pytest] norecursedirs = venv .venv .tox .git .cache .mypy_cache .pytest_cache diff --git a/setup.py b/setup.py index e8be1ef6..f000136c 100644 --- a/setup.py +++ b/setup.py @@ -24,15 +24,15 @@ ] dev_requires = [ - "black==22.3.0", + "black==25.1.0", "check-manifest>=0.42,<1", - "flake8==7.1.1", - "isort==4.3.21", - "mypy==1.10", + "flake8==7.1.2", + "isort==6.0.1", + "mypy==1.15", "sphinx>=7.0.0,<8;python_version<='3.9'", "sphinx>=8.1.0,<9;python_version>'3.9'", "sphinx_rtd_theme>=3.0.2,<4", - "sphinx-argparse==0.4.0", + "sphinx-argparse==0.5.2", "types-aiofiles", "types-requests", ] + tests_requires diff --git a/tests/conftest.py b/tests/conftest.py index 5b8807ae..70a050d5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,10 +10,11 @@ import tempfile import types from concurrent.futures import ThreadPoolExecutor -from typing import Union +from typing import Callable, Iterable, List, Union, cast import pytest import pytest_asyncio +from _pytest.fixtures import SubRequest from gql import Client @@ -219,7 +220,7 @@ async def start(self, handler, extra_serve_args=None): self.server = await self.start_server # Get hostname and port - hostname, port = self.server.sockets[0].getsockname()[:2] + hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore assert hostname == "127.0.0.1" self.hostname = hostname @@ -250,7 +251,7 @@ def __init__(self, with_ssl=False): if with_ssl: _, self.ssl_context = get_localhost_ssl_context() - def get_default_server_handler(answers): + def get_default_server_handler(answers: Iterable[str]) -> Callable: async def default_server_handler(request): import aiohttp @@ -291,7 +292,7 @@ async def default_server_handler(request): elif msg.type == WSMsgType.ERROR: print(f"WebSocket connection closed with: {ws.exception()}") - raise ws.exception() + raise ws.exception() # type: ignore elif msg.type in ( WSMsgType.CLOSE, WSMsgType.CLOSED, @@ -341,7 +342,8 @@ async def start(self, handler): await self.site.start() # Retrieve the actual port the server is listening on - sockets = self.site._server.sockets + assert self.site._server is not None + sockets = self.site._server.sockets # type: ignore if sockets: self.port = sockets[0].getsockname()[1] protocol = "https" if self.with_ssl else "http" @@ -448,7 +450,7 @@ async def send_connection_ack(ws): class TemporaryFile: """Class used to generate temporary files for the tests""" - def __init__(self, content: Union[str, bytearray]): + def __init__(self, content: Union[str, bytearray, bytes]): mode = "w" if isinstance(content, str) else "wb" @@ -474,24 +476,30 @@ def __exit__(self, type, value, traceback): os.unlink(self.filename) -def get_aiohttp_ws_server_handler(request): +def get_aiohttp_ws_server_handler( + request: SubRequest, +) -> Callable: """Get the server handler for the aiohttp websocket server. Either get it from test or use the default server handler if the test provides only an array of answers. """ + server_handler: Callable + if isinstance(request.param, types.FunctionType): server_handler = request.param else: - answers = request.param + answers = cast(List[str], request.param) server_handler = AIOHTTPWebsocketServer.get_default_server_handler(answers) return server_handler -def get_server_handler(request): +def get_server_handler( + request: SubRequest, +) -> Callable: """Get the server handler. Either get it from test or use the default server handler @@ -501,7 +509,7 @@ def get_server_handler(request): from websockets.exceptions import ConnectionClosed if isinstance(request.param, types.FunctionType): - server_handler = request.param + server_handler: Callable = request.param else: answers = request.param diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 2f15a8ca..3526d548 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional import pytest from graphql import ( @@ -6,6 +7,7 @@ GraphQLEnumType, GraphQLField, GraphQLList, + GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, @@ -251,19 +253,30 @@ def test_list_of_list_of_list(): def test_update_schema_enum(): - assert schema.get_type("Color").parse_value("RED") == Color.RED + color_type: Optional[GraphQLNamedType] + + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == Color.RED # Using values update_schema_enum(schema, "Color", Color, use_enum_values=True) - assert schema.get_type("Color").parse_value("RED") == 0 - assert schema.get_type("Color").serialize(1) == "GREEN" + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == 0 + assert color_type.serialize(1) == "GREEN" update_schema_enum(schema, "Color", Color) - assert schema.get_type("Color").parse_value("RED") == Color.RED - assert schema.get_type("Color").serialize(Color.RED) == "RED" + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == Color.RED + assert color_type.serialize(Color.RED) == "RED" def test_update_schema_enum_errors(): @@ -273,20 +286,22 @@ def test_update_schema_enum_errors(): assert "Enum Corlo not found in schema!" in str(exc_info) - with pytest.raises(TypeError) as exc_info: - update_schema_enum(schema, "Color", 6) + with pytest.raises(TypeError) as exc_info2: + update_schema_enum(schema, "Color", 6) # type: ignore - assert "Invalid type for enum values: " in str(exc_info) + assert "Invalid type for enum values: " in str(exc_info2) - with pytest.raises(TypeError) as exc_info: + with pytest.raises(TypeError) as exc_info3: update_schema_enum(schema, "RootQueryType", Color) - assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str(exc_info) + assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str( + exc_info3 + ) - with pytest.raises(KeyError) as exc_info: + with pytest.raises(KeyError) as exc_info4: update_schema_enum(schema, "Color", {"RED": Color.RED}) - assert 'Enum key "GREEN" not found in provided values!' in str(exc_info) + assert 'Enum key "GREEN" not found in provided values!' in str(exc_info4) def test_parse_results_with_operation_type(): diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index cf4ca45d..39f1a1cb 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -441,9 +441,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: [ { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } for result in results ] @@ -453,9 +453,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: return web.json_response( { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } ) @@ -680,14 +680,14 @@ async def test_update_schema_scalars(aiohttp_server): def test_update_schema_scalars_invalid_scalar(): with pytest.raises(TypeError) as exc_info: - update_schema_scalars(schema, [int]) + update_schema_scalars(schema, [int]) # type: ignore exception = exc_info.value assert str(exception) == "Scalars should be instances of GraphQLScalarType." with pytest.raises(TypeError) as exc_info: - update_schema_scalar(schema, "test", int) + update_schema_scalar(schema, "test", int) # type: ignore exception = exc_info.value @@ -697,7 +697,7 @@ def test_update_schema_scalars_invalid_scalar(): def test_update_schema_scalars_invalid_scalar_argument(): with pytest.raises(TypeError) as exc_info: - update_schema_scalars(schema, MoneyScalar) + update_schema_scalars(schema, MoneyScalar) # type: ignore exception = exc_info.value @@ -787,7 +787,7 @@ def test_code(): def test_serialize_value_with_invalid_type(): with pytest.raises(GraphQLError) as exc_info: - serialize_value("Not a valid type", 50) + serialize_value("Not a valid type", 50) # type: ignore exception = exc_info.value diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index c0177a32..61e80fa0 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -12,10 +12,10 @@ def _fake_signer_factory(request=None): class FakeSigner: - def __init__(self, request=None) -> None: + def __init__(self, request=None): self.request = request - def add_auth(self, request) -> None: + def add_auth(self, request): """ A fake for getting a request object that :return: diff --git a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py index b31ade7f..e4653d48 100644 --- a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py +++ b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py @@ -1,3 +1,5 @@ +from graphql import GraphQLSchema + from gql import Client, gql from gql.dsl import DSLFragment, DSLQuery, DSLSchema, dsl_gql, print_ast from gql.utilities import node_tree @@ -34,6 +36,9 @@ def test_issue_447(): client = Client(schema=schema_str) + + assert isinstance(client.schema, GraphQLSchema) + ds = DSLSchema(client.schema) sprite = DSLFragment("SpriteUnionAsSprite") diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 59d7ddfa..1d179f60 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -148,9 +148,10 @@ def create_review(episode, review): async def make_starwars_backend(aiohttp_server): from aiohttp import web - from .schema import StarWarsSchema from graphql import graphql_sync + from .schema import StarWarsSchema + async def handler(request): data = await request.json() source = data["query"] diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 4b672ad3..8f1efe99 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -1,4 +1,5 @@ import asyncio +from typing import cast from graphql import ( GraphQLArgument, @@ -14,6 +15,7 @@ GraphQLObjectType, GraphQLSchema, GraphQLString, + IntrospectionQuery, get_introspection_query, graphql_sync, print_schema, @@ -271,6 +273,8 @@ async def resolve_review(review, _info, **_args): ) -StarWarsIntrospection = graphql_sync(StarWarsSchema, get_introspection_query()).data +StarWarsIntrospection = cast( + IntrospectionQuery, graphql_sync(StarWarsSchema, get_introspection_query()).data +) StarWarsTypeDef = print_schema(StarWarsSchema) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 5cd051ba..d96435fc 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -4,6 +4,7 @@ GraphQLError, GraphQLFloat, GraphQLID, + GraphQLInputObjectType, GraphQLInt, GraphQLList, GraphQLNonNull, @@ -53,6 +54,7 @@ def client(): def test_ast_from_value_with_input_type_and_not_mapping_value(): obj_type = StarWarsSchema.get_type("ReviewInput") + assert isinstance(obj_type, GraphQLInputObjectType) assert ast_from_value(8, obj_type) is None @@ -78,7 +80,7 @@ def test_ast_from_value_with_graphqlid(): def test_ast_from_value_with_invalid_type(): with pytest.raises(TypeError) as exc_info: - ast_from_value(4, None) + ast_from_value(4, None) # type: ignore assert "Unexpected input type: None." in str(exc_info.value) @@ -114,7 +116,10 @@ def test_ast_from_serialized_value_untyped_typeerror(): def test_variable_to_ast_type_passing_wrapping_type(): - wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("ReviewInput"))) + review_type = StarWarsSchema.get_type("ReviewInput") + assert isinstance(review_type, GraphQLInputObjectType) + + wrapping_type = GraphQLNonNull(GraphQLList(review_type)) variable = DSLVariable("review_input") ast = variable.to_ast_type(wrapping_type) assert ast == NonNullTypeNode( @@ -383,7 +388,7 @@ def test_fetch_luke_aliased(ds): assert query == str(query_dsl) -def test_fetch_name_aliased(ds: DSLSchema): +def test_fetch_name_aliased(ds: DSLSchema) -> None: query = """ human(id: "1000") { my_name: name @@ -394,7 +399,7 @@ def test_fetch_name_aliased(ds: DSLSchema): assert query == str(query_dsl) -def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): +def test_fetch_name_aliased_as_kwargs(ds: DSLSchema) -> None: query = """ human(id: "1000") { my_name: name @@ -787,7 +792,7 @@ def test_dsl_query_all_fields_should_be_instances_of_DSLField(): TypeError, match="Fields should be instances of DSLSelectable. Received: ", ): - DSLQuery("I am a string") + DSLQuery("I am a string") # type: ignore def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): @@ -839,7 +844,7 @@ def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( TypeError, match="Operations should be instances of DSLExecutable " ): - dsl_gql("I am a string") + dsl_gql("I am a string") # type: ignore def test_DSLSchema_requires_a_schema(client): diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py index e8f3f8d4..8020b586 100644 --- a/tests/starwars/test_parse_results.py +++ b/tests/starwars/test_parse_results.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest from graphql import GraphQLError @@ -87,7 +89,7 @@ def test_key_not_found_in_result(): # Backend returned an invalid result without the hero key # Should be impossible. In that case, we ignore the missing key - result = {} + result: Dict[str, Any] = {} parsed_result = parse_result(StarWarsSchema, query, result) diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index bf15e11a..7a2a8084 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -336,4 +336,4 @@ def test_query_from_source(client): def test_already_parsed_query(client): query = gql("{ hero { name } }") with pytest.raises(TypeError, match="must be passed as a string"): - gql(query) + gql(query) # type: ignore diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index 38676836..75ce4162 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -79,7 +79,7 @@ def introspection_schema_no_directives(): introspection = copy.deepcopy(StarWarsIntrospection) # Simulate no directives key - del introspection["__schema"]["directives"] + del introspection["__schema"]["directives"] # type: ignore return Client(introspection=introspection) @@ -108,7 +108,7 @@ def validation_errors(client, query): def test_incompatible_request_gql(client): with pytest.raises(TypeError): - gql(123) + gql(123) # type: ignore """ The error generated depends on graphql-core version @@ -253,7 +253,7 @@ def test_build_client_schema_invalid_introspection(): from gql.utilities import build_client_schema with pytest.raises(TypeError) as exc_info: - build_client_schema("blah") + build_client_schema("blah") # type: ignore assert ( "Invalid or incomplete introspection result. Ensure that you are passing the " diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index e843db6c..04417c4e 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -47,6 +47,7 @@ @pytest.mark.asyncio async def test_aiohttp_query(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -86,6 +87,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_ignore_backend_content_type(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -115,6 +117,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cookies(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -148,6 +151,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_401(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -179,6 +183,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_429(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -226,6 +231,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_code_500(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -261,6 +267,7 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_aiohttp_error_code(aiohttp_server, query_error): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -316,6 +323,7 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_aiohttp_invalid_protocol(aiohttp_server, param): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport response = param["response"] @@ -344,6 +352,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_subscribe_not_supported(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -369,6 +378,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_connect_twice(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -391,6 +401,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -413,6 +424,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_extra_args(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -460,6 +472,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_variable_values(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -496,6 +509,7 @@ async def test_aiohttp_query_variable_values_fix_issue_292(aiohttp_server): See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -526,6 +540,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -552,6 +567,7 @@ def test_code(): @pytest.mark.asyncio async def test_aiohttp_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -638,6 +654,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -703,6 +720,7 @@ async def single_upload_handler_with_content_type(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_with_content_type(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -724,7 +742,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} @@ -741,6 +759,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): @pytest.mark.asyncio async def test_aiohttp_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -809,6 +828,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_aiohttp_binary_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -843,7 +863,8 @@ async def test_aiohttp_binary_file_upload(aiohttp_server): @pytest.mark.asyncio async def test_aiohttp_stream_reader_upload(aiohttp_server): - from aiohttp import web, ClientSession + from aiohttp import ClientSession, web + from gql.transport.aiohttp import AIOHTTPTransport async def binary_data_handler(request): @@ -882,6 +903,7 @@ async def binary_data_handler(request): async def test_aiohttp_async_generator_upload(aiohttp_server): import aiofiles from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() @@ -944,6 +966,7 @@ async def file_sender(file_name): @pytest.mark.asyncio async def test_aiohttp_file_upload_two_files(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1035,6 +1058,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1253,6 +1277,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_with_extensions(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1282,6 +1307,7 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_aiohttp_query_https(ssl_aiohttp_server, ssl_close_timeout, verify_https): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1328,8 +1354,9 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_https_self_cert_fail(ssl_aiohttp_server): """By default, we should verify the ssl certificate""" - from aiohttp.client_exceptions import ClientConnectorCertificateError from aiohttp import web + from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1361,6 +1388,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_query_https_self_cert_default(ssl_aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1382,6 +1410,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_error_fetching_schema(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport error_answer = """ @@ -1425,6 +1454,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_reconnecting_session(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1463,6 +1493,7 @@ async def handler(request): @pytest.mark.parametrize("retries", [False, lambda e: e]) async def test_aiohttp_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1496,6 +1527,7 @@ async def test_aiohttp_reconnecting_session_start_connecting_task_twice( aiohttp_server, caplog ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1529,6 +1561,7 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_json_serializer(aiohttp_server, caplog): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1584,9 +1617,11 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_json_deserializer(aiohttp_server): - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1623,7 +1658,8 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_connector_owner_false(aiohttp_server): - from aiohttp import web, TCPConnector + from aiohttp import TCPConnector, web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 801af6b9..86c502a9 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -179,7 +179,7 @@ async def monkey_patch_send_query( document, variable_values=None, operation_name=None, - ) -> int: + ): query_id = self.next_query_id self.next_query_id += 1 diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 8863ead9..e8832217 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,6 +8,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -763,6 +764,7 @@ def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -818,8 +820,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.exceptions import TransportClosed from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport + from gql.transport.exceptions import TransportClosed path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index deb425f7..cf91d148 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import Dict, Mapping +from typing import Any, Dict, Mapping import pytest @@ -66,7 +66,8 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(aiohttp_ws_s ) assert transport.response_headers == {} - assert transport.headers["test"] == "1234" + assert isinstance(transport.headers, Mapping) + assert transport.headers["test"] == "1234" # type: ignore async with Client(transport=transport) as session: @@ -154,6 +155,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( ): from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport server = ws_ssl_server @@ -161,7 +163,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["ssl"] = True @@ -645,7 +647,6 @@ async def test_aiohttp_websocket_non_regression_bug_108( async def test_aiohttp_websocket_using_cli( aiohttp_ws_server, transport_arg, monkeypatch, capsys ): - """ Note: depending on the transport_arg parameter, if there is no transport argument, then we will use WebsocketsTransport if the websockets dependency is installed, diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 5beb023e..83ae3589 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,6 +9,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -228,6 +229,7 @@ async def test_aiohttp_websocket_subscription_get_execution_result( async for result in session.subscribe(subscription, get_execution_result=True): assert isinstance(result, ExecutionResult) + assert result.data is not None number = result.data["number"] print(f"Number received: {number}") @@ -669,6 +671,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -678,6 +681,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( assert count == 4 # Catch interrupt_task exception to remove warning + assert interrupt_task is not None interrupt_task.exception() # Check that the server received a connection_terminate message last diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index cb279ae5..8abb3410 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -23,6 +23,7 @@ def test_appsync_init_with_minimal_args(fake_session_factory): @pytest.mark.botocore def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): @@ -72,6 +73,7 @@ def test_appsync_init_with_apikey_auth(): @pytest.mark.botocore def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -108,10 +110,13 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have the AWS_DEFAULT_REGION environment variable set """ - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.exceptions import NoRegionError import logging + from botocore.exceptions import NoRegionError + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + caplog.set_level(logging.WARNING) with pytest.raises(NoRegionError): @@ -120,6 +125,8 @@ def test_appsync_init_with_iam_auth_and_no_region( session._credentials.region = None transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=session) + assert isinstance(transport.auth, AppSyncIAMAuthentication) + # prints the region name in case the test fails print(f"Region found: {transport.auth._region_name}") diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 2a6c9ca7..536b2fe9 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -9,10 +9,12 @@ @pytest.mark.aiohttp @pytest.mark.botocore async def test_appsync_iam_mutation(aiohttp_server, fake_credentials_factory): + from urllib.parse import urlparse + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication - from urllib.parse import urlparse async def handler(request): data = { diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 7aa96292..37cbe460 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -426,9 +426,10 @@ async def test_appsync_subscription_api_key(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_with_token(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -453,9 +454,10 @@ async def test_appsync_subscription_iam_with_token(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_without_token(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -479,9 +481,10 @@ async def test_appsync_subscription_iam_without_token(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_execute_method_not_allowed(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -526,9 +529,10 @@ async def test_appsync_execute_method_not_allowed(server): @pytest.mark.botocore async def test_appsync_fetch_schema_from_transport_not_allowed(): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials dummy_credentials = Credentials( access_key=DUMMY_ACCESS_KEY_ID, @@ -579,10 +583,11 @@ async def test_appsync_subscription_api_key_unauthorized(server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_not_allowed(server): + from botocore.credentials import Credentials + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportQueryError - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" diff --git a/tests/test_cli.py b/tests/test_cli.py index dccfcb5a..4c6b7d15 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -286,8 +286,8 @@ async def test_cli_main_appsync_websockets_iam(parser, url): ) def test_cli_get_transport_appsync_websockets_api_key(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--api-key", "test-api-key"] @@ -307,8 +307,8 @@ def test_cli_get_transport_appsync_websockets_api_key(parser, url): ) def test_cli_get_transport_appsync_websockets_jwt(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--jwt", "test-jwt"] diff --git a/tests/test_client.py b/tests/test_client.py index e5edec8b..8669b4a3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,10 @@ import os from contextlib import suppress +from typing import Any from unittest import mock import pytest -from graphql import build_ast_schema, parse +from graphql import DocumentNode, ExecutionResult, build_ast_schema, parse from gql import Client, GraphQLRequest, gql from gql.transport import Transport @@ -29,19 +30,27 @@ def http_transport_query(): def test_request_transport_not_implemented(http_transport_query): class RandomTransport(Transport): - def execute(self): - super().execute(http_transport_query) + pass - with pytest.raises(NotImplementedError) as exc_info: - RandomTransport().execute() + with pytest.raises(TypeError) as exc_info: + RandomTransport() # type: ignore - assert "Any Transport subclass must implement execute method" == str(exc_info.value) + assert "Can't instantiate abstract class RandomTransport" in str(exc_info.value) - with pytest.raises(NotImplementedError) as exc_info: - RandomTransport().execute_batch([]) + class RandomTransport2(Transport): + def execute( + self, + document: DocumentNode, + *args: Any, + **kwargs: Any, + ) -> ExecutionResult: + return ExecutionResult() + + with pytest.raises(NotImplementedError) as exc_info2: + RandomTransport2().execute_batch([]) assert "This Transport has not implemented the execute_batch method" == str( - exc_info.value + exc_info2.value ) @@ -70,7 +79,7 @@ def test_retries_on_transport(execute_mock): expected_retries = 3 execute_mock.side_effect = NewConnectionError( - "Should be HTTPConnection", "Fake connection error" + "Should be HTTPConnection", "Fake connection error" # type: ignore ) transport = RequestsHTTPTransport( url="http://127.0.0.1:8000/graphql", @@ -109,11 +118,10 @@ def test_retries_on_transport(execute_mock): assert execute_mock.call_count == expected_retries + 1 -def test_no_schema_exception(): +def test_no_schema_no_transport_exception(): with pytest.raises(AssertionError) as exc_info: - client = Client() - client.validate("") - assert "Cannot validate the document locally, you need to pass a schema." in str( + Client() + assert "You need to provide either a transport or a schema to the Client." in str( exc_info.value ) @@ -255,6 +263,7 @@ def test_sync_transport_close_on_schema_retrieval_failure(): # transport is closed afterwards pass + assert isinstance(client.transport, RequestsHTTPTransport) assert client.transport.session is None @@ -279,6 +288,7 @@ async def test_async_transport_close_on_schema_retrieval_failure(): # transport is closed afterwards pass + assert isinstance(client.transport, AIOHTTPTransport) assert client.transport.session is None import asyncio diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 2735fbb0..94028d26 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,6 +8,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -757,6 +758,7 @@ def test_graphqlws_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -812,8 +814,8 @@ async def test_graphqlws_subscription_reconnecting_session( graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import TransportClosed + from gql.transport.websockets import WebsocketsTransport path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" diff --git a/tests/test_httpx.py b/tests/test_httpx.py index c15872d7..43d74ec6 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -1,4 +1,4 @@ -from typing import Mapping +from typing import Any, Dict, Mapping import pytest @@ -38,6 +38,7 @@ @pytest.mark.asyncio async def test_httpx_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -82,6 +83,7 @@ def test_code(): @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_httpx_query_https(ssl_aiohttp_server, run_sync_test, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -144,6 +146,7 @@ async def test_httpx_query_https_self_cert_fail( """By default, we should verify the ssl certificate""" from aiohttp import web from httpx import ConnectError + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -162,7 +165,7 @@ async def handler(request): assert str(url).startswith("https://") def test_code(): - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -191,6 +194,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -228,6 +232,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -263,6 +268,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -312,6 +318,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -344,6 +351,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -382,6 +390,7 @@ def test_code(): @pytest.mark.parametrize("response", invalid_protocol_responses) async def test_httpx_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -410,6 +419,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -436,6 +446,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -473,6 +484,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -528,6 +540,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -588,6 +601,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_with_content_type(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -638,7 +652,7 @@ def test_code(): with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} execution_result = session._execute( @@ -654,6 +668,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def single_upload_handler(request): @@ -716,6 +731,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport # This is a sample binary file content containing all possible byte values @@ -789,6 +805,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_2 = """ @@ -887,6 +904,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_3 = """ @@ -976,6 +994,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport error_answer = """ diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 44764ea4..49ea6a24 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1,6 +1,6 @@ import io import json -from typing import Mapping +from typing import Any, Dict, Mapping import pytest @@ -48,6 +48,7 @@ @pytest.mark.asyncio async def test_httpx_query(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -88,6 +89,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_ignore_backend_content_type(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -118,6 +120,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cookies(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -152,6 +155,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_401(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -184,6 +188,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_429(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -232,6 +237,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_code_500(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -268,6 +274,7 @@ async def handler(request): @pytest.mark.parametrize("query_error", transport_query_error_responses) async def test_httpx_error_code(aiohttp_server, query_error): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -324,6 +331,7 @@ async def handler(request): @pytest.mark.parametrize("param", invalid_protocol_responses) async def test_httpx_invalid_protocol(aiohttp_server, param): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport response = param["response"] @@ -353,6 +361,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_subscribe_not_supported(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -379,6 +388,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_connect_twice(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -402,6 +412,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -424,9 +435,10 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_extra_args(aiohttp_server): + import httpx from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport - import httpx async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -438,8 +450,8 @@ async def handler(request): url = str(server.make_url("/")) # passing extra arguments to httpx.AsyncClient - transport = httpx.AsyncHTTPTransport(retries=2) - transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=transport) + inner_transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=inner_transport) async with Client(transport=transport) as session: @@ -470,6 +482,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_variable_values(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -507,6 +520,7 @@ async def test_httpx_query_variable_values_fix_issue_292(aiohttp_server): See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -538,6 +552,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -565,6 +580,7 @@ def test_code(): @pytest.mark.asyncio async def test_httpx_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -652,6 +668,7 @@ async def single_upload_handler(request): @pytest.mark.asyncio async def test_httpx_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -688,6 +705,7 @@ async def test_httpx_file_upload(aiohttp_server): @pytest.mark.asyncio async def test_httpx_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -757,6 +775,7 @@ async def binary_upload_handler(request): @pytest.mark.asyncio async def test_httpx_binary_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() @@ -815,6 +834,7 @@ async def test_httpx_binary_file_upload(aiohttp_server): @pytest.mark.asyncio async def test_httpx_file_upload_two_files(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -907,6 +927,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1130,6 +1151,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_query_with_extensions(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1159,6 +1181,7 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_httpx_query_https(ssl_aiohttp_server, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1202,9 +1225,10 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_httpx_query_https_self_cert_fail(ssl_aiohttp_server, verify_https): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport from httpx import ConnectError + from gql.transport.httpx import HTTPXAsyncTransport + async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -1216,7 +1240,7 @@ async def handler(request): assert url.startswith("https://") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -1240,6 +1264,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_error_fetching_schema(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport error_answer = """ @@ -1284,6 +1309,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_reconnecting_session(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1323,6 +1349,7 @@ async def handler(request): @pytest.mark.parametrize("retries", [False, lambda e: e]) async def test_httpx_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1357,6 +1384,7 @@ async def test_httpx_reconnecting_session_start_connecting_task_twice( aiohttp_server, caplog ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1391,6 +1419,7 @@ async def handler(request): @pytest.mark.asyncio async def test_httpx_json_serializer(aiohttp_server, caplog): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1447,9 +1476,11 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_json_deserializer(aiohttp_server): - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 2a312d71..09c129b3 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -19,9 +19,7 @@ def ensure_list(s): return ( s if s is None or isinstance(s, list) - else list(s) - if isinstance(s, tuple) - else [s] + else list(s) if isinstance(s, tuple) else [s] ) @@ -360,9 +358,10 @@ def subscription_server( data_answers=default_subscription_data_answer, unsubscribe_answers=default_subscription_unsubscribe_answer, ): - from .conftest import PhoenixChannelServerHelper import json + from .conftest import PhoenixChannelServerHelper + async def phoenix_server(ws): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index 621f648e..7dff7062 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -110,10 +110,11 @@ async def test_phoenix_channel_query_ssl(ws_ssl_server, query_str): async def test_phoenix_channel_query_ssl_self_cert_fail( ws_ssl_server, query_str, verify_https ): + from ssl import SSLCertVerificationError + from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) - from ssl import SSLCertVerificationError path = "/graphql" server = ws_ssl_server diff --git a/tests/test_requests.py b/tests/test_requests.py index 8f3b0b7a..9c0334bd 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,4 +1,4 @@ -from typing import Mapping +from typing import Any, Dict, Mapping import pytest @@ -42,6 +42,7 @@ @pytest.mark.asyncio async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -85,9 +86,11 @@ def test_code(): @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_requests_query_https(ssl_aiohttp_server, run_sync_test, verify_https): + import warnings + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - import warnings async def handler(request): return web.Response( @@ -151,9 +154,10 @@ async def test_requests_query_https_self_cert_fail( ): """By default, we should verify the ssl certificate""" from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport from requests.exceptions import SSLError + from gql.transport.requests import RequestsHTTPTransport + async def handler(request): return web.Response( text=query1_server_answer, @@ -168,7 +172,7 @@ async def handler(request): url = server.make_url("/") def test_code(): - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -197,6 +201,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -234,6 +239,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -269,6 +275,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -318,6 +325,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -350,6 +358,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -388,6 +397,7 @@ def test_code(): @pytest.mark.parametrize("response", invalid_protocol_responses) async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -416,6 +426,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -442,6 +453,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -479,6 +491,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -534,6 +547,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -594,6 +608,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_with_content_type(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -644,7 +659,7 @@ def test_code(): with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} execution_result = session._execute( @@ -660,6 +675,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def single_upload_handler(request): @@ -722,6 +738,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport # This is a sample binary file content containing all possible byte values @@ -795,6 +812,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_2 = """ @@ -893,6 +911,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_3 = """ @@ -982,6 +1001,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport error_answer = """ @@ -1029,7 +1049,9 @@ def test_code(): @pytest.mark.asyncio async def test_requests_json_serializer(aiohttp_server, run_sync_test, caplog): import json + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -1089,9 +1111,11 @@ def test_code(): @pytest.mark.asyncio async def test_requests_json_deserializer(aiohttp_server, run_sync_test): import json - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index dbd3dfa5..4b9e09b8 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -50,6 +50,7 @@ @pytest.mark.asyncio async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -93,6 +94,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -140,9 +142,11 @@ def test_code(): async def test_requests_query_auto_batch_enabled_two_requests( aiohttp_server, run_sync_test ): + from threading import Thread + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - from threading import Thread async def handler(request): return web.Response( @@ -199,6 +203,7 @@ def test_thread(): @pytest.mark.asyncio async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -238,6 +243,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -275,6 +281,7 @@ async def test_requests_error_code_401_auto_batch_enabled( aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -313,6 +320,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -362,6 +370,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -394,6 +403,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -437,6 +447,7 @@ def test_code(): @pytest.mark.parametrize("response", invalid_protocol_responses) async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -465,6 +476,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -504,6 +516,7 @@ def test_code(): @pytest.mark.asyncio async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -543,6 +556,7 @@ def test_code(): def test_requests_sync_batch_auto(): from threading import Thread + from gql.transport.requests import RequestsHTTPTransport client = Client( diff --git a/tests/test_transport.py b/tests/test_transport.py index d9a3eced..e554955a 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -28,6 +28,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index a9b21e6a..7c108ec3 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -28,6 +28,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 9c43965f..08058aea 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -175,7 +175,7 @@ async def monkey_patch_send_query( document, variable_values=None, operation_name=None, - ) -> int: + ): query_id = self.next_query_id self.next_query_id += 1 @@ -366,9 +366,10 @@ async def test_websocket_using_cli_invalid_query(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index fa288b6d..c53be5f4 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -27,12 +27,10 @@ async def test_websocket_simple_query(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( - url="wss://countries.trevorblades.com/graphql" - ) + transport = WebsocketsTransport(url="wss://countries.trevorblades.com/graphql") # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -68,12 +66,12 @@ async def test_websocket_invalid_query(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -98,12 +96,12 @@ async def test_websocket_sending_invalid_data(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -122,7 +120,8 @@ async def test_websocket_sending_invalid_data(): invalid_data = "QSDF" print(f">>> {invalid_data}") - await sample_transport.websocket.send(invalid_data) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2) @@ -134,17 +133,18 @@ async def test_websocket_sending_invalid_payload(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport): + async with Client(transport=transport): invalid_payload = '{"id": "1", "type": "start", "payload": "BLAHBLAH"}' print(f">>> {invalid_payload}") - await sample_transport.websocket.send(invalid_payload) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_payload) await asyncio.sleep(2) @@ -156,12 +156,12 @@ async def test_websocket_sending_invalid_data_while_other_query_is_running(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -190,7 +190,8 @@ async def query_task2(): invalid_data = "QSDF" print(f">>> {invalid_data}") - await sample_transport.websocket.send(invalid_data) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_data) task1 = asyncio.create_task(query_task1()) task2 = asyncio.create_task(query_task2()) @@ -207,12 +208,12 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 919f6bdb..99ff7334 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import Dict, Mapping +from typing import Any, Dict, Mapping import pytest @@ -60,6 +60,7 @@ async def test_websocket_starting_client_in_context_manager(server): transport = WebsocketsTransport(url=url, headers={"test": "1234"}) assert transport.response_headers == {} + assert isinstance(transport.headers, Mapping) assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: @@ -93,6 +94,7 @@ async def test_websocket_starting_client_in_context_manager(server): @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) async def test_websocket_using_ssl_connection(ws_ssl_server): import websockets + from gql.transport.websockets import WebsocketsTransport server = ws_ssl_server @@ -138,15 +140,16 @@ async def test_websocket_using_ssl_connection(ws_ssl_server): async def test_websocket_using_ssl_connection_self_cert_fail( ws_ssl_server, verify_https ): - from gql.transport.websockets import WebsocketsTransport from ssl import SSLCertVerificationError + from gql.transport.websockets import WebsocketsTransport + server = ws_ssl_server url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["ssl"] = True @@ -585,10 +588,11 @@ async def test_websocket_using_cli(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io import json + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index a020e1f5..89acd635 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,6 +9,7 @@ from parse import search from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, PyPy, WebSocketServerHelper @@ -160,6 +161,7 @@ async def test_websocket_subscription_get_execution_result( assert isinstance(result, ExecutionResult) + assert result.data is not None number = result.data["number"] print(f"Number received: {number}") @@ -600,6 +602,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -609,6 +612,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) assert count == 4 # Catch interrupt_task exception to remove warning + assert interrupt_task is not None interrupt_task.exception() # Check that the server received a connection_terminate message last diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py index f070f497..f0448c79 100644 --- a/tests/test_websockets_adapter.py +++ b/tests/test_websockets_adapter.py @@ -1,4 +1,5 @@ import json +from typing import Mapping import pytest from graphql import print_ast @@ -73,11 +74,12 @@ async def test_websockets_adapter_edge_cases(server): query = print_ast(gql(query1_str)) print("query=", query) - adapter = WebSocketsAdapter(url, headers={"a": 1}, ssl=False, connect_args={}) + adapter = WebSocketsAdapter(url, headers={"a": "r1"}, ssl=False, connect_args={}) await adapter.connect() - assert adapter.headers["a"] == 1 + assert isinstance(adapter.headers, Mapping) + assert adapter.headers["a"] == "r1" assert adapter.ssl is False assert adapter.connect_args == {} assert adapter.response_headers["dummy"] == "test1234" diff --git a/tox.ini b/tox.ini index 8796357b..f6d4b48e 100644 --- a/tox.ini +++ b/tox.ini @@ -47,7 +47,7 @@ commands = basepython = python deps = -e.[dev] commands = - isort --recursive --check-only --diff gql tests + isort --check-only --diff gql tests [testenv:mypy] basepython = python From 2e2ba3f8358323ad12322dc32169658a2815e702 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 12 Mar 2025 17:20:37 +0100 Subject: [PATCH 206/239] Chore bump aiohttp to 3.11.2 (#541) --- gql/transport/common/adapters/aiohttp.py | 10 +++++++--- setup.py | 3 +-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index 736f2a3e..d5b16a82 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Literal, Mapping, Optional, Union import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp import BasicAuth, ClientWSTimeout, Fingerprint, WSMsgType from aiohttp.typedefs import LooseHeaders, StrOrURL from multidict import CIMultiDictProxy @@ -132,6 +132,11 @@ async def connect(self) -> None: self.session = aiohttp.ClientSession(**client_session_args) + ws_timeout = ClientWSTimeout( + ws_receive=self.receive_timeout, + ws_close=self.websocket_close_timeout, + ) + connect_args: Dict[str, Any] = { "url": self.url, "headers": self.headers, @@ -142,8 +147,7 @@ async def connect(self) -> None: "proxy": self.proxy, "proxy_auth": self.proxy_auth, "proxy_headers": self.proxy_headers, - "timeout": self.websocket_close_timeout, - "receive_timeout": self.receive_timeout, + "timeout": ws_timeout, } if self.subprotocols: diff --git a/setup.py b/setup.py index f000136c..a36284b0 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,7 @@ ] + tests_requires install_aiohttp_requires = [ - "aiohttp>=3.8.0,<4;python_version<='3.11'", - "aiohttp>=3.9.0b0,<4;python_version>'3.11'", + "aiohttp>=3.11.2,<4", ] install_requests_requires = [ From fbe03c4d49009614ae0d232ca2aa320e486f9d90 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 12 Mar 2025 17:42:27 +0100 Subject: [PATCH 207/239] Fix httpx test deprecated warning (#542) --- tests/test_httpx.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 43d74ec6..d129f022 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -11,7 +11,11 @@ TransportServerError, ) -from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + strip_braces_spaces, +) # Marking all tests in this file with the httpx marker pytestmark = pytest.mark.httpx @@ -105,9 +109,9 @@ def test_code(): extra_args = {} if verify_https == "cert_provided": - cert, _ = get_localhost_ssl_context() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["verify"] = cert.decode() + extra_args["verify"] = ssl_context elif verify_https == "disabled": extra_args["verify"] = False From 886a6b818671e9a62900f2fdfe578cab7ea78f75 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 15 Mar 2025 00:16:21 +0100 Subject: [PATCH 208/239] Upgrade lastest websockets and Exceptions overhaul (#543) --- gql/client.py | 14 +-- gql/transport/common/adapters/aiohttp.py | 11 ++- gql/transport/common/adapters/websockets.py | 22 +++-- gql/transport/common/base.py | 35 +++++--- setup.py | 2 +- tests/conftest.py | 39 +++------ tests/test_aiohttp_online.py | 14 ++- tests/test_aiohttp_websocket_exceptions.py | 29 ++++--- ..._aiohttp_websocket_graphqlws_exceptions.py | 9 +- ...iohttp_websocket_graphqlws_subscription.py | 87 ++++++++++--------- tests/test_aiohttp_websocket_query.py | 5 +- tests/test_appsync_auth.py | 34 ++++---- tests/test_appsync_http.py | 4 +- tests/test_appsync_websockets.py | 2 +- tests/test_async_client_validation.py | 16 ++-- tests/test_graphqlws_exceptions.py | 39 +++++++-- tests/test_graphqlws_subscription.py | 87 ++++++++++--------- tests/test_http_async_sync.py | 18 ++-- tests/test_httpx_online.py | 14 ++- tests/test_phoenix_channel_exceptions.py | 32 +++---- tests/test_phoenix_channel_subscription.py | 12 +-- tests/test_websocket_exceptions.py | 47 +++++++--- tests/test_websocket_query.py | 13 ++- tests/test_websocket_subscription.py | 30 +++---- 24 files changed, 336 insertions(+), 279 deletions(-) diff --git a/gql/client.py b/gql/client.py index faf3230a..99cd6e46 100644 --- a/gql/client.py +++ b/gql/client.py @@ -35,7 +35,7 @@ from .graphql_request import GraphQLRequest from .transport.async_transport import AsyncTransport -from .transport.exceptions import TransportClosed, TransportQueryError +from .transport.exceptions import TransportConnectionFailed, TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport from .utilities import build_client_schema, get_introspection_query_ast @@ -1730,6 +1730,7 @@ async def _connection_loop(self): # Then wait for the reconnect event self._reconnect_request_event.clear() await self._reconnect_request_event.wait() + await self.transport.close() async def start_connecting_task(self): """Start the task responsible to restart the connection @@ -1758,7 +1759,7 @@ async def _execute_once( **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent method _execute but requesting a - reconnection if we receive a TransportClosed exception. + reconnection if we receive a TransportConnectionFailed exception. """ try: @@ -1770,7 +1771,7 @@ async def _execute_once( parse_result=parse_result, **kwargs, ) - except TransportClosed: + except TransportConnectionFailed: self._reconnect_request_event.set() raise @@ -1786,7 +1787,8 @@ async def _execute( **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent, but with optional retries - and requesting a reconnection if we receive a TransportClosed exception. + and requesting a reconnection if we receive a + TransportConnectionFailed exception. """ return await self._execute_with_retries( @@ -1808,7 +1810,7 @@ async def _subscribe( **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Same Async generator as parent method _subscribe but requesting a - reconnection if we receive a TransportClosed exception. + reconnection if we receive a TransportConnectionFailed exception. """ inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe( @@ -1824,7 +1826,7 @@ async def _subscribe( async for result in inner_generator: yield result - except TransportClosed: + except TransportConnectionFailed: self._reconnect_request_event.set() raise diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index d5b16a82..d2e1a346 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -178,12 +178,14 @@ async def send(self, message: str) -> None: TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionFailed("Connection is already closed") + raise TransportConnectionFailed("WebSocket connection is already closed") try: await self.websocket.send_str(message) - except ConnectionResetError as e: - raise TransportConnectionFailed("Connection was closed") from e + except Exception as e: + raise TransportConnectionFailed( + f"Error trying to send data: {type(e).__name__}" + ) from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -200,6 +202,9 @@ async def receive(self) -> str: raise TransportConnectionFailed("Connection is already closed") while True: + # Should not raise any exception: + # https://docs.aiohttp.org/en/stable/_modules/aiohttp/client_ws.html + # #ClientWebSocketResponse.receive ws_message = await self.websocket.receive() # Ignore low-level ping and pong received diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index c2524fb4..6d248e71 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Union import websockets -from websockets.client import WebSocketClientProtocol +from websockets import ClientConnection from websockets.datastructures import Headers, HeadersLike from ...exceptions import TransportConnectionFailed, TransportProtocolError @@ -40,7 +40,7 @@ def __init__( self._headers: Optional[HeadersLike] = headers self.ssl = ssl - self.websocket: Optional[WebSocketClientProtocol] = None + self.websocket: Optional[ClientConnection] = None self._response_headers: Optional[Headers] = None async def connect(self) -> None: @@ -57,7 +57,7 @@ async def connect(self) -> None: # Set default arguments used in the websockets.connect call connect_args: Dict[str, Any] = { "ssl": ssl, - "extra_headers": self.headers, + "additional_headers": self.headers, } if self.subprotocols: @@ -68,11 +68,13 @@ async def connect(self) -> None: # Connection to the specified url try: - self.websocket = await websockets.client.connect(self.url, **connect_args) + self.websocket = await websockets.connect(self.url, **connect_args) except Exception as e: raise TransportConnectionFailed("Connect failed") from e - self._response_headers = self.websocket.response_headers + assert self.websocket.response is not None + + self._response_headers = self.websocket.response.headers async def send(self, message: str) -> None: """Send message to the WebSocket server. @@ -84,12 +86,14 @@ async def send(self, message: str) -> None: TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionFailed("Connection is already closed") + raise TransportConnectionFailed("WebSocket connection is already closed") try: await self.websocket.send(message) except Exception as e: - raise TransportConnectionFailed("Connection was closed") from e + raise TransportConnectionFailed( + f"Error trying to send data: {type(e).__name__}" + ) from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -109,7 +113,9 @@ async def receive(self) -> str: try: data = await self.websocket.recv() except Exception as e: - raise TransportConnectionFailed("Connection was closed") from e + raise TransportConnectionFailed( + f"Error trying to receive data: {type(e).__name__}" + ) from e # websocket.recv() can return either str or bytes # In our case, we should receive only str here diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index a3d025c0..cae8f488 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -127,11 +127,13 @@ async def _send(self, message: str) -> None: """Send the provided message to the adapter connection and log the message""" if not self._connected: - raise TransportClosed( - "Transport is not connected" - ) from self.close_exception + if isinstance(self.close_exception, TransportConnectionFailed): + raise self.close_exception + else: + raise TransportConnectionFailed() from self.close_exception try: + # Can raise TransportConnectionFailed await self.adapter.send(message) log.info(">>> %s", message) except TransportConnectionFailed as e: @@ -143,7 +145,7 @@ async def _receive(self) -> str: # It is possible that the connection has been already closed in another task if not self._connected: - raise TransportClosed("Transport is already closed") + raise TransportConnectionFailed() from self.close_exception # Wait for the next frame. # Can raise TransportConnectionFailed or TransportProtocolError @@ -214,8 +216,6 @@ async def _receive_data_loop(self) -> None: except (TransportConnectionFailed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break - except TransportClosed: - break # Parse the answer try: @@ -482,6 +482,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # We should always have an active websocket connection here assert self._connected + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + # Properly shut down liveness checker if enabled if self.check_keep_alive_task is not None: # More info: https://stackoverflow.com/a/43810272/1113207 @@ -492,10 +496,6 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # Calling the subclass close hook await self._close_hook() - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - if clean_close: log.debug("_close_coro: starting clean_close") try: @@ -503,7 +503,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: except Exception as exc: # pragma: no cover log.warning("Ignoring exception in _clean_close: " + repr(exc)) - log.debug("_close_coro: sending exception to listeners") + if log.isEnabledFor(logging.DEBUG): + log.debug( + f"_close_coro: sending exception to {len(self.listeners)} listeners" + ) # Send an exception to all remaining listeners for query_id, listener in self.listeners.items(): @@ -530,7 +533,15 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: exiting") async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) + if log.isEnabledFor(logging.DEBUG): + import inspect + + current_frame = inspect.currentframe() + assert current_frame is not None + caller_frame = current_frame.f_back + assert caller_frame is not None + caller_name = inspect.getframeinfo(caller_frame).function + log.debug(f"_fail from {caller_name}: " + repr(e)) if self.close_task is None: diff --git a/setup.py b/setup.py index a36284b0..aed15440 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ ] install_websockets_requires = [ - "websockets>=10.1,<14", + "websockets>=14.2,<16", ] install_botocore_requires = [ diff --git a/tests/conftest.py b/tests/conftest.py index 70a050d5..c69551b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -197,7 +197,7 @@ def __init__(self, with_ssl: bool = False): async def start(self, handler, extra_serve_args=None): - import websockets.server + import websockets print("Starting server") @@ -209,16 +209,21 @@ async def start(self, handler, extra_serve_args=None): extra_serve_args["ssl"] = ssl_context # Adding dummy response headers - extra_serve_args["extra_headers"] = {"dummy": "test1234"} + extra_headers = {"dummy": "test1234"} + + def process_response(connection, request, response): + response.headers.update(extra_headers) + return response # Start a server with a random open port - self.start_server = websockets.server.serve( - handler, "127.0.0.1", 0, **extra_serve_args + self.server = await websockets.serve( + handler, + "127.0.0.1", + 0, + process_response=process_response, + **extra_serve_args, ) - # Wait that the server is started - self.server = await self.start_server - # Get hostname and port hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore assert hostname == "127.0.0.1" @@ -603,24 +608,6 @@ async def graphqlws_server(request): subprotocol = "graphql-transport-ws" - from websockets.server import WebSocketServerProtocol - - class CustomSubprotocol(WebSocketServerProtocol): - def select_subprotocol(self, client_subprotocols, server_subprotocols): - print(f"Client subprotocols: {client_subprotocols!r}") - print(f"Server subprotocols: {server_subprotocols!r}") - - return subprotocol - - def process_subprotocol(self, headers, available_subprotocols): - # Overwriting available subprotocols - available_subprotocols = [subprotocol] - - print(f"headers: {headers!r}") - # print (f"Available subprotocols: {available_subprotocols!r}") - - return super().process_subprotocol(headers, available_subprotocols) - server_handler = get_server_handler(request) try: @@ -628,7 +615,7 @@ def process_subprotocol(self, headers, available_subprotocols): # Starting the server with the fixture param as the handler function await test_server.start( - server_handler, extra_serve_args={"create_protocol": CustomSubprotocol} + server_handler, extra_serve_args={"subprotocols": [subprotocol]} ) yield test_server diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 7cacd921..a4f2480c 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -19,10 +19,10 @@ async def test_aiohttp_simple_query(): url = "https://countries.trevorblades.com/graphql" # Get transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -60,11 +60,9 @@ async def test_aiohttp_invalid_query(): from gql.transport.aiohttp import AIOHTTPTransport - sample_transport = AIOHTTPTransport( - url="https://countries.trevorblades.com/graphql" - ) + transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -89,12 +87,12 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks(): from gql.transport.aiohttp import AIOHTTPTransport - sample_transport = AIOHTTPTransport( + transport = AIOHTTPTransport( url="https://countries.trevorblades.com/graphql", ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 86c502a9..2fb6722c 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -118,10 +118,10 @@ async def test_aiohttp_websocket_server_does_not_send_ack(server, query_str): url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -261,10 +261,10 @@ async def test_aiohttp_websocket_server_does_not_ack(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -281,10 +281,10 @@ async def test_aiohttp_websocket_server_closing_directly(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) with pytest.raises(TransportConnectionFailed): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -301,6 +301,15 @@ async def test_aiohttp_websocket_server_closing_after_ack(aiohttp_client_and_ser query = gql("query { hello }") + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed): + await session.execute(query) + + await session.transport.wait_closed() + + print("\n Trying to execute second query.\n") + with pytest.raises(TransportConnectionFailed): await session.execute(query) @@ -323,10 +332,10 @@ async def test_aiohttp_websocket_server_sending_invalid_query_errors(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) # Invalid server message is ignored - async with Client(transport=sample_transport): + async with Client(transport=transport): await asyncio.sleep(2 * MS) @@ -342,9 +351,9 @@ async def test_aiohttp_websocket_non_regression_bug_105(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) # Create a coroutine which start the connection with the transport but does nothing async def client_connect(client): diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index a7548cce..52bc27a4 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -5,7 +5,6 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, TransportConnectionFailed, TransportProtocolError, TransportQueryError, @@ -117,7 +116,7 @@ async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): async with Client(transport=transport): @@ -264,10 +263,14 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") + print("\n Trying to execute first query.\n") + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index e8832217..7c000d01 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -11,7 +11,7 @@ from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, PyPy, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] @@ -821,7 +821,6 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport - from gql.transport.exceptions import TransportClosed path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" @@ -839,56 +838,62 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( reconnecting=True, retry_connect=False, retry_execute=False ) - # First we make a subscription which will cause a disconnect in the backend - # (count=8) - try: - print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") - async for result in session.subscribe(subscription_with_disconnect): - pass - except TransportConnectionFailed: - pass - - await asyncio.sleep(50 * MS) - - # Then with the same session handle, we make a subscription or an execute - # which will detect that the transport is closed so that the client could - # try to reconnect - generator = None + # First we make a query or subscription which will cause a disconnect + # in the backend (count=8) try: if execute_instead_of_subscribe: - print("\nEXECUTION_2\n") - await session.execute(subscription) + print("\nEXECUTION_1\n") + await session.execute(subscription_with_disconnect) else: - print("\nSUBSCRIPTION_2\n") - generator = session.subscribe(subscription) - async for result in generator: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): pass - except (TransportClosed, TransportConnectionFailed): - if generator: - await generator.aclose() + except TransportConnectionFailed: pass - timeout = 50 + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - if PyPy: - timeout = 500 + # Wait for reconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if transport._connected: + print(f"\nConnected again in {i+1} MS") + break - await asyncio.sleep(timeout * MS) + assert transport._connected is True + + # Then after the reconnection, we make a query or a subscription + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + result = await session.execute(subscription) + assert result["number"] == 10 + else: + print("\nSUBSCRIPTION_2\n") + generator = session.subscribe(subscription) + async for result in generator: + number = result["number"] + print(f"Number received: {number}") - # And finally with the same session handle, we make a subscription - # which works correctly - print("\nSUBSCRIPTION_3\n") - generator = session.subscribe(subscription) - async for result in generator: + assert number == count + count -= 1 - number = result["number"] - print(f"Number received: {number}") + await generator.aclose() - assert number == count - count -= 1 + assert count == -1 - await generator.aclose() + # Close the reconnecting session + await client.close_async() - assert count == -1 + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - await client.close_async() + assert transport._connected is False diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index cf91d148..a3087d78 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -8,7 +8,6 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, TransportConnectionFailed, TransportQueryError, TransportServerError, @@ -319,11 +318,11 @@ async def test_aiohttp_websocket_server_closing_after_first_query( await session.execute(query) # Then we do other things - await asyncio.sleep(1000 * MS) + await asyncio.sleep(10 * MS) # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 8abb3410..94eaed2b 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -9,15 +9,15 @@ def test_appsync_init_with_minimal_args(fake_session_factory): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - sample_transport = AppSyncWebsocketsTransport( + transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory() ) - assert isinstance(sample_transport.auth, AppSyncIAMAuthentication) - assert sample_transport.connect_timeout == 10 - assert sample_transport.close_timeout == 10 - assert sample_transport.ack_timeout == 10 - assert sample_transport.ssl is False - assert sample_transport.connect_args == {} + assert isinstance(transport.auth, AppSyncIAMAuthentication) + assert transport.connect_timeout == 10 + assert transport.close_timeout == 10 + assert transport.ack_timeout == 10 + assert transport.ssl is False + assert transport.connect_args == {} @pytest.mark.botocore @@ -27,11 +27,11 @@ def test_appsync_init_with_no_credentials(caplog, fake_session_factory): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): - sample_transport = AppSyncWebsocketsTransport( + transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory(credentials=None), ) - assert sample_transport.auth is None + assert transport.auth is None expected_error = "Credentials not found" @@ -46,8 +46,8 @@ def test_appsync_init_with_jwt_auth(): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth assert auth.get_headers() == { "host": mock_transport_host, @@ -61,8 +61,8 @@ def test_appsync_init_with_apikey_auth(): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth assert auth.get_headers() == { "host": mock_transport_host, @@ -95,8 +95,8 @@ def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): credentials=fake_credentials_factory(), region_name="us-east-1", ) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth @pytest.mark.botocore @@ -153,7 +153,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): signer=fake_signer_factory(), request_creator=fake_request_factory, ) - sample_transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) + transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) header_string = ( "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" @@ -164,7 +164,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" ) - assert sample_transport.url == expected_url + assert transport.url == expected_url @pytest.mark.botocore diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 536b2fe9..168924bc 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -49,9 +49,9 @@ async def handler(request): region_name="us-east-1", ) - sample_transport = AIOHTTPTransport(url=url, auth=auth) + transport = AIOHTTPTransport(url=url, auth=auth) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 37cbe460..0be04034 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -139,7 +139,7 @@ async def realtime_appsync_server_template(ws): ) return - path = ws.path + path = ws.request.path print(f"path = {path}") diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index be214134..c256e5dd 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -91,9 +91,9 @@ async def test_async_client_validation(server, subscription_str, client_params): url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: @@ -138,9 +138,9 @@ async def test_async_client_validation_invalid_query( url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: @@ -171,10 +171,10 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(AssertionError): - async with Client(transport=sample_transport, **client_params): + async with Client(transport=transport, **client_params): pass @@ -261,10 +261,10 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=True, ) as session: diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index 2e3514d1..6f30c8da 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -5,7 +5,6 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, TransportConnectionFailed, TransportProtocolError, TransportQueryError, @@ -111,10 +110,10 @@ async def test_graphqlws_server_does_not_send_ack(graphqlws_server, query_str): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -212,10 +211,10 @@ async def test_graphqlws_server_does_not_ack(graphqlws_server): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -231,10 +230,10 @@ async def test_graphqlws_server_closing_directly(graphqlws_server): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportConnectionFailed): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -251,10 +250,32 @@ async def test_graphqlws_server_closing_after_ack(client_and_graphqlws_server): query = gql("query { hello }") - with pytest.raises(TransportConnectionFailed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed) as exc1: await session.execute(query) + exc1_cause = exc1.value.__cause__ + exc1_cause_str = f"{type(exc1_cause).__name__}:{exc1_cause!s}" + + print(f"\n First query Exception cause: {exc1_cause_str}\n") + + assert ( + exc1_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed) as exc2: await session.execute(query) + + exc2_cause = exc2.value.__cause__ + exc2_cause_str = f"{type(exc2_cause).__name__}:{exc2_cause!s}" + + print(f" Second query Exception cause: {exc2_cause_str}\n") + + assert ( + exc2_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 94028d26..b4c6a17b 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -11,7 +11,7 @@ from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, PyPy, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -814,7 +814,6 @@ async def test_graphqlws_subscription_reconnecting_session( graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.exceptions import TransportClosed from gql.transport.websockets import WebsocketsTransport path = "/graphql" @@ -833,56 +832,62 @@ async def test_graphqlws_subscription_reconnecting_session( reconnecting=True, retry_connect=False, retry_execute=False ) - # First we make a subscription which will cause a disconnect in the backend - # (count=8) - try: - print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") - async for result in session.subscribe(subscription_with_disconnect): - pass - except TransportConnectionFailed: - pass - - await asyncio.sleep(50 * MS) - - # Then with the same session handle, we make a subscription or an execute - # which will detect that the transport is closed so that the client could - # try to reconnect - generator = None + # First we make a query or subscription which will cause a disconnect + # in the backend (count=8) try: if execute_instead_of_subscribe: - print("\nEXECUTION_2\n") - await session.execute(subscription) + print("\nEXECUTION_1\n") + await session.execute(subscription_with_disconnect) else: - print("\nSUBSCRIPTION_2\n") - generator = session.subscribe(subscription) - async for result in generator: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): pass - except (TransportClosed, TransportConnectionFailed): - if generator: - await generator.aclose() + except TransportConnectionFailed: pass - timeout = 50 + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - if PyPy: - timeout = 500 + # Wait for reconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if transport._connected: + print(f"\nConnected again in {i+1} MS") + break - await asyncio.sleep(timeout * MS) + assert transport._connected is True + + # Then after the reconnection, we make a query or a subscription + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + result = await session.execute(subscription) + assert result["number"] == 10 + else: + print("\nSUBSCRIPTION_2\n") + generator = session.subscribe(subscription) + async for result in generator: + number = result["number"] + print(f"Number received: {number}") - # And finally with the same session handle, we make a subscription - # which works correctly - print("\nSUBSCRIPTION_3\n") - generator = session.subscribe(subscription) - async for result in generator: + assert number == count + count -= 1 - number = result["number"] - print(f"Number received: {number}") + await generator.aclose() - assert number == count - count -= 1 + assert count == -1 - await generator.aclose() + # Close the reconnecting session + await client.close_async() - assert count == -1 + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - await client.close_async() + assert transport._connected is False diff --git a/tests/test_http_async_sync.py b/tests/test_http_async_sync.py index 45efd7f5..61dc1809 100644 --- a/tests/test_http_async_sync.py +++ b/tests/test_http_async_sync.py @@ -15,11 +15,11 @@ async def test_async_client_async_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get async transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instantiate client async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) as session: @@ -58,17 +58,17 @@ async def test_async_client_sync_transport(fetch_schema_from_transport): url = "http://countries.trevorblades.com/graphql" # Get sync transport - sample_transport = RequestsHTTPTransport(url=url, use_json=True) + transport = RequestsHTTPTransport(url=url, use_json=True) # Impossible to use a sync transport asynchronously with pytest.raises(AssertionError): async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ): pass - sample_transport.close() + transport.close() @pytest.mark.aiohttp @@ -82,11 +82,11 @@ def test_sync_client_async_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get async transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instanciate client client = Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) @@ -125,11 +125,11 @@ def test_sync_client_sync_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get sync transport - sample_transport = RequestsHTTPTransport(url=url, use_json=True) + transport = RequestsHTTPTransport(url=url, use_json=True) # Instanciate client client = Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index 3b08fa18..c6e84368 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -19,10 +19,10 @@ async def test_httpx_simple_query(): url = "https://countries.trevorblades.com/graphql" # Get transport - sample_transport = HTTPXAsyncTransport(url=url) + transport = HTTPXAsyncTransport(url=url) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -60,11 +60,9 @@ async def test_httpx_invalid_query(): from gql.transport.httpx import HTTPXAsyncTransport - sample_transport = HTTPXAsyncTransport( - url="https://countries.trevorblades.com/graphql" - ) + transport = HTTPXAsyncTransport(url="https://countries.trevorblades.com/graphql") - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -89,12 +87,12 @@ async def test_httpx_two_queries_in_parallel_using_two_tasks(): from gql.transport.httpx import HTTPXAsyncTransport - sample_transport = HTTPXAsyncTransport( + transport = HTTPXAsyncTransport( url="https://countries.trevorblades.com/graphql", ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 09c129b3..b7f11dcb 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -167,13 +167,11 @@ async def test_phoenix_channel_query_protocol_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -197,13 +195,11 @@ async def test_phoenix_channel_query_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportQueryError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -414,13 +410,11 @@ async def test_phoenix_channel_subscription_protocol_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): await asyncio.sleep(10 * MS) break @@ -444,13 +438,11 @@ async def test_phoenix_channel_server_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportServerError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -476,12 +468,12 @@ async def test_phoenix_channel_unsubscribe_error(server, query_str): # Reduce close_timeout. These tests will wait for an unsubscribe # reply that will never come... - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", url=url, close_timeout=1 ) query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): break @@ -504,13 +496,13 @@ async def test_phoenix_channel_unsubscribe_error_forcing(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", url=url, close_timeout=1 ) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): await session.transport._send_stop_message(2) await asyncio.sleep(10 * MS) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 25ca0f0b..ecda9c38 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -191,14 +191,14 @@ async def test_phoenix_channel_subscription(server, subscription_str, end_count) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, close_timeout=5 ) count = 10 subscription = gql(subscription_str.format(count=count)) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: generator = session.subscribe(subscription) async for result in generator: @@ -240,14 +240,14 @@ async def test_phoenix_channel_subscription_no_break(server, subscription_str): async def testing_stopping_without_break(): - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, close_timeout=(5000 * MS) ) count = 10 subscription = gql(subscription_str.format(count=count)) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for result in session.subscribe(subscription): number = result["countdown"]["number"] print(f"Number received: {number}") @@ -372,12 +372,12 @@ async def test_phoenix_channel_heartbeat(server, subscription_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, heartbeat_interval=0.1 ) subscription = gql(heartbeat_subscription_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: i = 0 generator = session.subscribe(subscription) async for result in generator: diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 08058aea..b6169468 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -8,7 +8,6 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, TransportConnectionFailed, TransportProtocolError, TransportQueryError, @@ -118,10 +117,10 @@ async def test_websocket_server_does_not_send_ack(server, query_str): url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -257,10 +256,10 @@ async def test_websocket_server_does_not_ack(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -276,10 +275,10 @@ async def test_websocket_server_closing_directly(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportConnectionFailed): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -296,14 +295,36 @@ async def test_websocket_server_closing_after_ack(client_and_server): query = gql("query { hello }") - with pytest.raises(TransportConnectionFailed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed) as exc1: await session.execute(query) + exc1_cause = exc1.value.__cause__ + exc1_cause_str = f"{type(exc1_cause).__name__}:{exc1_cause!s}" + + print(f"\n First query Exception cause: {exc1_cause_str}\n") + + assert ( + exc1_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed) as exc2: await session.execute(query) + exc2_cause = exc2.value.__cause__ + exc2_cause_str = f"{type(exc2_cause).__name__}:{exc2_cause!s}" + + print(f" Second query Exception cause: {exc2_cause_str}\n") + + assert ( + exc2_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + async def server_sending_invalid_query_errors(ws): await WebSocketServerHelper.send_connection_ack(ws) @@ -323,10 +344,10 @@ async def test_websocket_server_sending_invalid_query_errors(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) # Invalid server message is ignored - async with Client(transport=sample_transport): + async with Client(transport=transport): await asyncio.sleep(2 * MS) @@ -342,9 +363,9 @@ async def test_websocket_non_regression_bug_105(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) # Create a coroutine which start the connection with the transport but does nothing async def client_connect(client): diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 99ff7334..979bb99b 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -8,7 +8,6 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, TransportConnectionFailed, TransportQueryError, TransportServerError, @@ -112,9 +111,7 @@ async def test_websocket_using_ssl_connection(ws_ssl_server): async with Client(transport=transport) as session: - assert isinstance( - transport.adapter.websocket, websockets.client.WebSocketClientProtocol - ) + assert isinstance(transport.adapter.websocket, websockets.ClientConnection) query1 = gql(query1_str) @@ -290,11 +287,11 @@ async def test_websocket_server_closing_after_first_query(client_and_server, que await session.execute(query) # Then we do other things - await asyncio.sleep(100 * MS) + await asyncio.sleep(10 * MS) # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) @@ -663,7 +660,7 @@ async def test_websocket_adapter_connection_closed(server): # Close adapter connection manually (should not be done) await transport.adapter.close() - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query1) # Check client is disconnect here @@ -691,5 +688,5 @@ async def test_websocket_transport_closed_in_receive(server): # await transport.adapter.close() transport._connected = False - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query1) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 89acd635..8d2fd152 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -420,11 +420,9 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( if PyPy: keep_alive_timeout = 200 * MS - sample_transport = WebsocketsTransport( - url=url, keep_alive_timeout=keep_alive_timeout - ) + transport = WebsocketsTransport(url=url, keep_alive_timeout=keep_alive_timeout) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -452,9 +450,9 @@ async def test_websocket_subscription_with_keepalive_with_timeout_nok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -480,9 +478,9 @@ def test_websocket_subscription_sync(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -506,9 +504,9 @@ def test_websocket_subscription_sync_user_exception(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -537,9 +535,9 @@ def test_websocket_subscription_sync_break(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -578,9 +576,9 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -630,9 +628,9 @@ async def test_websocket_subscription_running_in_thread( def test_code(): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) From eda90ed6317fc3d20a61959e120dc2ac0a69508d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 20 May 2025 12:27:25 +0000 Subject: [PATCH 209/239] Allow graphql-core 3.2.6 on stable branch (#547) * Bump graphql-core on stable to 3.2.6 * Using Ubuntu 24.04 for GitHub actions * Remove obsolete 3.7 Python version from GihHub actions --- .github/workflows/deploy.yml | 2 +- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 10 ++++------ setup.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index da129836..96ba8e85 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -7,7 +7,7 @@ on: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 39f5cf0c..7e1cacd2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,7 +4,7 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 30e8289c..6d961f2b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,11 +8,9 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.8"] - os: [ubuntu-20.04, windows-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.8"] + os: [ubuntu-24.04, windows-latest] exclude: - - os: windows-latest - python-version: "3.7" - os: windows-latest python-version: "3.9" - os: windows-latest @@ -40,7 +38,7 @@ jobs: TOXENV: ${{ matrix.toxenv }} single_extra: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: fail-fast: false matrix: @@ -60,7 +58,7 @@ jobs: run: pytest tests --${{ matrix.dependency }}-only coverage: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 diff --git a/setup.py b/setup.py index f34b2e35..f7b96ede 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.2,<3.2.5", + "graphql-core>=3.2,<3.2.7", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", From 2dd1adbbbc6585aa6fef51b767c4d91ffde77d85 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 20 May 2025 14:29:51 +0200 Subject: [PATCH 210/239] Bump version number to 3.5.3 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index dae42b1b..ad45aa38 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.5.2" +__version__ = "3.5.3" From 476d133c425c49d4ac2aaeb295258176042abdb8 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 20 May 2025 15:06:38 +0200 Subject: [PATCH 211/239] Bump version number to 4.0.0a0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index cfe6b54e..7870304a 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0b4" +__version__ = "4.0.0a0" From a87d97af79aacf4633ec7b4816c82dc692341676 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 20 May 2025 16:14:22 +0200 Subject: [PATCH 212/239] Upgrade pypi release github action --- .github/workflows/deploy.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 69c11d2a..1b489a95 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -15,12 +15,14 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.12 - - name: Build wheel and source tarball + - name: Install build dependencies run: | - pip install wheel setuptools - python setup.py sdist bdist_wheel + python -m pip install --upgrade pip + pip install build wheel + - name: Build package + run: | + python -m build - name: Publish a Python distribution to PyPI - uses: pypa/gh-action-pypi-publish@v1.1.0 + uses: pypa/gh-action-pypi-publish@release/v1 with: - user: __token__ password: ${{ secrets.pypi_password }} From 9a84f0b4224a1dc05cc41895458427bd7f7df83a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 20 May 2025 16:45:52 +0000 Subject: [PATCH 213/239] Fix subscription task cancel exception swallow (#548) --- gql/transport/common/base.py | 3 +-- ...iohttp_websocket_graphqlws_subscription.py | 19 ++++++++++++++----- tests/test_aiohttp_websocket_subscription.py | 19 ++++++++++++++----- tests/test_graphqlws_subscription.py | 19 ++++++++++++++----- tests/test_websocket_subscription.py | 19 ++++++++++++++----- 5 files changed, 57 insertions(+), 22 deletions(-) diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index cae8f488..a285ad2c 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -317,8 +317,7 @@ async def subscribe( if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False - if isinstance(e, GeneratorExit): - raise e + raise e finally: log.debug(f"In subscribe finally for query_id {query_id}") diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 7c000d01..22dd1004 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -292,16 +292,24 @@ async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -317,6 +325,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 83ae3589..32daf038 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -283,16 +283,24 @@ async def test_aiohttp_websocket_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -308,6 +316,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index b4c6a17b..45e7aba4 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -290,16 +290,24 @@ async def test_graphqlws_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -315,6 +323,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 8d2fd152..487b9ba5 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -210,16 +210,24 @@ async def test_websocket_subscription_task_cancel(client_and_server, subscriptio count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -235,6 +243,7 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio From f0fd64db3dc1f54755a454b3c8dd04ace4630d62 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 23 May 2025 11:18:07 +0000 Subject: [PATCH 214/239] Clean up the file upload interface with FileVar class (#549) --- docs/usage/file_upload.rst | 98 +++--- gql/__init__.py | 2 + gql/transport/aiohttp.py | 103 ++++--- gql/transport/file_upload.py | 126 ++++++++ gql/transport/httpx.py | 33 +- gql/transport/requests.py | 33 +- gql/utils.py | 39 +-- setup.py | 7 +- tests/conftest.py | 59 ++++ tests/test_aiohttp.py | 579 +++++++++++++++++++++-------------- tests/test_httpx.py | 404 +++++++++++------------- tests/test_httpx_async.py | 326 ++++++++------------ tests/test_requests.py | 472 ++++++++++++++++------------ 13 files changed, 1289 insertions(+), 992 deletions(-) create mode 100644 gql/transport/file_upload.py diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index 10903585..7793354b 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -14,11 +14,14 @@ Single File In order to upload a single file, you need to: * set the file as a variable value in the mutation -* provide the opened file to the `variable_values` argument of `execute` +* create a :class:`FileVar ` object with your file path +* provide the `FileVar` instance to the `variable_values` argument of `execute` * set the `upload_files` argument to True .. code-block:: python + from gql import client, gql, FileVar + transport = AIOHTTPTransport(url='YOUR_URL') # Or transport = RequestsHTTPTransport(url='YOUR_URL') # Or transport = HTTPXTransport(url='YOUR_URL') @@ -34,32 +37,38 @@ In order to upload a single file, you need to: } ''') - with open("YOUR_FILE_PATH", "rb") as f: - - params = {"file": f} + params = {"file": FileVar("YOUR_FILE_PATH")} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute( + query, variable_values=params, upload_files=True + ) Setting the content-type ^^^^^^^^^^^^^^^^^^^^^^^^ If you need to set a specific Content-Type attribute to a file, -you can set the :code:`content_type` attribute of the file like this: +you can set the :code:`content_type` attribute of :class:`FileVar `: .. code-block:: python - with open("YOUR_FILE_PATH", "rb") as f: + # Setting the content-type to a pdf file for example + filevar = FileVar( + "YOUR_FILE_PATH", + content_type="application/pdf", + ) - # Setting the content-type to a pdf file for example - f.content_type = "application/pdf" +Setting the uploaded file name +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - params = {"file": f} +To modify the uploaded filename, use the :code:`filename` attribute of :class:`FileVar `: - result = client.execute( - query, variable_values=params, upload_files=True - ) +.. code-block:: python + + # Setting the content-type to a pdf file for example + filevar = FileVar( + "YOUR_FILE_PATH", + filename="filename1.txt", + ) File list --------- @@ -68,6 +77,8 @@ It is also possible to upload multiple files using a list. .. code-block:: python + from gql import client, gql, FileVar + transport = AIOHTTPTransport(url='YOUR_URL') # Or transport = RequestsHTTPTransport(url='YOUR_URL') # Or transport = HTTPXTransport(url='YOUR_URL') @@ -83,8 +94,8 @@ It is also possible to upload multiple files using a list. } ''') - f1 = open("YOUR_FILE_PATH_1", "rb") - f2 = open("YOUR_FILE_PATH_2", "rb") + f1 = FileVar("YOUR_FILE_PATH_1") + f2 = FileVar("YOUR_FILE_PATH_2") params = {"files": [f1, f2]} @@ -92,9 +103,6 @@ It is also possible to upload multiple files using a list. query, variable_values=params, upload_files=True ) - f1.close() - f2.close() - Streaming --------- @@ -120,18 +128,8 @@ Streaming local files aiohttp allows to upload files using an asynchronous generator. See `Streaming uploads on aiohttp docs`_. - -In order to stream local files, instead of providing opened files to the -`variable_values` argument of `execute`, you need to provide an async generator -which will provide parts of the files. - -You can use `aiofiles`_ -to read the files in chunks and create this asynchronous generator. - -.. _Streaming uploads on aiohttp docs: https://docs.aiohttp.org/en/stable/client_quickstart.html#streaming-uploads -.. _aiofiles: https://github.com/Tinche/aiofiles - -Example: +From gql version 4.0, it is possible to activate file streaming simply by +setting the `streaming` argument of :class:`FileVar ` to `True` .. code-block:: python @@ -147,18 +145,38 @@ Example: } ''') + f1 = FileVar( + file_name='YOUR_FILE_PATH', + streaming=True, + ) + + params = {"file": f1} + + result = client.execute( + query, variable_values=params, upload_files=True + ) + +Another option is to use an async generator to provide parts of the file. + +You can use `aiofiles`_ +to read the files in chunks and create this asynchronous generator. + +.. _Streaming uploads on aiohttp docs: https://docs.aiohttp.org/en/stable/client_quickstart.html#streaming-uploads +.. _aiofiles: https://github.com/Tinche/aiofiles + +.. code-block:: python + async def file_sender(file_name): async with aiofiles.open(file_name, 'rb') as f: - chunk = await f.read(64*1024) - while chunk: - yield chunk - chunk = await f.read(64*1024) + while chunk := await f.read(64*1024): + yield chunk - params = {"file": file_sender(file_name='YOUR_FILE_PATH')} + f1 = FileVar(file_sender(file_name='YOUR_FILE_PATH')) + params = {"file": f1} result = client.execute( - query, variable_values=params, upload_files=True - ) + query, variable_values=params, upload_files=True + ) Streaming downloaded files ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -200,7 +218,7 @@ Example: } ''') - params = {"file": resp.content} + params = {"file": FileVar(resp.content)} result = client.execute( query, variable_values=params, upload_files=True diff --git a/gql/__init__.py b/gql/__init__.py index 8eaa0b7c..4c9a6aa0 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -11,10 +11,12 @@ from .client import Client from .gql import gql from .graphql_request import GraphQLRequest +from .transport.file_upload import FileVar __all__ = [ "__version__", "gql", "Client", "GraphQLRequest", + "FileVar", ] diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 76b46c35..b2633abb 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -23,7 +23,6 @@ from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy -from ..utils import extract_files from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport from .common.aiohttp_closed_event import create_aiohttp_closed_event @@ -33,6 +32,7 @@ TransportProtocolError, TransportServerError, ) +from .file_upload import FileVar, close_files, extract_files, open_files log = logging.getLogger(__name__) @@ -207,6 +207,10 @@ async def execute( file_classes=self.file_classes, ) + # Opening the files using the FileVar parameters + open_files(list(files.values()), transport_supports_streaming=True) + self.files = files + # Save the nulled variable values in the payload payload["variables"] = nulled_variable_values @@ -220,8 +224,8 @@ async def execute( file_map = {str(i): [path] for i, path in enumerate(files)} # Enumerate the file streams - # Will generate something like {'0': <_io.BufferedReader ...>} - file_streams = {str(i): files[path] for i, path in enumerate(files)} + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} # Add the payload to the operations field operations_str = self.json_serialize(payload) @@ -235,12 +239,15 @@ async def execute( log.debug("file_map %s", file_map_str) data.add_field("map", file_map_str, content_type="application/json") - # Add the extracted files as remaining fields - for k, f in file_streams.items(): - name = getattr(f, "name", k) - content_type = getattr(f, "content_type", None) + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) - data.add_field(k, f, filename=name, content_type=content_type) + data.add_field( + k, + file_var.f, + filename=file_var.filename, + content_type=file_var.content_type, + ) post_args: Dict[str, Any] = {"data": data} @@ -267,51 +274,59 @@ async def execute( if self.session is None: raise TransportClosed("Transport is not connected") - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: - - # Saving latest response headers in the transport - self.response_headers = resp.headers + try: + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: - async def raise_response_error( - resp: aiohttp.ClientResponse, reason: str - ) -> NoReturn: - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases + # Saving latest response headers in the transport + self.response_headers = resp.headers - try: - # Raise a ClientResponseError if response status is 400 or higher - resp.raise_for_status() - except ClientResponseError as e: - raise TransportServerError(str(e), e.status) from e - - result_text = await resp.text() - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + async def raise_response_error( + resp: aiohttp.ClientResponse, reason: str + ) -> NoReturn: + # We raise a TransportServerError if status code is 400 or higher + # We raise a TransportProtocolError in the other cases - try: - result = await resp.json(loads=self.json_deserialize, content_type=None) + try: + # Raise ClientResponseError if response status is 400 or higher + resp.raise_for_status() + except ClientResponseError as e: + raise TransportServerError(str(e), e.status) from e - if log.isEnabledFor(logging.INFO): result_text = await resp.text() - log.info("<<< %s", result_text) + raise TransportProtocolError( + f"Server did not return a GraphQL result: " + f"{reason}: " + f"{result_text}" + ) - except Exception: - await raise_response_error(resp, "Not a JSON answer") + try: + result = await resp.json( + loads=self.json_deserialize, content_type=None + ) - if result is None: - await raise_response_error(resp, "Not a JSON answer") + if log.isEnabledFor(logging.INFO): + result_text = await resp.text() + log.info("<<< %s", result_text) - if "errors" not in result and "data" not in result: - await raise_response_error(resp, 'No "data" or "errors" keys in answer') + except Exception: + await raise_response_error(resp, "Not a JSON answer") - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) + if result is None: + await raise_response_error(resp, "Not a JSON answer") + + if "errors" not in result and "data" not in result: + await raise_response_error( + resp, 'No "data" or "errors" keys in answer' + ) + + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) + finally: + if upload_files: + close_files(list(self.files.values())) def subscribe( self, diff --git a/gql/transport/file_upload.py b/gql/transport/file_upload.py new file mode 100644 index 00000000..8673ab60 --- /dev/null +++ b/gql/transport/file_upload.py @@ -0,0 +1,126 @@ +import io +import os +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type + + +class FileVar: + def __init__( + self, + f: Any, # str | io.IOBase | aiohttp.StreamReader | AsyncGenerator + *, + filename: Optional[str] = None, + content_type: Optional[str] = None, + streaming: bool = False, + streaming_block_size: int = 64 * 1024, + ): + self.f = f + self.filename = filename + self.content_type = content_type + self.streaming = streaming + self.streaming_block_size = streaming_block_size + + self._file_opened: bool = False + + def open_file( + self, + transport_supports_streaming: bool = False, + ) -> None: + assert self._file_opened is False + + if self.streaming: + assert ( + transport_supports_streaming + ), "streaming not supported on this transport" + self._make_file_streamer() + else: + if isinstance(self.f, str): + if self.filename is None: + # By default we set the filename to the basename + # of the opened file + self.filename = os.path.basename(self.f) + self.f = open(self.f, "rb") + self._file_opened = True + + def close_file(self) -> None: + if self._file_opened: + assert isinstance(self.f, io.IOBase) + self.f.close() + self._file_opened = False + + def _make_file_streamer(self) -> None: + assert isinstance(self.f, str), "streaming option needs a filepath str" + + import aiofiles + + async def file_sender(file_name): + async with aiofiles.open(file_name, "rb") as f: + while chunk := await f.read(self.streaming_block_size): + yield chunk + + self.f = file_sender(self.f) + + +def open_files( + filevars: List[FileVar], + transport_supports_streaming: bool = False, +) -> None: + + for filevar in filevars: + filevar.open_file(transport_supports_streaming=transport_supports_streaming) + + +def close_files(filevars: List[FileVar]) -> None: + for filevar in filevars: + filevar.close_file() + + +FILE_UPLOAD_DOCS = "https://gql.readthedocs.io/en/latest/usage/file_upload.html" + + +def extract_files( + variables: Dict, file_classes: Tuple[Type[Any], ...] +) -> Tuple[Dict, Dict[str, FileVar]]: + files: Dict[str, FileVar] = {} + + def recurse_extract(path, obj): + """ + recursively traverse obj, doing a deepcopy, but + replacing any file-like objects with nulls and + shunting the originals off to the side. + """ + nonlocal files + if isinstance(obj, list): + nulled_list = [] + for key, value in enumerate(obj): + value = recurse_extract(f"{path}.{key}", value) + nulled_list.append(value) + return nulled_list + elif isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = recurse_extract(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + elif isinstance(obj, file_classes): + # extract obj from its parent and put it into files instead. + warnings.warn( + "Not using FileVar for file upload is deprecated. " + f"See {FILE_UPLOAD_DOCS} for details.", + DeprecationWarning, + ) + name = getattr(obj, "name", None) + content_type = getattr(obj, "content_type", None) + files[path] = FileVar(obj, filename=name, content_type=content_type) + return None + elif isinstance(obj, FileVar): + # extract obj from its parent and put it into files instead. + files[path] = obj + return None + else: + # base case: pass through unchanged + return obj + + nulled_variables = recurse_extract("variables", variables) + + return nulled_variables, files diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 4c5d33d0..eb15ac57 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -12,13 +12,11 @@ Tuple, Type, Union, - cast, ) import httpx from graphql import DocumentNode, ExecutionResult, print_ast -from ..utils import extract_files from . import AsyncTransport, Transport from .exceptions import ( TransportAlreadyConnected, @@ -26,6 +24,7 @@ TransportProtocolError, TransportServerError, ) +from .file_upload import close_files, extract_files, open_files log = logging.getLogger(__name__) @@ -104,6 +103,10 @@ def _prepare_file_uploads( file_classes=self.file_classes, ) + # Opening the files using the FileVar parameters + open_files(list(files.values())) + self.files = files + # Save the nulled variable values in the payload payload["variables"] = nulled_variable_values @@ -112,7 +115,7 @@ def _prepare_file_uploads( file_map: Dict[str, List[str]] = {} file_streams: Dict[str, Tuple[str, ...]] = {} - for i, (path, f) in enumerate(files.items()): + for i, (path, file_var) in enumerate(files.items()): key = str(i) # Generate the file map @@ -121,16 +124,12 @@ def _prepare_file_uploads( # Will generate something like {"0": ["variables.file"]} file_map[key] = [path] - # Generate the file streams - # Will generate something like - # {"0": ("variables.file", <_io.BufferedReader ...>)} - name = cast(str, getattr(f, "name", key)) - content_type = getattr(f, "content_type", None) + name = key if file_var.filename is None else file_var.filename - if content_type is None: - file_streams[key] = (name, f) + if file_var.content_type is None: + file_streams[key] = (name, file_var.f) else: - file_streams[key] = (name, f, content_type) + file_streams[key] = (name, file_var.f, file_var.content_type) # Add the payload to the operations field operations_str = self.json_serialize(payload) @@ -232,7 +231,11 @@ def execute( # type: ignore upload_files, ) - response = self.client.post(self.url, **post_args) + try: + response = self.client.post(self.url, **post_args) + finally: + if upload_files: + close_files(list(self.files.values())) return self._prepare_result(response) @@ -295,7 +298,11 @@ async def execute( upload_files, ) - response = await self.client.post(self.url, **post_args) + try: + response = await self.client.post(self.url, **post_args) + finally: + if upload_files: + close_files(list(self.files.values())) return self._prepare_result(response) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 44f8a362..5fb7e827 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -25,13 +25,13 @@ from gql.transport import Transport from ..graphql_request import GraphQLRequest -from ..utils import extract_files from .exceptions import ( TransportAlreadyConnected, TransportClosed, TransportProtocolError, TransportServerError, ) +from .file_upload import FileVar, close_files, extract_files, open_files log = logging.getLogger(__name__) @@ -190,6 +190,10 @@ def execute( # type: ignore file_classes=self.file_classes, ) + # Opening the files using the FileVar parameters + open_files(list(files.values())) + self.files = files + # Save the nulled variable values in the payload payload["variables"] = nulled_variable_values @@ -204,8 +208,8 @@ def execute( # type: ignore file_map = {str(i): [path] for i, path in enumerate(files)} # Enumerate the file streams - # Will generate something like {'0': <_io.BufferedReader ...>} - file_streams = {str(i): files[path] for i, path in enumerate(files)} + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} # Add the file map field file_map_str = self.json_serialize(file_map) @@ -214,14 +218,14 @@ def execute( # type: ignore fields = {"operations": operations_str, "map": file_map_str} # Add the extracted files as remaining fields - for k, f in file_streams.items(): - name = getattr(f, "name", k) - content_type = getattr(f, "content_type", None) + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) + name = k if file_var.filename is None else file_var.filename - if content_type is None: - fields[k] = (name, f) + if file_var.content_type is None: + fields[k] = (name, file_var.f) else: - fields[k] = (name, f, content_type) + fields[k] = (name, file_var.f, file_var.content_type) # Prepare requests http to send multipart-encoded data data = MultipartEncoder(fields=fields) @@ -254,9 +258,14 @@ def execute( # type: ignore post_args.update(extra_args) # Using the created session to perform requests - response = self.session.request( - self.method, self.url, **post_args # type: ignore - ) + try: + response = self.session.request( + self.method, self.url, **post_args # type: ignore + ) + finally: + if upload_files: + close_files(list(self.files.values())) + self.response_headers = response.headers def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: diff --git a/gql/utils.py b/gql/utils.py index 6a7d0791..f7f0f5a7 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -1,6 +1,6 @@ """Utilities to manipulate several python objects.""" -from typing import Any, Dict, List, Tuple, Type +from typing import List # From this response in Stackoverflow @@ -12,43 +12,6 @@ def to_camel_case(snake_str): return components[0] + "".join(x.title() if x else "_" for x in components[1:]) -def extract_files( - variables: Dict, file_classes: Tuple[Type[Any], ...] -) -> Tuple[Dict, Dict]: - files = {} - - def recurse_extract(path, obj): - """ - recursively traverse obj, doing a deepcopy, but - replacing any file-like objects with nulls and - shunting the originals off to the side. - """ - nonlocal files - if isinstance(obj, list): - nulled_list = [] - for key, value in enumerate(obj): - value = recurse_extract(f"{path}.{key}", value) - nulled_list.append(value) - return nulled_list - elif isinstance(obj, dict): - nulled_dict = {} - for key, value in obj.items(): - value = recurse_extract(f"{path}.{key}", value) - nulled_dict[key] = value - return nulled_dict - elif isinstance(obj, file_classes): - # extract obj from its parent and put it into files instead. - files[path] = obj - return None - else: - # base case: pass through unchanged - return obj - - nulled_variables = recurse_extract("variables", variables) - - return nulled_variables, files - - def str_first_element(errors: List) -> str: try: first_error = errors[0] diff --git a/setup.py b/setup.py index aed15440..706a80c3 100644 --- a/setup.py +++ b/setup.py @@ -58,8 +58,12 @@ "botocore>=1.21,<2", ] +install_aiofiles_requires = [ + "aiofiles", +] + install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_httpx_requires + install_websockets_requires + install_botocore_requires + install_aiohttp_requires + install_requests_requires + install_httpx_requires + install_websockets_requires + install_botocore_requires + install_aiofiles_requires ) # Get version from __version__.py file @@ -107,6 +111,7 @@ "httpx": install_httpx_requires, "websockets": install_websockets_requires, "botocore": install_botocore_requires, + "aiofiles": install_aiofiles_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index c69551b0..cef561f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -763,3 +763,62 @@ def strip_braces_spaces(s): strip_back = re.sub(r"([^\s]) }", r"\1}", strip_front) return strip_back + + +def make_upload_handler( + nb_files=1, + filenames=None, + request_headers=None, + file_headers=None, + binary=False, + expected_contents=None, + expected_operations=None, + expected_map=None, + server_answer='{"data":{"success":true}}', +): + assert expected_contents is not None + assert expected_operations is not None + assert expected_map is not None + + async def single_upload_handler(request): + from aiohttp import web + + reader = await request.multipart() + + if request_headers is not None: + for k, v in request_headers.items(): + assert request.headers[k] == v + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert strip_braces_spaces(field_0_text) == expected_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == expected_map + + for i in range(nb_files): + field = await reader.next() + assert field.name == str(i) + if filenames is not None: + assert field.filename == filenames[i] + + if binary: + field_content = await field.read() + assert field_content == expected_contents[i] + else: + field_text = await field.text() + assert field_text == expected_contents[i] + + if file_headers is not None: + for k, v in file_headers[i].items(): + assert field.headers[k] == v + + final_field = await reader.next() + assert final_field is None + + return web.Response(text=server_answer, content_type="application/json") + + return single_upload_handler diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 04417c4e..fe36585e 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,10 +1,12 @@ import io import json +import os +import warnings from typing import Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -17,7 +19,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) query1_str = """ @@ -600,8 +602,6 @@ def test_code(): await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' - file_upload_mutation_1 = """ mutation($file: Upload!) { uploadFile(input:{other_var:$other_var, file:$file}) { @@ -624,33 +624,6 @@ def test_code(): """ -async def single_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response(text=file_upload_server_answer, content_type="application/json") - - @pytest.mark.asyncio async def test_aiohttp_file_upload(aiohttp_server): from aiohttp import web @@ -658,7 +631,15 @@ async def test_aiohttp_file_upload(aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -673,48 +654,45 @@ async def test_aiohttp_file_upload(aiohttp_server): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: params = {"file": f, "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute( + query, variable_values=params, upload_files=True + ) success = result["success"] - assert success + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: -async def single_upload_handler_with_content_type(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + params = {"file": FileVar(f), "other_var": 42} - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + result = await session.execute( + query, variable_values=params, upload_files=True + ) - # Verifying the content_type - assert field_2.headers["Content-Type"] == "application/pdf" + success = result["success"] + assert success - field_3 = await reader.next() - assert field_3 is None + # Using an filename string inside a FileVar object + params = {"file": FileVar(file_path), "other_var": 42} + result = await session.execute( + query, variable_values=params, upload_files=True + ) - return web.Response(text=file_upload_server_answer, content_type="application/json") + success = result["success"] + assert success @pytest.mark.asyncio @@ -724,7 +702,16 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler_with_content_type) + app.router.add_route( + "POST", + "/", + make_upload_handler( + file_headers=[{"Content-Type": "application/pdf"}], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -739,6 +726,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: # Setting the content_type @@ -746,83 +734,185 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): params = {"file": f, "other_var": 42} - # Execute query asynchronously + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + assert success + + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + params = { + "file": FileVar( + f, + content_type="application/pdf", + ), + "other_var": 42, + } + result = await session.execute( query, variable_values=params, upload_files=True ) success = result["success"] + assert success + + # Using an filename string inside a FileVar object + params = { + "file": FileVar( + file_path, + content_type="application/pdf", + ), + "other_var": 42, + } + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] assert success @pytest.mark.asyncio -async def test_aiohttp_file_upload_without_session(aiohttp_server, run_sync_test): +async def test_aiohttp_file_upload_default_filename_is_basename(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) - server = await aiohttp_server(app) - url = server.make_url("/") + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + file_basename = os.path.basename(file_path) + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=[file_basename], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - def test_code(): - transport = AIOHTTPTransport(url=url, timeout=10) + url = server.make_url("/") - with TemporaryFile(file_1_content) as test_file: + transport = AIOHTTPTransport(url=url, timeout=10) - client = Client(transport=transport) + async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) - file_path = test_file.filename + params = { + "file": FileVar( + file_path, + ), + "other_var": 42, + } - with open(file_path, "rb") as f: + result = await session.execute( + query, variable_values=params, upload_files=True + ) - params = {"file": f, "other_var": 42} + success = result["success"] + assert success - result = client.execute( - query, variable_values=params, upload_files=True - ) - success = result["success"] +@pytest.mark.asyncio +async def test_aiohttp_file_upload_with_filename(aiohttp_server): + from aiohttp import web - assert success + from gql.transport.aiohttp import AIOHTTPTransport - await run_sync_test(server, test_code) + app = web.Application() + + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=["filename1.txt"], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) + url = server.make_url("/") -# This is a sample binary file content containing all possible byte values -binary_file_content = bytes(range(0, 256)) + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) -async def binary_upload_handler(request): + params = { + "file": FileVar( + file_path, + filename="filename1.txt", + ), + "other_var": 42, + } + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + assert success + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web - reader = await request.multipart() + from gql.transport.aiohttp import AIOHTTPTransport + + app = web.Application() + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file: + + client = Client(transport=transport) - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + query = gql(file_upload_mutation_1) - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + file_path = test_file.filename - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content + params = {"file": FileVar(file_path), "other_var": 42} - field_3 = await reader.next() - assert field_3 is None + result = client.execute(query, variable_values=params, upload_files=True) - return web.Response(text=file_upload_server_answer, content_type="application/json") + success = result["success"] + assert success + + await run_sync_test(server, test_code) @pytest.mark.asyncio @@ -831,8 +921,20 @@ async def test_aiohttp_binary_file_upload(aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -847,14 +949,12 @@ async def test_aiohttp_binary_file_upload(aiohttp_server): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} + params = {"file": FileVar(file_path), "other_var": 42} - # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) success = result["success"] @@ -867,13 +967,25 @@ async def test_aiohttp_stream_reader_upload(aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + async def binary_data_handler(request): return web.Response( body=binary_file_content, content_type="binary/octet-stream" ) app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) app.router.add_route("GET", "/binary_data", binary_data_handler) server = await aiohttp_server(app) @@ -883,19 +995,36 @@ async def binary_data_handler(request): transport = AIOHTTPTransport(url=url, timeout=10) + # Not using FileVar async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) async with ClientSession() as client: async with client.get(binary_data_url) as resp: params = {"file": resp.content, "other_var": 42} - # Execute query asynchronously + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + assert success + + # Using FileVar + async with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) + async with ClientSession() as client: + async with client.get(binary_data_url) as resp: + params = {"file": FileVar(resp.content), "other_var": 42} + result = await session.execute( query, variable_values=params, upload_files=True ) success = result["success"] - assert success @@ -906,30 +1035,59 @@ async def test_aiohttp_async_generator_upload(aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = server.make_url("/") transport = AIOHTTPTransport(url=url, timeout=10) + query = gql(file_upload_mutation_1) + with TemporaryFile(binary_file_content) as test_file: + file_path = test_file.filename + + async def file_sender(file_name): + async with aiofiles.open(file_name, "rb") as f: + chunk = await f.read(64 * 1024) + while chunk: + yield chunk + chunk = await f.read(64 * 1024) + + # Not using FileVar async with Client(transport=transport) as session: - query = gql(file_upload_mutation_1) + params = {"file": file_sender(file_path), "other_var": 42} - file_path = test_file.filename + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute( + query, variable_values=params, upload_files=True + ) - async def file_sender(file_name): - async with aiofiles.open(file_name, "rb") as f: - chunk = await f.read(64 * 1024) - while chunk: - yield chunk - chunk = await f.read(64 * 1024) + success = result["success"] + assert success - params = {"file": file_sender(file_path), "other_var": 42} + # Using FileVar + async with Client(transport=transport) as session: + + params = {"file": FileVar(file_sender(file_path)), "other_var": 42} # Execute query asynchronously result = await session.execute( @@ -937,30 +1095,23 @@ async def file_sender(file_name): ) success = result["success"] - assert success + # Using FileVar with new streaming support + async with Client(transport=transport) as session: -file_upload_mutation_2 = """ - mutation($file1: Upload!, $file2: Upload!) { - uploadFile(input:{file1:$file, file2:$file}) { - success - } - } -""" - -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) + params = { + "file": FileVar(file_path, streaming=True), + "other_var": 42, + } -file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) -file_2_content = """ -This is a second test file -This file will also be sent in the GraphQL mutation -""" + success = result["success"] + assert success @pytest.mark.asyncio @@ -969,39 +1120,38 @@ async def test_aiohttp_file_upload_two_files(aiohttp_server): from gql.transport.aiohttp import AIOHTTPTransport - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -1018,82 +1168,60 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - params = { - "file1": f1, - "file2": f2, + "file1": FileVar(file_path_1), + "file2": FileVar(file_path_2), } result = await session.execute( query, variable_values=params, upload_files=True ) - f1.close() - f2.close() - success = result["success"] assert success -file_upload_mutation_3 = """ - mutation($files: [Upload!]!) { - uploadFiles(input:{files:$files}) { - success - } - } -""" - -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' - "input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) - -file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' - - @pytest.mark.asyncio async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web from gql.transport.aiohttp import AIOHTTPTransport - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{files:$files}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_3_map = ( + '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + ) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -1110,19 +1238,18 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = {"files": [f1, f2]} + params = { + "files": [ + FileVar(file_path_1), + FileVar(file_path_2), + ], + } # Execute query asynchronously result = await session.execute( query, variable_values=params, upload_files=True ) - f1.close() - f2.close() - success = result["success"] assert success diff --git a/tests/test_httpx.py b/tests/test_httpx.py index d129f022..9558e137 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -1,8 +1,9 @@ +import os from typing import Any, Dict, Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -14,7 +15,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) # Marking all tests in this file with the httpx marker @@ -516,8 +517,6 @@ def test_code(): await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' - file_upload_mutation_1 = """ mutation($file: Upload!) { uploadFile(input:{other_var:$other_var, file:$file}) { @@ -547,35 +546,16 @@ async def test_httpx_file_upload(aiohttp_server, run_sync_test): from gql.transport.httpx import HTTPXTransport - async def single_upload_handler(request): - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -589,15 +569,41 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: params = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + params = {"file": FileVar(f), "other_var": 42} execution_result = session._execute( query, variable_values=params, upload_files=True ) assert execution_result.data["success"] + # Using an filename string inside a FileVar object + params = { + "file": FileVar(file_path), + "other_var": 42, + } + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + await run_sync_test(server, test_code) @@ -608,38 +614,17 @@ async def test_httpx_file_upload_with_content_type(aiohttp_server, run_sync_test from gql.transport.httpx import HTTPXTransport - async def single_upload_handler(request): - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - # Verifying the content_type - assert field_2.headers["Content-Type"] == "application/pdf" - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + file_headers=[{"Content-Type": "application/pdf"}], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -653,59 +638,104 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: # Setting the content_type f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) assert execution_result.data["success"] + # Using FileVar + params = { + "file": FileVar(file_path, content_type="application/pdf"), + "other_var": 42, + } + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_additional_headers(aiohttp_server, run_sync_test): +async def test_httpx_file_upload_default_filename_is_basename( + aiohttp_server, run_sync_test +): from aiohttp import web from gql.transport.httpx import HTTPXTransport - async def single_upload_handler(request): - from aiohttp import web + app = web.Application() - assert request.headers["X-Auth"] == "foobar" + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + file_basename = os.path.basename(file_path) + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=[file_basename], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - reader = await request.multipart() + url = str(server.make_url("/")) - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + def test_code(): + transport = HTTPXTransport(url=url) - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + # Using FileVar + params = { + "file": FileVar(file_path), + "other_var": 42, + } + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) - field_3 = await reader.next() - assert field_3 is None + assert execution_result.data["success"] - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + await run_sync_test(server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_additional_headers(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + request_headers={"X-Auth": "foobar"}, + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -719,14 +749,12 @@ def test_code(): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + params = {"file": FileVar(file_path), "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) - assert execution_result.data["success"] + assert execution_result.data["success"] await run_sync_test(server, test_code) @@ -741,36 +769,17 @@ async def test_httpx_binary_file_upload(aiohttp_server, run_sync_test): # This is a sample binary file content containing all possible byte values binary_file_content = bytes(range(0, 256)) - async def binary_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -785,26 +794,17 @@ def test_code(): file_path = test_file.filename - with open(file_path, "rb") as f: + params = {"file": FileVar(file_path), "other_var": 42} - params = {"file": f, "other_var": 42} - - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) - assert execution_result.data["success"] + assert execution_result.data["success"] await run_sync_test(server, test_code) -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): @@ -820,6 +820,12 @@ async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): } """ + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' file_2_content = """ @@ -827,39 +833,17 @@ async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -877,12 +861,9 @@ def test_code(): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - params = { - "file1": f1, - "file2": f2, + "file1": FileVar(file_path_1), + "file2": FileVar(file_path_2), } execution_result = session._execute( @@ -891,19 +872,9 @@ def test_code(): assert execution_result.data["success"] - f1.close() - f2.close() - await run_sync_test(server, test_code) -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' - "(input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test): @@ -919,6 +890,12 @@ async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test } """ + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) + file_upload_mutation_3_map = ( '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' ) @@ -928,39 +905,17 @@ async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -977,10 +932,12 @@ def test_code(): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = {"files": [f1, f2]} + params = { + "files": [ + FileVar(file_path_1), + FileVar(file_path_2), + ], + } execution_result = session._execute( query, variable_values=params, upload_files=True @@ -988,9 +945,6 @@ def test_code(): assert execution_result.data["success"] - f1.close() - f2.close() - await run_sync_test(server, test_code) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 49ea6a24..ddacbc14 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -17,7 +17,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) query1_str = """ @@ -613,8 +613,6 @@ def test_code(): await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' - file_upload_mutation_1 = """ mutation($file: Upload!) { uploadFile(input:{other_var:$other_var, file:$file}) { @@ -637,33 +635,6 @@ def test_code(): """ -async def single_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response(text=file_upload_server_answer, content_type="application/json") - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_file_upload(aiohttp_server): @@ -672,7 +643,15 @@ async def test_httpx_file_upload(aiohttp_server): from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -687,17 +666,45 @@ async def test_httpx_file_upload(aiohttp_server): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: params = {"file": f, "other_var": 42} + # Execute query asynchronously + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + assert success + + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + params = {"file": FileVar(f), "other_var": 42} + # Execute query asynchronously result = await session.execute( query, variable_values=params, upload_files=True ) success = result["success"] + assert success + + # Using an filename string inside a FileVar object + params = {"file": FileVar(file_path), "other_var": 42} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + success = result["success"] assert success @@ -709,7 +716,15 @@ async def test_httpx_file_upload_without_session(aiohttp_server, run_sync_test): from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -725,52 +740,17 @@ def test_code(): file_path = test_file.filename - with open(file_path, "rb") as f: + params = {"file": FileVar(file_path), "other_var": 42} - params = {"file": f, "other_var": 42} + result = client.execute(query, variable_values=params, upload_files=True) - result = client.execute( - query, variable_values=params, upload_files=True - ) - - success = result["success"] + success = result["success"] - assert success + assert success await run_sync_test(server, test_code) -# This is a sample binary file content containing all possible byte values -binary_file_content = bytes(range(0, 256)) - - -async def binary_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response(text=file_upload_server_answer, content_type="application/json") - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_binary_file_upload(aiohttp_server): @@ -778,8 +758,20 @@ async def test_httpx_binary_file_upload(aiohttp_server): from gql.transport.httpx import HTTPXAsyncTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -794,42 +786,18 @@ async def test_httpx_binary_file_upload(aiohttp_server): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} + params = {"file": FileVar(file_path), "other_var": 42} - # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) success = result["success"] assert success -file_upload_mutation_2 = """ - mutation($file1: Upload!, $file2: Upload!) { - uploadFile(input:{file1:$file, file2:$file}) { - success - } - } -""" - -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) - -file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' - -file_2_content = """ -This is a second test file -This file will also be sent in the GraphQL mutation -""" - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_file_upload_two_files(aiohttp_server): @@ -837,39 +805,38 @@ async def test_httpx_file_upload_two_files(aiohttp_server): from gql.transport.httpx import HTTPXAsyncTransport - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -886,43 +853,19 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - params = { - "file1": f1, - "file2": f2, + "file1": FileVar(file_path_1), + "file2": FileVar(file_path_2), } result = await session.execute( query, variable_values=params, upload_files=True ) - f1.close() - f2.close() - success = result["success"] - assert success -file_upload_mutation_3 = """ - mutation($files: [Upload!]!) { - uploadFiles(input:{files:$files}) { - success - } - } -""" - -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' - "input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) - -file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_file_upload_list_of_two_files(aiohttp_server): @@ -930,39 +873,40 @@ async def test_httpx_file_upload_list_of_two_files(aiohttp_server): from gql.transport.httpx import HTTPXAsyncTransport - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{files:$files}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_3_map = ( + '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + ) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -979,21 +923,19 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = {"files": [f1, f2]} + params = { + "files": [ + FileVar(file_path_1), + FileVar(file_path_2), + ], + } # Execute query asynchronously result = await session.execute( query, variable_values=params, upload_files=True ) - f1.close() - f2.close() - success = result["success"] - assert success diff --git a/tests/test_requests.py b/tests/test_requests.py index 9c0334bd..c184e230 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,8 +1,10 @@ +import os +import warnings from typing import Any, Dict, Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -14,7 +16,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) # Marking all tests in this file with the requests marker @@ -86,8 +88,6 @@ def test_code(): @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_requests_query_https(ssl_aiohttp_server, run_sync_test, verify_https): - import warnings - from aiohttp import web from gql.transport.requests import RequestsHTTPTransport @@ -519,8 +519,6 @@ def test_code(): await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' - file_upload_mutation_1 = """ mutation($file: Upload!) { uploadFile(input:{other_var:$other_var, file:$file}) { @@ -550,35 +548,16 @@ async def test_requests_file_upload(aiohttp_server, run_sync_test): from gql.transport.requests import RequestsHTTPTransport - async def single_upload_handler(request): - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -592,15 +571,41 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + params = {"file": FileVar(f), "other_var": 42} + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) assert execution_result.data["success"] + # Using an filename string inside a FileVar object + params = {"file": FileVar(file_path), "other_var": 42} + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + await run_sync_test(server, test_code) @@ -611,38 +616,17 @@ async def test_requests_file_upload_with_content_type(aiohttp_server, run_sync_t from gql.transport.requests import RequestsHTTPTransport - async def single_upload_handler(request): - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - # Verifying the content_type - assert field_2.headers["Content-Type"] == "application/pdf" - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + file_headers=[{"Content-Type": "application/pdf"}], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -656,12 +640,30 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: # Setting the content_type f.content_type = "application/pdf" # type: ignore params = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + params = { + "file": FileVar(f, content_type="application/pdf"), + "other_var": 42, + } execution_result = session._execute( query, variable_values=params, upload_files=True ) @@ -673,48 +675,78 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_additional_headers(aiohttp_server, run_sync_test): +async def test_requests_file_upload_default_filename_is_basename( + aiohttp_server, run_sync_test +): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport - async def single_upload_handler(request): - from aiohttp import web + app = web.Application() - assert request.headers["X-Auth"] == "foobar" + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + file_basename = os.path.basename(file_path) + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=[file_basename], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - reader = await request.multipart() + url = server.make_url("/") - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + def test_code(): - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + transport = RequestsHTTPTransport(url=url) - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) - field_3 = await reader.next() - assert field_3 is None + params = { + "file": FileVar(file_path), + "other_var": 42, + } + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + await run_sync_test(server, test_code) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_with_filename(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.requests import RequestsHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=["filename1.txt"], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") def test_code(): - transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) + + transport = RequestsHTTPTransport(url=url) with TemporaryFile(file_1_content) as test_file: with Client(transport=transport) as session: @@ -724,7 +756,10 @@ def test_code(): with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + params = { + "file": FileVar(f, filename="filename1.txt"), + "other_var": 42, + } execution_result = session._execute( query, variable_values=params, upload_files=True ) @@ -736,44 +771,72 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_binary_file_upload(aiohttp_server, run_sync_test): +async def test_requests_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web from gql.transport.requests import RequestsHTTPTransport - # This is a sample binary file content containing all possible byte values - binary_file_content = bytes(range(0, 256)) + app = web.Application() + app.router.add_route( + "POST", + "/", + make_upload_handler( + request_headers={"X-Auth": "foobar"}, + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - async def binary_upload_handler(request): + url = server.make_url("/") - from aiohttp import web + def test_code(): + transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) - reader = await request.multipart() + with TemporaryFile(file_1_content) as test_file: + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + file_path = test_file.filename - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + with open(file_path, "rb") as f: - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content + params = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) - field_3 = await reader.next() - assert field_3 is None + assert execution_result.data["success"] + + await run_sync_test(server, test_code) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_binary_file_upload(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.requests import RequestsHTTPTransport + + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -792,22 +855,19 @@ def test_code(): params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) assert execution_result.data["success"] await run_sync_test(server, test_code) -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): @@ -823,6 +883,12 @@ async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): } """ + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' file_2_content = """ @@ -830,39 +896,17 @@ async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -877,19 +921,45 @@ def test_code(): query = gql(file_upload_mutation_2) + # Old method file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params = { + params_1 = { "file1": f1, "file2": f2, } + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params_1, upload_files=True + ) + + assert execution_result.data["success"] + + f1.close() + f2.close() + + # Using FileVar + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params_2 = { + "file1": FileVar(f1), + "file2": FileVar(f2), + } + execution_result = session._execute( - query, variable_values=params, upload_files=True + query, variable_values=params_2, upload_files=True ) assert execution_result.data["success"] @@ -900,13 +970,6 @@ def test_code(): await run_sync_test(server, test_code) -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' - "(input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) - - @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_test): @@ -922,6 +985,12 @@ async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_t } """ + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) + file_upload_mutation_3_map = ( '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' ) @@ -931,39 +1000,17 @@ async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_t This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -977,6 +1024,7 @@ def test_code(): query = gql(file_upload_mutation_3) + # Old method file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename @@ -985,8 +1033,30 @@ def test_code(): params = {"files": [f1, f2]} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session._execute( + query, variable_values=params, upload_files=True + ) + + assert execution_result.data["success"] + + f1.close() + f2.close() + + # Using FileVar + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params_2 = {"files": [FileVar(f1), FileVar(f2)]} + execution_result = session._execute( - query, variable_values=params, upload_files=True + query, variable_values=params_2, upload_files=True ) assert execution_result.data["success"] From 58cd38738d4abb24719c7a9537733aeb787ce608 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 25 May 2025 14:12:49 +0000 Subject: [PATCH 215/239] Implementation of execute_batch for async transports (#550) --- gql/client.py | 266 +++++++++++++++-- gql/transport/aiohttp.py | 344 ++++++++++++++-------- gql/transport/async_transport.py | 21 +- gql/transport/common/batch.py | 76 +++++ gql/transport/httpx.py | 176 +++++++++--- gql/transport/requests.py | 54 +--- tests/custom_scalars/test_money.py | 26 ++ tests/test_aiohttp.py | 9 +- tests/test_aiohttp_batch.py | 335 ++++++++++++++++++++++ tests/test_client.py | 13 - tests/test_httpx.py | 1 + tests/test_httpx_async.py | 2 +- tests/test_httpx_batch.py | 440 +++++++++++++++++++++++++++++ tests/test_requests_batch.py | 7 +- 14 files changed, 1516 insertions(+), 254 deletions(-) create mode 100644 gql/transport/common/batch.py create mode 100644 tests/test_aiohttp_batch.py create mode 100644 tests/test_httpx_batch.py diff --git a/gql/client.py b/gql/client.py index 99cd6e46..a4e80dcb 100644 --- a/gql/client.py +++ b/gql/client.py @@ -184,6 +184,24 @@ def _build_schema_from_introspection( self.introspection = cast(IntrospectionQuery, execution_result.data) self.schema = build_client_schema(self.introspection) + @staticmethod + def _get_event_loop() -> asyncio.AbstractEventLoop: + """Get the current asyncio event loop. + + Or create a new event loop if there isn't one (in a new Thread). + """ + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop + @overload def execute_sync( self, @@ -358,6 +376,58 @@ async def execute_async( **kwargs, ) + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False] = ..., + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """:meta private:""" + async with self as session: + return await session.execute_batch( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + @overload def execute( self, @@ -430,17 +500,7 @@ def execute( """ if isinstance(self.transport, AsyncTransport): - # Get the current asyncio event loop - # Or create a new event loop if there isn't one (in a new Thread) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = self._get_event_loop() assert not loop.is_running(), ( "Cannot run client.execute(query) if an asyncio loop is running." @@ -537,7 +597,24 @@ def execute_batch( """ if isinstance(self.transport, AsyncTransport): - raise NotImplementedError("Batching is not implemented for async yet.") + loop = self._get_event_loop() + + assert not loop.is_running(), ( + "Cannot run client.execute_batch(query) if an asyncio loop is running." + " Use 'await client.execute_batch(query)' instead." + ) + + data = loop.run_until_complete( + self.execute_batch_async( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + ) + + return data else: # Sync transports return self.execute_batch_sync( @@ -675,17 +752,12 @@ def subscribe( We need an async transport for this functionality. """ - # Get the current asyncio event loop - # Or create a new event loop if there isn't one (in a new Thread) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = self._get_event_loop() + + assert not loop.is_running(), ( + "Cannot run client.subscribe(query) if an asyncio loop is running." + " Use 'await client.subscribe_async(query)' instead." + ) async_generator: Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] @@ -699,11 +771,6 @@ def subscribe( **kwargs, ) - assert not loop.is_running(), ( - "Cannot run client.subscribe(query) if an asyncio loop is running." - " Use 'await client.subscribe_async(query)' instead." - ) - try: while True: # Note: we need to create a task here in order to be able to close @@ -1626,6 +1693,149 @@ async def execute( return result.data + async def _execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + validate_document: Optional[bool] = True, + **kwargs: Any, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch, using + the async transport, returning a list of ExecutionResult objects. + + :param requests: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will deserialize the result. + By default use the parse_results argument of the client. + :param validate_document: Whether we still need to validate the document. + + The extra arguments are passed to the transport execute_batch method.""" + + # Validate document + if self.client.schema: + + if validate_document: + for req in requests: + self.client.validate(req.document) + + # Parse variable values for custom scalars if requested + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + requests = [ + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) + for req in requests + ] + + results = await self.transport.execute_batch(requests, **kwargs) + + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + for result in results: + result.data = parse_result_fn( + self.client.schema, + req.document, + result.data, + operation_name=req.operation_name, + ) + + return results + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False] = ..., + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """Execute multiple GraphQL requests in a batch, using + the async transport. This method sends the requests to the server all at once. + + Raises a TransportQueryError if an error has been returned in any + ExecutionResult. + + :param requests: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will deserialize the result. + By default use the parse_results argument of the client. + :param get_execution_result: return the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. + + The extra arguments are passed to the transport execute method.""" + + # Validate and execute on the transport + results = await self._execute_batch( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + + for result in results: + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError( + str_first_element(result.errors), + errors=result.errors, + data=result.data, + extensions=result.extensions, + ) + + assert ( + result.data is not None + ), "Transport returned an ExecutionResult without data or errors" + + if get_execution_result: + return results + + return cast(List[Dict[str, Any]], [result.data for result in results]) + async def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index b2633abb..9535eef4 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -8,7 +8,7 @@ AsyncGenerator, Callable, Dict, - NoReturn, + List, Optional, Tuple, Type, @@ -23,9 +23,11 @@ from graphql import DocumentNode, ExecutionResult, print_ast from multidict import CIMultiDictProxy +from ..graphql_request import GraphQLRequest from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport from .common.aiohttp_closed_event import create_aiohttp_closed_event +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -162,172 +164,274 @@ async def close(self) -> None: self.session = None - async def execute( + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload + + def _prepare_batch_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + reqs: List[GraphQLRequest], extra_args: Optional[Dict[str, Any]] = None, - upload_files: bool = False, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - This uses the aiohttp library to perform a HTTP POST request asynchronously - to the remote server. + ) -> Dict[str, Any]: - Don't call this coroutine directly on the transport, instead use - :code:`execute` on a client or a session. + payload = [self._build_payload(req) for req in reqs] - :param document: the parsed GraphQL request - :param variable_values: An optional Dict of variable values - :param operation_name: An optional Operation name for the request - :param extra_args: additional arguments to send to the aiohttp post method - :param upload_files: Set to True if you want to put files in the variable values - :returns: an ExecutionResult object. - """ + post_args = {"json": payload} - query_str = print_ast(document) + # Log the payload + if log.isEnabledFor(logging.INFO): + log.info(">>> %s", self.json_serialize(post_args["json"])) - payload: Dict[str, Any] = { - "query": query_str, - } + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) + + return post_args + + def _prepare_request( + self, + req: GraphQLRequest, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> Dict[str, Any]: - if operation_name: - payload["operationName"] = operation_name + payload = self._build_payload(req) if upload_files: + post_args = self._prepare_file_uploads(req, payload) + else: + post_args = {"json": payload} - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None + # Log the payload + if log.isEnabledFor(logging.INFO): + log.info(">>> %s", self.json_serialize(payload)) + + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) - # If we upload files, we will extract the files present in the - # variable_values dict and replace them by null values - nulled_variable_values, files = extract_files( - variables=variable_values, - file_classes=self.file_classes, + # Add headers for AppSync if requested + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + self.json_serialize(payload), + {"content-type": "application/json"}, ) - # Opening the files using the FileVar parameters - open_files(list(files.values()), transport_supports_streaming=True) - self.files = files + return post_args + + def _prepare_file_uploads( + self, req: GraphQLRequest, payload: Dict[str, Any] + ) -> Dict[str, Any]: + + # If the upload_files flag is set, then we need variable_values + variable_values = req.variable_values + assert variable_values is not None - # Save the nulled variable values in the payload - payload["variables"] = nulled_variable_values + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=variable_values, + file_classes=self.file_classes, + ) - # Prepare aiohttp to send multipart-encoded data - data = aiohttp.FormData() + # Opening the files using the FileVar parameters + open_files(list(files.values()), transport_supports_streaming=True) + self.files = files - # Generate the file map - # path is nested in a list because the spec allows multiple pointers - # to the same file. But we don't support that. - # Will generate something like {"0": ["variables.file"]} - file_map = {str(i): [path] for i, path in enumerate(files)} + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values - # Enumerate the file streams - # Will generate something like {'0': FileVar object} - file_vars = {str(i): files[path] for i, path in enumerate(files)} + # Prepare aiohttp to send multipart-encoded data + data = aiohttp.FormData() + + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} + + # Enumerate the file streams + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} + + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) + data.add_field("operations", operations_str, content_type="application/json") + + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) + data.add_field("map", file_map_str, content_type="application/json") + + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) - # Add the payload to the operations field - operations_str = self.json_serialize(payload) - log.debug("operations %s", operations_str) data.add_field( - "operations", operations_str, content_type="application/json" + k, + file_var.f, + filename=file_var.filename, + content_type=file_var.content_type, ) - # Add the file map field - file_map_str = self.json_serialize(file_map) - log.debug("file_map %s", file_map_str) - data.add_field("map", file_map_str, content_type="application/json") + post_args: Dict[str, Any] = {"data": data} - for k, file_var in file_vars.items(): - assert isinstance(file_var, FileVar) + return post_args - data.add_field( - k, - file_var.f, - filename=file_var.filename, - content_type=file_var.content_type, - ) + async def raise_response_error( + self, + resp: aiohttp.ClientResponse, + reason: str, + ) -> None: + # We raise a TransportServerError if status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + try: + # Raise ClientResponseError if response status is 400 or higher + resp.raise_for_status() + except ClientResponseError as e: + raise TransportServerError(str(e), e.status) from e - post_args: Dict[str, Any] = {"data": data} + result_text = await resp.text() + self._raise_invalid_result(result_text, reason) - else: - if variable_values: - payload["variables"] = variable_values + async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: + + # Saving latest response headers in the transport + self.response_headers = response.headers + + try: + result = await response.json(loads=self.json_deserialize, content_type=None) if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(payload)) + result_text = await response.text() + log.info("<<< %s", result_text) - post_args = {"json": payload} + except Exception: + await self.raise_response_error(response, "Not a JSON answer") - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) + if result is None: + await self.raise_response_error(response, "Not a JSON answer") - # Add headers for AppSync if requested - if isinstance(self.auth, AppSyncAuthentication): - post_args["headers"] = self.auth.get_headers( - self.json_serialize(payload), - {"content-type": "application/json"}, + return result + + async def _prepare_result( + self, response: aiohttp.ClientResponse + ) -> ExecutionResult: + + result = await self._get_json_result(response) + + if "errors" not in result and "data" not in result: + await self.raise_response_error( + response, 'No "data" or "errors" keys in answer' ) - if self.session is None: - raise TransportClosed("Transport is not connected") + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) - try: - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + async def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: aiohttp.ClientResponse, + ) -> List[ExecutionResult]: - # Saving latest response headers in the transport - self.response_headers = resp.headers + answers = await self._get_json_result(response) - async def raise_response_error( - resp: aiohttp.ClientResponse, reason: str - ) -> NoReturn: - # We raise a TransportServerError if status code is 400 or higher - # We raise a TransportProtocolError in the other cases + return get_batch_execution_result_list(reqs, answers) - try: - # Raise ClientResponseError if response status is 400 or higher - resp.raise_for_status() - except ClientResponseError as e: - raise TransportServerError(str(e), e.status) from e + def _raise_invalid_result(self, result_text: str, reason: str) -> None: + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " + f"{reason}: " + f"{result_text}" + ) - result_text = await resp.text() - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + This uses the aiohttp library to perform a HTTP POST request asynchronously + to the remote server. - try: - result = await resp.json( - loads=self.json_deserialize, content_type=None - ) + Don't call this coroutine directly on the transport, instead use + :code:`execute` on a client or a session. - if log.isEnabledFor(logging.INFO): - result_text = await resp.text() - log.info("<<< %s", result_text) + :param document: the parsed GraphQL request + :param variable_values: An optional Dict of variable values + :param operation_name: An optional Operation name for the request + :param extra_args: additional arguments to send to the aiohttp post method + :param upload_files: Set to True if you want to put files in the variable values + :returns: an ExecutionResult object. + """ - except Exception: - await raise_response_error(resp, "Not a JSON answer") + req = GraphQLRequest( + document=document, + variable_values=variable_values, + operation_name=operation_name, + ) - if result is None: - await raise_response_error(resp, "Not a JSON answer") + post_args = self._prepare_request( + req, + extra_args, + upload_files, + ) - if "errors" not in result and "data" not in result: - await raise_response_error( - resp, 'No "data" or "errors" keys in answer' - ) + if self.session is None: + raise TransportClosed("Transport is not connected") - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) + try: + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_result(resp) finally: if upload_files: close_files(list(self.files.values())) + async def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + post_args = self._prepare_batch_request( + reqs, + extra_args, + ) + + if self.session is None: + raise TransportClosed("Transport is not connected") + + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_batch_result(reqs, resp) + def subscribe( self, document: DocumentNode, diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 4cecc9f9..243746e6 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,8 +1,10 @@ import abc -from typing import Any, AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional from graphql import DocumentNode, ExecutionResult +from ..graphql_request import GraphQLRequest + class AsyncTransport(abc.ABC): @abc.abstractmethod @@ -32,6 +34,23 @@ async def execute( "Any AsyncTransport subclass must implement execute method" ) # pragma: no cover + async def execute_batch( + self, + reqs: List[GraphQLRequest], + *args: Any, + **kwargs: Any, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Execute the provided requests for either a remote or local GraphQL Schema. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :return: a list of ExecutionResult objects + """ + raise NotImplementedError( + "This Transport has not implemented the execute_batch method" + ) # pragma: no cover + @abc.abstractmethod def subscribe( self, diff --git a/gql/transport/common/batch.py b/gql/transport/common/batch.py new file mode 100644 index 00000000..4feadee6 --- /dev/null +++ b/gql/transport/common/batch.py @@ -0,0 +1,76 @@ +from typing import ( + Any, + Dict, + List, +) + +from graphql import ExecutionResult + +from ...graphql_request import GraphQLRequest +from ..exceptions import ( + TransportProtocolError, +) + + +def _raise_protocol_error(result_text: str, reason: str) -> None: + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " f"{reason}: " f"{result_text}" + ) + + +def _validate_answer_is_a_list(results: Any) -> None: + if not isinstance(results, list): + _raise_protocol_error( + str(results), + "Answer is not a list", + ) + + +def _validate_data_and_errors_keys_in_answers(results: List[Dict[str, Any]]) -> None: + for result in results: + if "errors" not in result and "data" not in result: + _raise_protocol_error( + str(results), + 'No "data" or "errors" keys in answer', + ) + + +def _validate_every_answer_is_a_dict(results: List[Dict[str, Any]]) -> None: + for result in results: + if not isinstance(result, dict): + _raise_protocol_error(str(results), "Not every answer is dict") + + +def _validate_num_of_answers_same_as_requests( + reqs: List[GraphQLRequest], + results: List[Dict[str, Any]], +) -> None: + if len(reqs) != len(results): + _raise_protocol_error( + str(results), + ( + "Invalid number of answers: " + f"{len(results)} answers received for {len(reqs)} requests" + ), + ) + + +def _answer_to_execution_result(result: Dict[str, Any]) -> ExecutionResult: + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) + + +def get_batch_execution_result_list( + reqs: List[GraphQLRequest], + answers: List, +) -> List[ExecutionResult]: + + _validate_answer_is_a_list(answers) + _validate_num_of_answers_same_as_requests(reqs, answers) + _validate_every_answer_is_a_dict(answers) + _validate_data_and_errors_keys_in_answers(answers) + + return [_answer_to_execution_result(answer) for answer in answers] diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index eb15ac57..406c0523 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -17,7 +17,9 @@ import httpx from graphql import DocumentNode, ExecutionResult, print_ast +from ..graphql_request import GraphQLRequest from . import AsyncTransport, Transport +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -55,32 +57,30 @@ def __init__( self.json_deserialize = json_deserialize self.kwargs = kwargs + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload + def _prepare_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + req: GraphQLRequest, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - query_str = print_ast(document) - - payload: Dict[str, Any] = { - "query": query_str, - } - if operation_name: - payload["operationName"] = operation_name + payload = self._build_payload(req) if upload_files: - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None - - post_args = self._prepare_file_uploads(variable_values, payload) + post_args = self._prepare_file_uploads(req, payload) else: - if variable_values: - payload["variables"] = variable_values - post_args = {"json": payload} # Log the payload @@ -93,9 +93,37 @@ def _prepare_request( return post_args + def _prepare_batch_request( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + + payload = [self._build_payload(req) for req in reqs] + + post_args = {"json": payload} + + # Log the payload + if log.isEnabledFor(logging.INFO): + log.debug(">>> %s", self.json_serialize(payload)) + + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) + + return post_args + def _prepare_file_uploads( - self, variable_values: Dict[str, Any], payload: Dict[str, Any] + self, + request: GraphQLRequest, + payload: Dict[str, Any], ) -> Dict[str, Any]: + + variable_values = request.variable_values + + # If the upload_files flag is set, then we need variable_values + assert variable_values is not None + # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( @@ -143,8 +171,9 @@ def _prepare_file_uploads( return {"data": data, "files": file_streams} - def _prepare_result(self, response: httpx.Response) -> ExecutionResult: - # Save latest response headers in transport + def _get_json_result(self, response: httpx.Response) -> Any: + + # Saving latest response headers in the transport self.response_headers = response.headers if log.isEnabledFor(logging.DEBUG): @@ -152,10 +181,15 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: try: result: Dict[str, Any] = self.json_deserialize(response.content) - except Exception: self._raise_response_error(response, "Not a JSON answer") + return result + + def _prepare_result(self, response: httpx.Response) -> ExecutionResult: + + result = self._get_json_result(response) + if "errors" not in result and "data" not in result: self._raise_response_error(response, 'No "data" or "errors" keys in answer') @@ -165,6 +199,16 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: extensions=result.get("extensions"), ) + def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: httpx.Response, + ) -> List[ExecutionResult]: + + answers = self._get_json_result(response) + + return get_batch_execution_result_list(reqs, answers) + def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn: # We raise a TransportServerError if the status code is 400 or higher # We raise a TransportProtocolError in the other cases @@ -223,10 +267,14 @@ def execute( # type: ignore if not self.client: raise TransportClosed("Transport is not connected") + request = GraphQLRequest( + document=document, + variable_values=variable_values, + operation_name=operation_name, + ) + post_args = self._prepare_request( - document, - variable_values, - operation_name, + request, extra_args, upload_files, ) @@ -239,6 +287,36 @@ def execute( # type: ignore return self._prepare_result(response) + def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_batch_request( + reqs, + extra_args, + ) + + response = self.client.post(self.url, **post_args) + + return self._prepare_batch_result(reqs, response) + def close(self): """Closing the transport by closing the inner session""" if self.client: @@ -290,10 +368,14 @@ async def execute( if not self.client: raise TransportClosed("Transport is not connected") + request = GraphQLRequest( + document=document, + variable_values=variable_values, + operation_name=operation_name, + ) + post_args = self._prepare_request( - document, - variable_values, - operation_name, + request, extra_args, upload_files, ) @@ -306,11 +388,35 @@ async def execute( return self._prepare_result(response) - async def close(self): - """Closing the transport by closing the inner session""" - if self.client: - await self.client.aclose() - self.client = None + async def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_batch_request( + reqs, + extra_args, + ) + + response = await self.client.post(self.url, **post_args) + + return self._prepare_batch_result(reqs, response) def subscribe( self, @@ -323,3 +429,9 @@ def subscribe( :meta private: """ raise NotImplementedError("The HTTP transport does not support subscriptions") + + async def close(self): + """Closing the transport by closing the inner session""" + if self.client: + await self.client.aclose() + self.client = None diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 5fb7e827..d84ba9d3 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -25,6 +25,7 @@ from gql.transport import Transport from ..graphql_request import GraphQLRequest +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -307,7 +308,7 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: extensions=result.get("extensions"), ) - def execute_batch( # type: ignore + def execute_batch( self, reqs: List[GraphQLRequest], timeout: Optional[int] = None, @@ -340,52 +341,7 @@ def execute_batch( # type: ignore answers = self._extract_response(response) - self._validate_answer_is_a_list(answers) - self._validate_num_of_answers_same_as_requests(reqs, answers) - self._validate_every_answer_is_a_dict(answers) - self._validate_data_and_errors_keys_in_answers(answers) - - return [self._answer_to_execution_result(answer) for answer in answers] - - def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult: - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) - - def _validate_answer_is_a_list(self, results: Any) -> None: - if not isinstance(results, list): - self._raise_invalid_result( - str(results), - "Answer is not a list", - ) - - def _validate_data_and_errors_keys_in_answers( - self, results: List[Dict[str, Any]] - ) -> None: - for result in results: - if "errors" not in result and "data" not in result: - self._raise_invalid_result( - str(results), - 'No "data" or "errors" keys in answer', - ) - - def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None: - for result in results: - if not isinstance(result, dict): - self._raise_invalid_result(str(results), "Not every answer is dict") - - def _validate_num_of_answers_same_as_requests( - self, - reqs: List[GraphQLRequest], - results: List[Dict[str, Any]], - ) -> None: - if len(reqs) != len(results): - self._raise_invalid_result( - str(results), - "Invalid answer length", - ) + return get_batch_execution_result_list(reqs, answers) def _raise_invalid_result(self, result_text: str, reason: str) -> None: raise TransportProtocolError( @@ -427,7 +383,7 @@ def _build_batch_post_args( } data_key = "json" if self.use_json else "data" - post_args[data_key] = [self._build_data(req) for req in reqs] + post_args[data_key] = [self._build_payload(req) for req in reqs] # Log the payload if log.isEnabledFor(logging.INFO): @@ -442,7 +398,7 @@ def _build_batch_post_args( return post_args - def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]: + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: query_str = print_ast(req.document) payload: Dict[str, Any] = {"query": query_str} diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 39f1a1cb..8b4a99f4 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -784,6 +784,32 @@ def test_code(): await run_sync_test(server, test_code) +@pytest.mark.asyncio +@pytest.mark.aiohttp +async def test_custom_scalar_serialize_variables_async_transport(aiohttp_server): + transport = await make_money_transport(aiohttp_server) + + async with Client( + schema=schema, transport=transport, parse_results=True + ) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + variable_values = {"money": Money(10, "DM")} + + results = await session.execute_batch( + [ + GraphQLRequest(document=query, variable_values=variable_values), + GraphQLRequest(document=query, variable_values=variable_values), + ], + serialize_variables=True, + ) + + print(f"result = {results!r}") + assert results[0]["toEuros"] == 5 + assert results[1]["toEuros"] == 5 + + def test_serialize_value_with_invalid_type(): with pytest.raises(GraphQLError) as exc_info: diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index fe36585e..0642e536 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -295,27 +295,28 @@ async def handler(request): { "response": "{}", "expected_exception": ( - "Server did not return a GraphQL result: " + "Server did not return a valid GraphQL result: " 'No "data" or "errors" keys in answer: {}' ), }, { "response": "qlsjfqsdlkj", "expected_exception": ( - "Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj" + "Server did not return a valid GraphQL result: " + "Not a JSON answer: qlsjfqsdlkj" ), }, { "response": '{"not_data_or_errors": 35}', "expected_exception": ( - "Server did not return a GraphQL result: " + "Server did not return a valid GraphQL result: " 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' ), }, { "response": "", "expected_exception": ( - "Server did not return a GraphQL result: Not a JSON answer: " + "Server did not return a valid GraphQL result: Not a JSON answer: " ), }, ] diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py new file mode 100644 index 00000000..f04f05e4 --- /dev/null +++ b/tests/test_aiohttp_batch.py @@ -0,0 +1,335 @@ +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Execute query asynchronously + results = await session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_without_session(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = AIOHTTPTransport(url=url, timeout=10) + + client = Client(transport=transport) + + query = [GraphQLRequest(document=gql(query1_str))] + + results = client.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.asyncio +async def test_aiohttp_batch_error_code(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportQueryError): + await session.execute_batch(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_aiohttp_batch_invalid_protocol(aiohttp_server, response): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportProtocolError): + await session.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_aiohttp_batch_cannot_execute_if_not_connected( + aiohttp_server, run_sync_test +): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + await transport.execute_batch(query) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_extra_args(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + # passing extra arguments to aiohttp.ClientSession + from aiohttp import DummyCookieJar + + jar = DummyCookieJar() + transport = AIOHTTPTransport( + url=url, timeout=10, client_session_args={"version": "1.1", "cookie_jar": jar} + ) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Passing extra arguments to the post method of aiohttp + results = await session.execute_batch( + query, extra_args={"allow_redirects": False} + ) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_with_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + query = [GraphQLRequest(document=gql(query1_str))] + + async with Client(transport=transport) as session: + + execution_results = await session.execute_batch( + query, get_execution_result=True + ) + + assert execution_results[0].extensions["key1"] == "val1" + + +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_aiohttp_batch_online_manual(): + + from gql.transport.aiohttp import AIOHTTPTransport + + client = Client( + transport=AIOHTTPTransport(url=ONLINE_URL, timeout=10), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + async with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = await session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" diff --git a/tests/test_client.py b/tests/test_client.py index 8669b4a3..55993a9e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -54,19 +54,6 @@ def execute( ) -@pytest.mark.aiohttp -def test_request_async_execute_batch_not_implemented_yet(): - from gql.transport.aiohttp import AIOHTTPTransport - - transport = AIOHTTPTransport(url="http://localhost/") - client = Client(transport=transport) - - with pytest.raises(NotImplementedError) as exc_info: - client.execute_batch([GraphQLRequest(document=gql("{dummy}"))]) - - assert "Batching is not implemented for async yet." == str(exc_info.value) - - @pytest.mark.requests @mock.patch("urllib3.connection.HTTPConnection._new_conn") def test_retries_on_transport(execute_mock): diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 9558e137..0991355a 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -387,6 +387,7 @@ def test_code(): "{}", "qlsjfqsdlkj", '{"not_data_or_errors": 35}', + "", ] diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index ddacbc14..87f1675a 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -457,7 +457,7 @@ async def handler(request): query = gql(query1_str) - # Passing extra arguments to the post method of aiohttp + # Passing extra arguments to the post method result = await session.execute(query, extra_args={"follow_redirects": True}) continents = result["continents"] diff --git a/tests/test_httpx_batch.py b/tests/test_httpx_batch.py new file mode 100644 index 00000000..9e5b9b93 --- /dev/null +++ b/tests/test_httpx_batch.py @@ -0,0 +1,440 @@ +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +# Marking all tests in this file with the httpx marker +pytestmark = pytest.mark.httpx + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Execute query asynchronously + results = await session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_sync_batch_query(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url, timeout=10) + + def test_code(): + with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + results = session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXAsyncTransport(url=url, timeout=10) + + client = Client(transport=transport) + + query = [GraphQLRequest(document=gql(query1_str))] + + results = client.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_error_code(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportQueryError): + await session.execute_batch(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_httpx_async_batch_invalid_protocol(aiohttp_server, response): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportProtocolError): + await session.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_cannot_execute_if_not_connected(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + await transport.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_sync_batch_cannot_execute_if_not_connected(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url, timeout=10) + + query = [GraphQLRequest(document=gql(query1_str))] + + with pytest.raises(TransportClosed): + transport.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_extra_args(aiohttp_server): + import httpx + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # passing extra arguments to httpx.AsyncClient + inner_transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=inner_transport) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(document=gql(query1_str))] + + # Passing extra arguments to the post method + results = await session.execute_batch( + query, extra_args={"follow_redirects": True} + ) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query_with_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + query = [GraphQLRequest(document=gql(query1_str))] + + async with Client(transport=transport) as session: + + execution_results = await session.execute_batch( + query, get_execution_result=True + ) + + assert execution_results[0].extensions["key1"] == "val1" + + +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_batch_online_async_manual(): + + from gql.transport.httpx import HTTPXAsyncTransport + + client = Client( + transport=HTTPXAsyncTransport(url=ONLINE_URL), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + async with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = await session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_batch_online_sync_manual(): + + from gql.transport.httpx import HTTPXTransport + + client = Client( + transport=HTTPXTransport(url=ONLINE_URL), + ) + + query = gql( + """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + ) + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 4b9e09b8..38850d56 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -545,14 +545,11 @@ def test_code(): await run_sync_test(server, test_code) -ONLINE_URL = "https://countries.trevorblades.com/" - -skip_reason = "backend does not support batching anymore..." +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto(): from threading import Thread @@ -619,7 +616,6 @@ def get_continent_name(session, continent_code): @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto_execute_future(): from gql.transport.requests import RequestsHTTPTransport @@ -657,7 +653,6 @@ def test_requests_sync_batch_auto_execute_future(): @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_manual(): from gql.transport.requests import RequestsHTTPTransport From ef492685b921c757947578f852668699ab4ced73 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 25 May 2025 20:57:21 +0000 Subject: [PATCH 216/239] change transports prototype using GraphQLRequest (#551) --- gql/client.py | 224 +++++++++----------- gql/graphql_request.py | 15 +- gql/transport/aiohttp.py | 50 ++--- gql/transport/appsync_websockets.py | 19 +- gql/transport/async_transport.py | 14 +- gql/transport/common/base.py | 23 +- gql/transport/httpx.py | 60 ++---- gql/transport/local_schema.py | 38 +++- gql/transport/phoenix_channel_websockets.py | 11 +- gql/transport/requests.py | 41 +--- gql/transport/transport.py | 11 +- gql/transport/websockets_protocol.py | 13 +- tests/starwars/test_subscription.py | 6 +- tests/test_aiohttp.py | 8 +- tests/test_client.py | 4 +- tests/test_httpx.py | 44 ++-- tests/test_httpx_async.py | 8 +- tests/test_requests.py | 56 ++--- 18 files changed, 289 insertions(+), 356 deletions(-) diff --git a/gql/client.py b/gql/client.py index a4e80dcb..4e269a2a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -40,7 +40,6 @@ from .transport.transport import Transport from .utilities import build_client_schema, get_introspection_query_ast from .utilities import parse_result as parse_result_fn -from .utilities import serialize_variable_values from .utils import str_first_element log = logging.getLogger(__name__) @@ -68,6 +67,7 @@ class Client: def __init__( self, + *, schema: Optional[Union[str, GraphQLSchema]] = None, introspection: Optional[IntrospectionQuery] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, @@ -206,11 +206,11 @@ def _get_event_loop() -> asyncio.AbstractEventLoop: def execute_sync( self, document: DocumentNode, + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -219,11 +219,11 @@ def execute_sync( def execute_sync( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -232,11 +232,11 @@ def execute_sync( def execute_sync( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -244,6 +244,7 @@ def execute_sync( def execute_sync( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -319,11 +320,11 @@ def execute_batch_sync( async def execute_async( self, document: DocumentNode, + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -332,11 +333,11 @@ async def execute_async( async def execute_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -345,11 +346,11 @@ async def execute_async( async def execute_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -357,6 +358,7 @@ async def execute_async( async def execute_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -432,11 +434,11 @@ async def execute_batch_async( def execute( self, document: DocumentNode, + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -445,11 +447,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -458,11 +460,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -470,6 +472,7 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -629,11 +632,11 @@ def execute_batch( def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @@ -642,11 +645,11 @@ def subscribe_async( def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @@ -655,11 +658,11 @@ def subscribe_async( def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[ @@ -669,6 +672,7 @@ def subscribe_async( async def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -697,11 +701,11 @@ async def subscribe_async( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover @@ -710,11 +714,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover @@ -723,11 +727,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[ @@ -737,11 +741,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - *, get_execution_result: bool = False, **kwargs: Any, ) -> Union[ @@ -925,19 +929,17 @@ def __init__(self, client: Client): def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, ) -> ExecutionResult: - """Execute the provided document AST synchronously using + """Execute the provided request synchronously using the sync transport, returning an ExecutionResult object. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -948,34 +950,22 @@ def _execute( # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request.document) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) if self.client.batching_enabled: - request = GraphQLRequest( - document, - variable_values=variable_values, - operation_name=operation_name, - ) future_result = self._execute_future(request) result = future_result.result() else: result = self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) @@ -984,9 +974,9 @@ def _execute( if parse_result or (parse_result is None and self.client.parse_results): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) return result @@ -995,11 +985,11 @@ def _execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -1008,11 +998,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -1021,11 +1011,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -1033,6 +1023,7 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -1059,11 +1050,16 @@ def execute( The extra arguments are passed to the transport execute method.""" - # Validate and execute on the transport - result = self._execute( - document, + # Make GraphQLRequest object + request = GraphQLRequest( + document=document, variable_values=variable_values, operation_name=operation_name, + ) + + # Validate and execute on the transport + result = self._execute( + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1337,7 +1333,9 @@ def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = self.transport.execute(introspection_query) + execution_result = self.transport.execute( + GraphQLRequest(document=introspection_query) + ) self.client._build_schema_from_introspection(execution_result) @@ -1360,23 +1358,21 @@ def __init__(self, client: Client): async def _subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: - """Coroutine to subscribe asynchronously to the provided document AST + """Coroutine to subscribe asynchronously to the provided request asynchronously using the async transport, returning an async generator producing ExecutionResult objects. * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1387,26 +1383,19 @@ async def _subscribe( # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request.document) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) # Subscribe to the transport inner_generator: AsyncGenerator[ExecutionResult, None] = ( self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) ) @@ -1423,9 +1412,9 @@ async def _subscribe( ): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) yield result @@ -1437,11 +1426,11 @@ async def _subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @@ -1450,11 +1439,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @@ -1463,11 +1452,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[ @@ -1477,6 +1466,7 @@ def subscribe( async def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -1505,10 +1495,15 @@ async def subscribe( The extra arguments are passed to the transport subscribe method.""" - inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, + # Make GraphQLRequest object + request = GraphQLRequest( + document=document, variable_values=variable_values, operation_name=operation_name, + ) + + inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1536,22 +1531,20 @@ async def subscribe( async def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, ) -> ExecutionResult: - """Coroutine to execute the provided document AST asynchronously using + """Coroutine to execute the provided request asynchronously using the async transport, returning an ExecutionResult object. * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: graphql request as a + :class:`graphqlrequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1562,26 +1555,19 @@ async def _execute( # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request.document) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) # Execute the query with the transport with a timeout with fail_after(self.client.execute_timeout): result = await self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) @@ -1590,9 +1576,9 @@ async def _execute( if parse_result or (parse_result is None and self.client.parse_results): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) return result @@ -1601,11 +1587,11 @@ async def _execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -1614,11 +1600,11 @@ async def execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -1627,11 +1613,11 @@ async def execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -1639,6 +1625,7 @@ async def execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -1665,11 +1652,16 @@ async def execute( The extra arguments are passed to the transport execute method.""" - # Validate and execute on the transport - result = await self._execute( - document, + # Make GraphQLRequest object + request = GraphQLRequest( + document=document, variable_values=variable_values, operation_name=operation_name, + ) + + # Validate and execute on the transport + result = await self._execute( + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1844,7 +1836,9 @@ async def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = await self.transport.execute(introspection_query) + execution_result = await self.transport.execute( + GraphQLRequest(introspection_query) + ) self.client._build_schema_from_introspection(execution_result) @@ -1869,6 +1863,7 @@ class ReconnectingAsyncClientSession(AsyncClientSession): def __init__( self, client: Client, + *, retry_connect: Union[bool, _Decorator] = True, retry_execute: Union[bool, _Decorator] = True, ): @@ -1961,9 +1956,8 @@ async def stop_connecting_task(self): async def _execute_once( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, @@ -1974,9 +1968,7 @@ async def _execute_once( try: answer = await super()._execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1989,9 +1981,8 @@ async def _execute_once( async def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, @@ -2002,9 +1993,7 @@ async def _execute( """ return await self._execute_with_retries( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -2012,9 +2001,8 @@ async def _execute( async def _subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, @@ -2024,9 +2012,7 @@ async def _subscribe( """ inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, diff --git a/gql/graphql_request.py b/gql/graphql_request.py index b0c68f5c..7289a8f9 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional -from graphql import DocumentNode, GraphQLSchema +from graphql import DocumentNode, GraphQLSchema, print_ast from .utilities import serialize_variable_values @@ -35,3 +35,16 @@ def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": ), operation_name=self.operation_name, ) + + @property + def payload(self) -> Dict[str, Any]: + query_str = print_ast(self.document) + payload: Dict[str, Any] = {"query": query_str} + + if self.operation_name: + payload["operationName"] = self.operation_name + + if self.variable_values: + payload["variables"] = self.variable_values + + return payload diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 9535eef4..0a677af3 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -20,7 +20,7 @@ from aiohttp.client_reqrep import Fingerprint from aiohttp.helpers import BasicAuth from aiohttp.typedefs import LooseCookies, LooseHeaders -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from multidict import CIMultiDictProxy from ..graphql_request import GraphQLRequest @@ -164,25 +164,13 @@ async def close(self) -> None: self.session = None - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload - def _prepare_batch_request( self, reqs: List[GraphQLRequest], extra_args: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - payload = [self._build_payload(req) for req in reqs] + payload = [req.payload for req in reqs] post_args = {"json": payload} @@ -198,15 +186,15 @@ def _prepare_batch_request( def _prepare_request( self, - req: GraphQLRequest, + request: GraphQLRequest, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - payload = self._build_payload(req) + payload = request.payload if upload_files: - post_args = self._prepare_file_uploads(req, payload) + post_args = self._prepare_file_uploads(request, payload) else: post_args = {"json": payload} @@ -228,11 +216,11 @@ def _prepare_request( return post_args def _prepare_file_uploads( - self, req: GraphQLRequest, payload: Dict[str, Any] + self, request: GraphQLRequest, payload: Dict[str, Any] ) -> Dict[str, Any]: # If the upload_files flag is set, then we need variable_values - variable_values = req.variable_values + variable_values = request.variable_values assert variable_values is not None # If we upload files, we will extract the files present in the @@ -359,13 +347,12 @@ def _raise_invalid_result(self, result_text: str, reason: str) -> None: async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server + """Execute the provided request against the configured remote server using the current session. This uses the aiohttp library to perform a HTTP POST request asynchronously to the remote server. @@ -373,22 +360,15 @@ async def execute( Don't call this coroutine directly on the transport, instead use :code:`execute` on a client or a session. - :param document: the parsed GraphQL request - :param variable_values: An optional Dict of variable values - :param operation_name: An optional Operation name for the request + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the aiohttp post method :param upload_files: Set to True if you want to put files in the variable values :returns: an ExecutionResult object. """ - req = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - post_args = self._prepare_request( - req, + request, extra_args, upload_files, ) @@ -434,9 +414,7 @@ async def execute_batch( def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index a6a7d180..e2ab4f96 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult +from ..graphql_request import GraphQLRequest from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication from .common.adapters.websockets import WebSocketsAdapter from .common.base import SubscriptionTransportBase @@ -150,22 +151,14 @@ def _parse_answer( async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: query_id = self.next_query_id self.next_query_id += 1 - data: Dict = {"query": print_ast(document)} - - if variable_values: - data["variables"] = variable_values - - if operation_name: - data["operationName"] = operation_name + data: Dict[str, Any] = request.payload serialized_data = json.dumps(data, separators=(",", ":")) @@ -203,9 +196,7 @@ async def _send_query( async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: """This method is not available. diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 243746e6..526c97ba 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,7 +1,7 @@ import abc -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, List -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest @@ -24,11 +24,9 @@ async def close(self): @abc.abstractmethod async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: - """Execute the provided document AST for either a remote or local GraphQL + """Execute the provided request for either a remote or local GraphQL Schema.""" raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" @@ -54,9 +52,7 @@ async def execute_batch( @abc.abstractmethod def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using an async generator diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index a285ad2c..f2070fe1 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -5,8 +5,9 @@ from contextlib import suppress from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult +from ...graphql_request import GraphQLRequest from ..async_transport import AsyncTransport from ..exceptions import ( TransportAlreadyConnected, @@ -158,9 +159,7 @@ async def _receive(self) -> str: @abstractmethod async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: raise NotImplementedError # pragma: no cover @@ -267,9 +266,8 @@ async def _handle_answer( async def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, send_stop: Optional[bool] = True, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using a python async generator. @@ -281,7 +279,7 @@ async def subscribe( # Send the query and receive the id query_id: int = await self._send_query( - document, variable_values, operation_name + request, ) # Create a queue to receive the answers for this query_id @@ -325,11 +323,9 @@ async def subscribe( async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server + """Execute the provided request against the configured remote server using the current session. Send a query but close the async generator as soon as we have the first answer. @@ -339,7 +335,8 @@ async def execute( first_result = None generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + request, + send_stop=False, ) async for result in generator: diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 406c0523..f3416c24 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -15,7 +15,7 @@ ) import httpx -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest from . import AsyncTransport, Transport @@ -57,18 +57,6 @@ def __init__( self.json_deserialize = json_deserialize self.kwargs = kwargs - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload - def _prepare_request( self, req: GraphQLRequest, @@ -76,7 +64,7 @@ def _prepare_request( upload_files: bool = False, ) -> Dict[str, Any]: - payload = self._build_payload(req) + payload = req.payload if upload_files: post_args = self._prepare_file_uploads(req, payload) @@ -99,7 +87,7 @@ def _prepare_batch_request( extra_args: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - payload = [self._build_payload(req) for req in reqs] + payload = [req.payload for req in reqs] post_args = {"json": payload} @@ -243,21 +231,18 @@ def connect(self): def execute( # type: ignore self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the httpx library to perform a HTTP POST request to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the httpx post method :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. @@ -267,12 +252,6 @@ def execute( # type: ignore if not self.client: raise TransportClosed("Transport is not connected") - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - post_args = self._prepare_request( request, extra_args, @@ -343,22 +322,19 @@ async def connect(self): async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the httpx library to perform a HTTP POST request asynchronously to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the httpx post method :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. @@ -368,12 +344,6 @@ async def execute( if not self.client: raise TransportClosed("Transport is not connected") - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - post_args = self._prepare_request( request, extra_args, @@ -420,9 +390,7 @@ async def execute_batch( def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 19760ad6..f87854e2 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -2,10 +2,12 @@ from inspect import isawaitable from typing import Any, AsyncGenerator, Awaitable, cast -from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe +from graphql import ExecutionResult, GraphQLSchema, execute, subscribe from gql.transport import AsyncTransport +from ..graphql_request import GraphQLRequest + class LocalSchemaTransport(AsyncTransport): """A transport for executing GraphQL queries against a local schema.""" @@ -30,13 +32,24 @@ async def close(self): async def execute( self, - document: DocumentNode, + request: GraphQLRequest, *args: Any, **kwargs: Any, ) -> ExecutionResult: - """Execute the provided document AST for on a local GraphQL Schema.""" - - result_or_awaitable = execute(self.schema, document, *args, **kwargs) + """Execute the provided request for on a local GraphQL Schema.""" + + inner_kwargs = { + "variable_values": request.variable_values, + "operation_name": request.operation_name, + **kwargs, + } + + result_or_awaitable = execute( + self.schema, + request.document, + *args, + **inner_kwargs, + ) execution_result: ExecutionResult @@ -57,7 +70,7 @@ async def _await_if_necessary(obj): async def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *args: Any, **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: @@ -66,8 +79,19 @@ async def subscribe( The results are sent as an ExecutionResult object """ + inner_kwargs = { + "variable_values": request.variable_values, + "operation_name": request.operation_name, + **kwargs, + } + subscribe_result = await self._await_if_necessary( - subscribe(self.schema, document, *args, **kwargs) + subscribe( + self.schema, + request.document, + *args, + **inner_kwargs, + ) ) if isinstance(subscribe_result, ExecutionResult): diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 8a975b73..8e7455e2 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -3,8 +3,9 @@ import logging from typing import Any, Dict, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult, print_ast +from ..graphql_request import GraphQLRequest from .common.adapters.websockets import WebSocketsAdapter from .common.base import SubscriptionTransportBase from .exceptions import ( @@ -182,9 +183,7 @@ async def _connection_terminate(self): async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: """Send a query to the provided websocket connection. @@ -201,8 +200,8 @@ async def _send_query( "topic": self.channel_name, "event": "doc", "payload": { - "query": print_ast(document), - "variables": variable_values or {}, + "query": print_ast(request.document), + "variables": request.variable_values or {}, }, "ref": query_id, } diff --git a/gql/transport/requests.py b/gql/transport/requests.py index d84ba9d3..2087bbd0 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -15,7 +15,7 @@ ) import requests -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar @@ -139,22 +139,18 @@ def connect(self): def execute( # type: ignore self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, timeout: Optional[int] = None, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the requests library to perform a HTTP POST request to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param timeout: Specifies a default timeout for requests (Default: None). :param extra_args: additional arguments to send to the requests post method :param upload_files: Set to True if you want to put files in the variable values @@ -166,11 +162,7 @@ def execute( # type: ignore if not self.session: raise TransportClosed("Transport is not connected") - query_str = print_ast(document) - payload: Dict[str, Any] = {"query": query_str} - - if operation_name: - payload["operationName"] = operation_name + payload = request.payload post_args: Dict[str, Any] = { "headers": self.headers, @@ -182,12 +174,12 @@ def execute( # type: ignore if upload_files: # If the upload_files flag is set, then we need variable_values - assert variable_values is not None + assert request.variable_values is not None # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( - variables=variable_values, + variables=request.variable_values, file_classes=self.file_classes, ) @@ -241,9 +233,6 @@ def execute( # type: ignore post_args["headers"]["Content-Type"] = data.content_type else: - if variable_values: - payload["variables"] = variable_values - data_key = "json" if self.use_json else "data" post_args[data_key] = payload @@ -383,7 +372,7 @@ def _build_batch_post_args( } data_key = "json" if self.use_json else "data" - post_args[data_key] = [self._build_payload(req) for req in reqs] + post_args[data_key] = [req.payload for req in reqs] # Log the payload if log.isEnabledFor(logging.INFO): @@ -398,18 +387,6 @@ def _build_batch_post_args( return post_args - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload - def close(self): """Closing the transport by closing the inner session""" if self.session: diff --git a/gql/transport/transport.py b/gql/transport/transport.py index 49d0aa34..7a72f9a6 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,7 +1,7 @@ import abc from typing import Any, List -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest @@ -9,13 +9,16 @@ class Transport(abc.ABC): @abc.abstractmethod def execute( - self, document: DocumentNode, *args: Any, **kwargs: Any + self, + request: GraphQLRequest, + *args: Any, + **kwargs: Any, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST for either a remote or local GraphQL Schema. + Execute the provided request for either a remote or local GraphQL Schema. - :param document: GraphQL query as AST Node or Document object. + :param request: GraphQL request as a GraphQLRequest object. :return: ExecutionResult """ raise NotImplementedError( diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 61a4bb85..3b66a0cb 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -4,8 +4,9 @@ from contextlib import suppress from typing import Any, Dict, List, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult +from ..graphql_request import GraphQLRequest from .common.adapters.connection import AdapterConnection from .common.base import SubscriptionTransportBase from .exceptions import ( @@ -224,9 +225,7 @@ async def _send_connection_terminate_message(self) -> None: async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: """Send a query to the provided websocket connection. @@ -238,11 +237,7 @@ async def _send_query( query_id = self.next_query_id self.next_query_id += 1 - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name + payload: Dict[str, Any] = request.payload query_type = "start" diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 0f412acc..bbaafd5c 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -3,7 +3,7 @@ import pytest from graphql import ExecutionResult, GraphQLError, subscribe -from gql import Client, gql +from gql import Client, GraphQLRequest, gql from .fixtures import reviews from .schema import StarWarsSchema @@ -93,7 +93,9 @@ async def test_subscription_support_using_client_invalid_field(): results = [ result async for result in await await_if_coroutine( - session.transport.subscribe(subs, variable_values=params) + session.transport.subscribe( + GraphQLRequest(subs, variable_values=params) + ) ) ] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 0642e536..24f82c9d 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -6,7 +6,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -421,7 +421,7 @@ async def handler(request): query = gql(query1_str) with pytest.raises(TransportClosed): - await transport.execute(query) + await transport.execute(GraphQLRequest(query)) @pytest.mark.asyncio @@ -533,7 +533,9 @@ async def handler(request): query = gql(query2_str) # Execute query asynchronously - result = await session.execute(query, params, operation_name="getEurope") + result = await session.execute( + query, variable_values=params, operation_name="getEurope" + ) continent = result["continent"] diff --git a/tests/test_client.py b/tests/test_client.py index 55993a9e..3412059e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ from unittest import mock import pytest -from graphql import DocumentNode, ExecutionResult, build_ast_schema, parse +from graphql import ExecutionResult, build_ast_schema, parse from gql import Client, GraphQLRequest, gql from gql.transport import Transport @@ -40,7 +40,7 @@ class RandomTransport(Transport): class RandomTransport2(Transport): def execute( self, - document: DocumentNode, + request: GraphQLRequest, *args: Any, **kwargs: Any, ) -> ExecutionResult: diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 0991355a..b944391f 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -3,7 +3,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -470,7 +470,7 @@ def test_code(): query = gql(query1_str) with pytest.raises(TransportClosed): - transport.execute(query) + transport.execute(GraphQLRequest(query)) await run_sync_test(server, test_code) @@ -578,32 +578,32 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: params = {"file": FileVar(f), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an filename string inside a FileVar object params = { "file": FileVar(file_path), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -650,22 +650,22 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using FileVar params = { "file": FileVar(file_path, content_type="application/pdf"), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -710,11 +710,11 @@ def test_code(): "file": FileVar(file_path), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -751,11 +751,11 @@ def test_code(): file_path = test_file.filename params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -797,11 +797,11 @@ def test_code(): params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -867,11 +867,11 @@ def test_code(): "file2": FileVar(file_path_2), } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -940,11 +940,11 @@ def test_code(): ], } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 87f1675a..56c65873 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -429,7 +429,7 @@ async def handler(request): query = gql(query1_str) with pytest.raises(TransportClosed): - await transport.execute(query) + await transport.execute(GraphQLRequest(query)) @pytest.mark.aiohttp @@ -541,7 +541,9 @@ async def handler(request): query = gql(query2_str) # Execute query asynchronously - result = await session.execute(query, params, operation_name="getEurope") + result = await session.execute( + query, variable_values=params, operation_name="getEurope" + ) continent = result["continent"] diff --git a/tests/test_requests.py b/tests/test_requests.py index c184e230..ff6a5651 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -471,7 +471,7 @@ def test_code(): query = gql(query1_str) with pytest.raises(TransportClosed): - transport.execute(query) + transport.execute(GraphQLRequest(query)) await run_sync_test(server, test_code) @@ -580,11 +580,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: @@ -592,19 +592,19 @@ def test_code(): params = {"file": FileVar(f), "other_var": 42} with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an filename string inside a FileVar object params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -651,11 +651,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: @@ -664,11 +664,11 @@ def test_code(): "file": FileVar(f, content_type="application/pdf"), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -713,11 +713,11 @@ def test_code(): "file": FileVar(file_path), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -760,11 +760,11 @@ def test_code(): "file": FileVar(f, filename="filename1.txt"), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -807,11 +807,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -859,11 +859,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -937,11 +937,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params_1, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() @@ -958,11 +958,11 @@ def test_code(): "file2": FileVar(f2), } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params_2, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() @@ -1037,11 +1037,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() @@ -1055,11 +1055,11 @@ def test_code(): params_2 = {"files": [FileVar(f1), FileVar(f2)]} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params_2, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() From 2d2110046e65e46ddbeb915c5b2c7c79b338fdeb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 25 May 2025 22:01:48 +0000 Subject: [PATCH 217/239] Set logging level to DEBUG for all transports (#552) --- docs/advanced/logging.rst | 14 ++------------ gql/transport/aiohttp.py | 12 ++++++------ gql/transport/common/base.py | 4 ++-- gql/transport/httpx.py | 2 +- gql/transport/requests.py | 16 ++++++++-------- 5 files changed, 19 insertions(+), 29 deletions(-) diff --git a/docs/advanced/logging.rst b/docs/advanced/logging.rst index 02fdf3fd..f75c5f32 100644 --- a/docs/advanced/logging.rst +++ b/docs/advanced/logging.rst @@ -4,14 +4,7 @@ Logging GQL uses the python `logging`_ module. In order to debug a problem, you can enable logging to see the messages exchanged between the client and the server. -To do that, set the loglevel at **INFO** at the beginning of your code: - -.. code-block:: python - - import logging - logging.basicConfig(level=logging.INFO) - -For even more logs, you can set the loglevel at **DEBUG**: +To do that, set the loglevel at **DEBUG** at the beginning of your code: .. code-block:: python @@ -21,10 +14,7 @@ For even more logs, you can set the loglevel at **DEBUG**: Disabling logs -------------- -By default, the logs for the transports are quite verbose. - -On the **INFO** level, all the messages between the frontend and the backend are logged which can -be difficult to read especially when it fetches the schema from the transport. +On the **DEBUG** log level, the logs for the transports are quite verbose. It is possible to disable the logs only for a specific gql transport by setting a higher log level for this transport (**WARNING** for example) so that the other logs of your program are not affected. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 0a677af3..2c0d8fa7 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -175,8 +175,8 @@ def _prepare_batch_request( post_args = {"json": payload} # Log the payload - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(post_args["json"])) + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(post_args["json"])) # Pass post_args to aiohttp post method if extra_args: @@ -199,8 +199,8 @@ def _prepare_request( post_args = {"json": payload} # Log the payload - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(payload)) + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) # Pass post_args to aiohttp post method if extra_args: @@ -299,9 +299,9 @@ async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: try: result = await response.json(loads=self.json_deserialize, content_type=None) - if log.isEnabledFor(logging.INFO): + if log.isEnabledFor(logging.DEBUG): result_text = await response.text() - log.info("<<< %s", result_text) + log.debug("<<< %s", result_text) except Exception: await self.raise_response_error(response, "Not a JSON answer") diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index f2070fe1..734c393b 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -136,7 +136,7 @@ async def _send(self, message: str) -> None: try: # Can raise TransportConnectionFailed await self.adapter.send(message) - log.info(">>> %s", message) + log.debug(">>> %s", message) except TransportConnectionFailed as e: await self._fail(e, clean_close=False) raise e @@ -152,7 +152,7 @@ async def _receive(self) -> str: # Can raise TransportConnectionFailed or TransportProtocolError answer: str = await self.adapter.receive() - log.info("<<< %s", answer) + log.debug("<<< %s", answer) return answer diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index f3416c24..76324cd7 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -92,7 +92,7 @@ def _prepare_batch_request( post_args = {"json": payload} # Log the payload - if log.isEnabledFor(logging.INFO): + if log.isEnabledFor(logging.DEBUG): log.debug(">>> %s", self.json_serialize(payload)) # Pass post_args to aiohttp post method diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 2087bbd0..7be288d2 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -237,8 +237,8 @@ def execute( # type: ignore post_args[data_key] = payload # Log the payload - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(payload)) + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) # Pass kwargs to requests post method post_args.update(self.kwargs) @@ -282,8 +282,8 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: else: result = self.json_deserialize(response.text) - if log.isEnabledFor(logging.INFO): - log.info("<<< %s", response.text) + if log.isEnabledFor(logging.DEBUG): + log.debug("<<< %s", response.text) except Exception: raise_response_error(response, "Not a JSON answer") @@ -344,8 +344,8 @@ def _extract_response(self, response: requests.Response) -> Any: response.raise_for_status() result = response.json() - if log.isEnabledFor(logging.INFO): - log.info("<<< %s", response.text) + if log.isEnabledFor(logging.DEBUG): + log.debug("<<< %s", response.text) except requests.HTTPError as e: raise TransportServerError( @@ -375,8 +375,8 @@ def _build_batch_post_args( post_args[data_key] = [req.payload for req in reqs] # Log the payload - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(post_args[data_key])) + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(post_args[data_key])) # Pass kwargs to requests post method post_args.update(self.kwargs) From 77a3a40b7dc17b46106f7889455d07b01e1e3cbd Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 26 May 2025 09:23:09 +0000 Subject: [PATCH 218/239] introspection now requests deprecated input fields by default (#553) --- docs/gql-cli/intro.rst | 12 +++++++----- docs/usage/validation.rst | 2 +- gql/cli.py | 4 ++-- gql/utilities/get_introspection_query_ast.py | 2 +- tests/starwars/test_introspection.py | 7 ++++--- tests/test_transport.py | 3 +++ tests/test_transport_batch.py | 3 +++ 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst index f88b60a1..c3237093 100644 --- a/docs/gql-cli/intro.rst +++ b/docs/gql-cli/intro.rst @@ -79,12 +79,14 @@ Print the GraphQL schema in a file $ gql-cli https://countries.trevorblades.com/graphql --print-schema > schema.graphql -.. note:: - - By default, deprecated input fields are not requested from the backend. - You can add :code:`--schema-download input_value_deprecation:true` to request them. - .. note:: You can add :code:`--schema-download descriptions:false` to request a compact schema without comments. + +.. warning:: + + By default, from gql version 4.0, deprecated input fields are requested from the backend. + It is possible that some old backends do not support this feature. In that case + you can add :code:`--schema-download input_value_deprecation:false` to go back + to the previous behavior. diff --git a/docs/usage/validation.rst b/docs/usage/validation.rst index f9711f31..18b1cda1 100644 --- a/docs/usage/validation.rst +++ b/docs/usage/validation.rst @@ -24,7 +24,7 @@ The schema can be provided as a String (which is usually stored in a .graphql fi .. note:: You can download a schema from a server by using :ref:`gql-cli ` - :code:`$ gql-cli https://SERVER_URL/graphql --print-schema --schema-download input_value_deprecation:true > schema.graphql` + :code:`$ gql-cli https://SERVER_URL/graphql --print-schema > schema.graphql` OR can be created using python classes: diff --git a/gql/cli.py b/gql/cli.py index 9ae92e83..37be3656 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -132,12 +132,12 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: By default, it will: - request field descriptions - - not request deprecated input fields + - request deprecated input fields Possible options: - descriptions:false for a compact schema without comments - - input_value_deprecation:true to download deprecated input fields + - input_value_deprecation:false to omit deprecated input fields - specified_by_url:true - schema_description:true - directive_is_repeatable:true""" diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index 975ccc83..4d6a243f 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -10,7 +10,7 @@ def get_introspection_query_ast( specified_by_url: bool = False, directive_is_repeatable: bool = False, schema_description: bool = False, - input_value_deprecation: bool = False, + input_value_deprecation: bool = True, type_recursion_level: int = 7, ) -> DocumentNode: """Get a query for introspection as a document using the DSL module. diff --git a/tests/starwars/test_introspection.py b/tests/starwars/test_introspection.py index 0d8369c0..9e5ff4aa 100644 --- a/tests/starwars/test_introspection.py +++ b/tests/starwars/test_introspection.py @@ -19,6 +19,9 @@ async def test_starwars_introspection_args(aiohttp_server): async with Client( transport=transport, fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": False, + }, ) as session: schema_str = print_schema(session.client.schema) @@ -35,6 +38,7 @@ async def test_starwars_introspection_args(aiohttp_server): fetch_schema_from_transport=True, introspection_args={ "descriptions": False, + "input_value_deprecation": False, }, ) as session: @@ -50,9 +54,6 @@ async def test_starwars_introspection_args(aiohttp_server): async with Client( transport=transport, fetch_schema_from_transport=True, - introspection_args={ - "input_value_deprecation": True, - }, ) as session: schema_str = print_schema(session.client.schema) diff --git a/tests/test_transport.py b/tests/test_transport.py index e554955a..87b31eb1 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -43,6 +43,9 @@ def client(): url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} ), fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": False, + }, ) diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index 7c108ec3..0b2a3158 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -43,6 +43,9 @@ def client(): url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} ), fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": False, + }, ) From 7fcb5b6f1a4f517b1f97c50f66599ea5943366e7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 27 May 2025 16:57:45 +0000 Subject: [PATCH 219/239] Implementation of automatic batching for async (#554) --- README.md | 1 + docs/advanced/batching_requests.rst | 96 ++++++++++++++ docs/advanced/index.rst | 1 + gql/client.py | 177 +++++++++++++++++++++++--- gql/graphql_request.py | 39 ++++-- gql/transport/aiohttp.py | 48 ++++--- gql/transport/httpx.py | 29 ++++- gql/transport/requests.py | 70 +++++----- tests/test_aiohttp_batch.py | 190 ++++++++++++++++++++++++++++ tests/test_graphql_request.py | 12 +- 10 files changed, 573 insertions(+), 90 deletions(-) create mode 100644 docs/advanced/batching_requests.rst diff --git a/README.md b/README.md index cbc53af6..e79a63d2 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ The complete documentation for GQL can be found at * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) +* Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line * [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically diff --git a/docs/advanced/batching_requests.rst b/docs/advanced/batching_requests.rst new file mode 100644 index 00000000..a71d4ffc --- /dev/null +++ b/docs/advanced/batching_requests.rst @@ -0,0 +1,96 @@ +.. _batching_requests: + +Batching requests +================= + +If you need to send multiple GraphQL queries to a backend, +and if the backend supports batch requests, +then you might want to send those requests in a batch instead of +making multiple execution requests. + +.. warning:: + - Some backends do not support batch requests + - File uploads and subscriptions are not supported with batch requests + +Batching requests manually +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To execute a batch of requests manually: + +- First Make a list of :class:`GraphQLRequest ` objects, containing: + * your GraphQL query + * Optional variable_values + * Optional operation_name + +.. code-block:: python + + request1 = GraphQLRequest(""" + query getContinents { + continents { + code + name + } + } + """ + ) + + request2 = GraphQLRequest(""" + query getContinentName ($code: ID!) { + continent (code: $code) { + name + } + } + """, + variable_values={ + "code": "AF", + }, + ) + + requests = [request1, request2] + +- Then use one of the `execute_batch` methods, either on Client, + or in a sync or async session + +**Sync**: + +.. code-block:: python + + transport = RequestsHTTPTransport(url=url) + # Or transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + results = session.execute_batch(requests) + + result1 = results[0] + result2 = results[1] + +**Async**: + +.. code-block:: python + + transport = AIOHTTPTransport(url=url) + # Or transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + results = await session.execute_batch(requests) + + result1 = results[0] + result2 = results[1] + +.. note:: + If any request in the batch returns an error, then a TransportQueryError will be raised + with the first error found. + +Automatic Batching of requests +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If your code execute multiple requests independently in a short time +(either from different threads in sync code, or from different asyncio tasks in async code), +then you can use gql automatic batching of request functionality. + +You define a :code:`batching_interval` in your :class:`Client ` +and each time a new execution request is received through an `execute` method, +we will wait that interval (in seconds) for other requests to arrive +before sending all the requests received in that interval in a single batch. diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index baae9276..ef14defd 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -6,6 +6,7 @@ Advanced async_advanced_usage async_permanent_session + batching_requests logging error_handling local_schema diff --git a/gql/client.py b/gql/client.py index 4e269a2a..a0e07056 100644 --- a/gql/client.py +++ b/gql/client.py @@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs): if reconnecting: self.session = ReconnectingAsyncClientSession(client=self, **kwargs) - await self.session.start_connecting_task() else: - try: - await self.transport.connect() - except Exception as e: - await self.transport.close() - raise e self.session = AsyncClientSession(client=self) + await self.session.connect() + # Get schema from transport if needed try: if self.fetch_schema_from_transport and not self.schema: @@ -846,7 +842,7 @@ async def connect_async(self, reconnecting=False, **kwargs): # we don't know what type of exception is thrown here because it # depends on the underlying transport; we just make sure that the # transport is closed and re-raise the exception - await self.transport.close() + await self.session.close() raise return self.session @@ -854,10 +850,7 @@ async def connect_async(self, reconnecting=False, **kwargs): async def close_async(self): """Close the async transport and stop the optional reconnecting task.""" - if isinstance(self.session, ReconnectingAsyncClientSession): - await self.session.stop_connecting_task() - - await self.transport.close() + await self.session.close() async def __aenter__(self): return await self.connect_async() @@ -1564,12 +1557,17 @@ async def _execute( ): request = request.serialize_variable_values(self.client.schema) - # Execute the query with the transport with a timeout - with fail_after(self.client.execute_timeout): - result = await self.transport.execute( - request, - **kwargs, - ) + # Check if batching is enabled + if self.client.batching_enabled: + future_result = await self._execute_future(request) + result = await future_result + else: + # Execute the query with the transport with a timeout + with fail_after(self.client.execute_timeout): + result = await self.transport.execute( + request, + **kwargs, + ) # Unserialize the result if requested if self.client.schema: @@ -1828,6 +1826,134 @@ async def execute_batch( return cast(List[Dict[str, Any]], [result.data for result in results]) + async def _batch_loop(self) -> None: + """Main loop of the task used to wait for requests + to execute them in a batch""" + + stop_loop = False + + while not stop_loop: + # First wait for a first request in from the batch queue + requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = [] + + # Wait for the first request + request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = ( + await self.batch_queue.get() + ) + + if request_and_future is None: + # None is our sentinel value to stop the loop + break + + requests_and_futures.append(request_and_future) + + # Then wait the requested batch interval except if we already + # have the maximum number of requests in the queue + if self.batch_queue.qsize() < self.client.batch_max - 1: + # Wait for the batch interval + await asyncio.sleep(self.client.batch_interval) + + # Then get the requests which had been made during that wait interval + for _ in range(self.client.batch_max - 1): + try: + # Use get_nowait since we don't want to wait here + request_and_future = self.batch_queue.get_nowait() + + if request_and_future is None: + # Sentinel value - stop after processing current batch + stop_loop = True + break + + requests_and_futures.append(request_and_future) + + except asyncio.QueueEmpty: + # No more requests in queue, that's fine + break + + # Extract requests and futures + requests = [request for request, _ in requests_and_futures] + futures = [future for _, future in requests_and_futures] + + # Execute the batch + try: + results: List[ExecutionResult] = await self._execute_batch( + requests, + serialize_variables=False, # already done + parse_result=False, # will be done later + validate_document=False, # already validated + ) + + # Set the result for each future + for result, future in zip(results, futures): + if not future.cancelled(): + future.set_result(result) + + except Exception as exc: + # If batch execution fails, propagate the error to all futures + for future in futures: + if not future.cancelled(): + future.set_exception(exc) + + # Signal that the task has stopped + self._batch_task_stopped_event.set() + + async def _execute_future( + self, + request: GraphQLRequest, + ) -> asyncio.Future: + """If batching is enabled, this method will put a request in the batching queue + instead of executing it directly so that the requests could be put in a batch. + """ + + assert hasattr(self, "batch_queue"), "Batching is not enabled" + assert not self._batch_task_stop_requested, "Batching task has been stopped" + + future: asyncio.Future = asyncio.Future() + await self.batch_queue.put((request, future)) + + return future + + async def _batch_init(self): + """Initialize the batch task loop if batching is enabled.""" + if self.client.batching_enabled: + self.batch_queue: asyncio.Queue = asyncio.Queue() + self._batch_task_stop_requested = False + self._batch_task_stopped_event = asyncio.Event() + self._batch_task = asyncio.create_task(self._batch_loop()) + + async def _batch_cleanup(self): + """Cleanup the batching task if batching is enabled.""" + if hasattr(self, "_batch_task_stopped_event"): + # Send a None in the queue to indicate that the batching task must stop + # after having processed the remaining requests in the queue + self._batch_task_stop_requested = True + await self.batch_queue.put(None) + + # Wait for the task to process remaining requests and stop + await self._batch_task_stopped_event.wait() + + async def connect(self): + """Connect the transport and initialize the batch task loop if batching + is enabled.""" + + await self._batch_init() + + try: + await self.transport.connect() + except Exception as e: + await self.transport.close() + raise e + + async def close(self): + """Close the transport and cleanup the batching task if batching is enabled. + + Will wait until all the remaining requests in the batch processing queue + have been executed. + """ + await self._batch_cleanup() + + await self.transport.close() + async def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. @@ -1954,6 +2080,23 @@ async def stop_connecting_task(self): self._connect_task.cancel() self._connect_task = None + async def connect(self): + """Start the connect task and initialize the batch task loop if batching + is enabled.""" + + await self._batch_init() + + await self.start_connecting_task() + + async def close(self): + """Stop the connect task and cleanup the batching task + if batching is enabled.""" + await self._batch_cleanup() + + await self.stop_connecting_task() + + await self.transport.close() + async def _execute_once( self, request: GraphQLRequest, diff --git a/gql/graphql_request.py b/gql/graphql_request.py index 7289a8f9..29a34717 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,26 +1,38 @@ -from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from graphql import DocumentNode, GraphQLSchema, print_ast +from .gql import gql from .utilities import serialize_variable_values -@dataclass(frozen=True) class GraphQLRequest: """GraphQL Request to be executed.""" - document: DocumentNode - """GraphQL query as AST Node object.""" + def __init__( + self, + document: Union[DocumentNode, str], + *, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ): + """ + Initialize a GraphQL request. - variable_values: Optional[Dict[str, Any]] = None - """Dictionary of input parameters (Default: None).""" + Args: + document: GraphQL query as AST Node object or as a string. + If string, it will be converted to DocumentNode using gql(). + variable_values: Dictionary of input parameters (Default: None). + operation_name: Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + """ + if isinstance(document, str): + self.document = gql(document) + else: + self.document = document - operation_name: Optional[str] = None - """ - Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). - """ + self.variable_values = variable_values + self.operation_name = operation_name def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": assert self.variable_values @@ -48,3 +60,6 @@ def payload(self) -> Dict[str, Any]: payload["variables"] = self.variable_values return payload + + def __str__(self): + return str(self.payload) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2c0d8fa7..61d01fb4 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -274,22 +274,35 @@ def _prepare_file_uploads( return post_args - async def raise_response_error( - self, + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( resp: aiohttp.ClientResponse, - reason: str, ) -> None: - # We raise a TransportServerError if status code is 400 or higher - # We raise a TransportProtocolError in the other cases - + # If the status is >400, + # then we need to raise a TransportServerError try: # Raise ClientResponseError if response status is 400 or higher resp.raise_for_status() except ClientResponseError as e: raise TransportServerError(str(e), e.status) from e + @classmethod + async def _raise_response_error( + cls, + resp: aiohttp.ClientResponse, + reason: str, + ) -> None: + # We raise a TransportServerError if status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(resp) + result_text = await resp.text() - self._raise_invalid_result(result_text, reason) + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " + f"{reason}: " + f"{result_text}" + ) async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: @@ -304,10 +317,10 @@ async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: log.debug("<<< %s", result_text) except Exception: - await self.raise_response_error(response, "Not a JSON answer") + await self._raise_response_error(response, "Not a JSON answer") if result is None: - await self.raise_response_error(response, "Not a JSON answer") + await self._raise_response_error(response, "Not a JSON answer") return result @@ -318,7 +331,7 @@ async def _prepare_result( result = await self._get_json_result(response) if "errors" not in result and "data" not in result: - await self.raise_response_error( + await self._raise_response_error( response, 'No "data" or "errors" keys in answer' ) @@ -336,14 +349,13 @@ async def _prepare_batch_result( answers = await self._get_json_result(response) - return get_batch_execution_result_list(reqs, answers) - - def _raise_invalid_result(self, result_text: str, reason: str) -> None: - raise TransportProtocolError( - f"Server did not return a valid GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise async def execute( self, diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 76324cd7..afb1360c 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -195,18 +195,33 @@ def _prepare_batch_result( answers = self._get_json_result(response) - return get_batch_execution_result_list(reqs, answers) - - def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn: - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases - try: - # Raise a HTTPError if response status is 400 or higher + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise + + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + response: httpx.Response, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise a HTTPStatusError if response status is 400 or higher response.raise_for_status() except httpx.HTTPStatusError as e: raise TransportServerError(str(e), e.response.status_code) from e + @classmethod + def _raise_response_error(cls, response: httpx.Response, reason: str) -> NoReturn: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(response) + raise TransportProtocolError( f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}" ) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 7be288d2..16d07025 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -258,24 +258,6 @@ def execute( # type: ignore self.response_headers = response.headers - def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases - - try: - # Raise a HTTPError if response status is 400 or higher - resp.raise_for_status() - except requests.HTTPError as e: - status_code = e.response.status_code if e.response is not None else None - raise TransportServerError(str(e), status_code) from e - - result_text = resp.text - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) - try: if self.json_deserialize == json.loads: result = response.json() @@ -286,10 +268,10 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: log.debug("<<< %s", response.text) except Exception: - raise_response_error(response, "Not a JSON answer") + self._raise_response_error(response, "Not a JSON answer") if "errors" not in result and "data" not in result: - raise_response_error(response, 'No "data" or "errors" keys in answer') + self._raise_response_error(response, 'No "data" or "errors" keys in answer') return ExecutionResult( errors=result.get("errors"), @@ -297,6 +279,31 @@ def raise_response_error(resp: requests.Response, reason: str) -> NoReturn: extensions=result.get("extensions"), ) + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + response: requests.Response, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise a HTTPError if response status is 400 or higher + response.raise_for_status() + except requests.HTTPError as e: + status_code = e.response.status_code if e.response is not None else None + raise TransportServerError(str(e), status_code) from e + + @classmethod + def _raise_response_error(cls, resp: requests.Response, reason: str) -> NoReturn: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(resp) + + result_text = resp.text + raise TransportProtocolError( + f"Server did not return a GraphQL result: " f"{reason}: " f"{result_text}" + ) + def execute_batch( self, reqs: List[GraphQLRequest], @@ -330,30 +337,23 @@ def execute_batch( answers = self._extract_response(response) - return get_batch_execution_result_list(reqs, answers) - - def _raise_invalid_result(self, result_text: str, reason: str) -> None: - raise TransportProtocolError( - f"Server did not return a valid GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise def _extract_response(self, response: requests.Response) -> Any: try: - response.raise_for_status() result = response.json() if log.isEnabledFor(logging.DEBUG): log.debug("<<< %s", response.text) - except requests.HTTPError as e: - raise TransportServerError( - str(e), e.response.status_code if e.response is not None else None - ) from e - except Exception: - self._raise_invalid_result(str(response.text), "Not a JSON answer") + self._raise_response_error(response, "Not a JSON answer") return result diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py index f04f05e4..e3407a4d 100644 --- a/tests/test_aiohttp_batch.py +++ b/tests/test_aiohttp_batch.py @@ -1,3 +1,4 @@ +import asyncio from typing import Mapping import pytest @@ -7,6 +8,7 @@ TransportClosed, TransportProtocolError, TransportQueryError, + TransportServerError, ) # Marking all tests in this file with the aiohttp marker @@ -29,6 +31,21 @@ '{"code":"SA","name":"South America"}]}}]' ) +query1_server_answer_twice_list = ( + "[" + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' + "]" +) + @pytest.mark.asyncio async def test_aiohttp_batch_query(aiohttp_server): @@ -72,6 +89,179 @@ async def handler(request): assert transport.response_headers["dummy"] == "test1234" +@pytest.mark.asyncio +async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, # 10ms batch interval + ) as session: + + query = gql(query1_str) + + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.asyncio +async def test_aiohttp_batch_auto_two_requests(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + async def test_coroutine(): + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Create two concurrent tasks that will be batched together + tasks = [] + for _ in range(2): + task = asyncio.create_task(test_coroutine()) + tasks.append(task) + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_auto_two_requests_close_session_directly(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.1, + ) as session: + + async def test_coroutine(): + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Create two concurrent tasks that will be batched together + tasks = [] + for _ in range(2): + task = asyncio.create_task(test_coroutine()) + tasks.append(task) + + await asyncio.sleep(0.01) + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_error_code_401(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, # 10ms batch interval + ) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "401, message='Unauthorized'" in str(exc_info.value) + + @pytest.mark.asyncio async def test_aiohttp_batch_query_without_session(aiohttp_server, run_sync_test): from aiohttp import web diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index 4c9e7d76..346dc00e 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -20,7 +20,7 @@ from gql import GraphQLRequest, gql -from .conftest import MS +from .conftest import MS, strip_braces_spaces # Marking all tests in this file with the aiohttp marker pytestmark = pytest.mark.aiohttp @@ -200,3 +200,13 @@ def test_serialize_variables_using_money_example(): req = req.serialize_variable_values(schema) assert req.variable_values == {"money": {"amount": 10, "currency": "DM"}} + + +def test_graphql_request_using_string_instead_of_document(): + request = GraphQLRequest("{balance}") + + expected_payload = "{'query': '{\\n balance\\n}'}" + + print(request) + + assert str(request) == strip_braces_spaces(expected_payload) From b221c0e60d7d02e8d32b8f1e4f113e86d74f1c51 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 27 May 2025 19:15:44 +0000 Subject: [PATCH 220/239] Remove MIT license classifier (#555) --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 706a80c3..3db1c9f8 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,6 @@ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", - "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.9", From b3789ef85f728e7b0dbc60e7159d9eb2cb71476a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 28 May 2025 09:39:37 +0000 Subject: [PATCH 221/239] Using GraphQLRequest instead of DocumentNode for gql, execute, subscribe methods (#556) --- docs/advanced/batching_requests.rst | 2 +- .../code_examples/appsync/mutation_api_key.py | 4 +- docs/code_examples/appsync/mutation_iam.py | 4 +- docs/code_examples/console_async.py | 6 +- docs/code_examples/fastapi_async.py | 5 +- .../reconnecting_mutation_http.py | 4 +- .../code_examples/reconnecting_mutation_ws.py | 4 +- docs/usage/custom_scalars_and_enums.rst | 23 +- docs/usage/file_upload.rst | 36 +-- docs/usage/variables.rst | 10 +- gql/client.py | 236 ++++++------------ gql/dsl.py | 10 +- gql/gql.py | 21 +- gql/graphql_request.py | 98 ++++++-- gql/utilities/get_introspection_query_ast.py | 2 +- tests/custom_scalars/test_datetime.py | 26 +- tests/custom_scalars/test_enum_colors.py | 17 +- tests/custom_scalars/test_json.py | 6 +- tests/custom_scalars/test_money.py | 77 +++--- tests/custom_scalars/test_parse_results.py | 5 +- .../test_dsl_directives.py | 4 +- tests/starwars/test_dsl.py | 110 +++++--- tests/starwars/test_parse_results.py | 20 +- tests/starwars/test_query.py | 30 +-- tests/starwars/test_subscription.py | 14 +- tests/test_aiohttp.py | 227 +++++++++++------ tests/test_aiohttp_batch.py | 20 +- ...iohttp_websocket_graphqlws_subscription.py | 5 +- tests/test_aiohttp_websocket_subscription.py | 13 +- tests/test_appsync_websockets.py | 9 +- tests/test_async_client_validation.py | 16 +- tests/test_client.py | 8 +- tests/test_graphql_request.py | 32 ++- tests/test_graphqlws_subscription.py | 5 +- tests/test_httpx.py | 64 ++--- tests/test_httpx_async.py | 62 ++--- tests/test_httpx_batch.py | 32 ++- tests/test_requests.py | 82 +++--- tests/test_requests_batch.py | 30 +-- tests/test_transport.py | 6 +- tests/test_transport_batch.py | 16 +- tests/test_websocket_subscription.py | 5 +- tests/test_websockets_adapter.py | 4 +- 43 files changed, 685 insertions(+), 725 deletions(-) diff --git a/docs/advanced/batching_requests.rst b/docs/advanced/batching_requests.rst index a71d4ffc..7c9fc9b6 100644 --- a/docs/advanced/batching_requests.rst +++ b/docs/advanced/batching_requests.rst @@ -24,7 +24,7 @@ To execute a batch of requests manually: .. code-block:: python - request1 = GraphQLRequest(""" + request1 = gql(""" query getContinents { continents { code diff --git a/docs/code_examples/appsync/mutation_api_key.py b/docs/code_examples/appsync/mutation_api_key.py index 634e2439..47067aca 100644 --- a/docs/code_examples/appsync/mutation_api_key.py +++ b/docs/code_examples/appsync/mutation_api_key.py @@ -46,9 +46,9 @@ async def main(): }""" ) - variable_values = {"message": "Hello world!"} + query.variable_values = {"message": "Hello world!"} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(result) diff --git a/docs/code_examples/appsync/mutation_iam.py b/docs/code_examples/appsync/mutation_iam.py index 3cc04a5a..efe9889b 100644 --- a/docs/code_examples/appsync/mutation_iam.py +++ b/docs/code_examples/appsync/mutation_iam.py @@ -45,9 +45,9 @@ async def main(): }""" ) - variable_values = {"message": "Hello world!"} + query.variable_values = {"message": "Hello world!"} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(result) diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 6c0b86d0..69c71bce 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -35,13 +35,11 @@ async def close(self): await self._client.close_async() async def get_continent_name(self, code): - params = {"code": code} + self.get_continent_name_query.variable_values = {"code": code} assert self._session is not None - answer = await self._session.execute( - self.get_continent_name_query, variable_values=params - ) + answer = await self._session.execute(self.get_continent_name_query) return answer.get("continent").get("name") # type: ignore diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index f4a5c14b..0b174fe5 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -93,9 +93,8 @@ async def get_continent(continent_code): try: assert isinstance(client.session, ReconnectingAsyncClientSession) - result = await client.session.execute( - query, variable_values={"code": continent_code} - ) + query.variable_values = {"code": continent_code} + result = await client.session.execute(query) except Exception as e: log.debug(f"get_continent Error: {e}") raise HTTPException(status_code=503, detail="GraphQL backend unavailable") diff --git a/docs/code_examples/reconnecting_mutation_http.py b/docs/code_examples/reconnecting_mutation_http.py index f4329c8b..5deb5063 100644 --- a/docs/code_examples/reconnecting_mutation_http.py +++ b/docs/code_examples/reconnecting_mutation_http.py @@ -33,10 +33,10 @@ async def main(): # Execute single query query = gql("mutation ($message: String!) {sendMessage(message: $message)}") - params = {"message": f"test {num}"} + query.variable_values = {"message": f"test {num}"} try: - result = await session.execute(query, variable_values=params) + result = await session.execute(query) print(result) except Exception as e: print(f"Received exception {e}") diff --git a/docs/code_examples/reconnecting_mutation_ws.py b/docs/code_examples/reconnecting_mutation_ws.py index 7d7c8f8a..d7e7cfe2 100644 --- a/docs/code_examples/reconnecting_mutation_ws.py +++ b/docs/code_examples/reconnecting_mutation_ws.py @@ -33,10 +33,10 @@ async def main(): # Execute single query query = gql("mutation ($message: String!) {sendMessage(message: $message)}") - params = {"message": f"test {num}"} + query.variable_values = {"message": f"test {num}"} try: - result = await session.execute(query, variable_values=params) + result = await session.execute(query) print(result) except Exception as e: print(f"Received exception {e}") diff --git a/docs/usage/custom_scalars_and_enums.rst b/docs/usage/custom_scalars_and_enums.rst index fc9008d8..f85b583a 100644 --- a/docs/usage/custom_scalars_and_enums.rst +++ b/docs/usage/custom_scalars_and_enums.rst @@ -203,11 +203,11 @@ In a variable query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = { + query.variable_values = { "time": "2021-11-12T11:58:13.461161", } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) - enum: @@ -220,11 +220,11 @@ In a variable }""" ) - variable_values = { + query.variable_values = { "color": 'RED', } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) Automatically ^^^^^^^^^^^^^ @@ -256,12 +256,10 @@ Examples: query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") # the argument for time is a datetime instance - variable_values = {"time": datetime.now()} + query.variable_values = {"time": datetime.now()} # we execute the query with serialize_variables set to True - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) - enums: @@ -285,14 +283,12 @@ Examples: ) # the argument for time is an instance of our Enum type - variable_values = { + query.variable_values = { "color": Color.RED, } # we execute the query with serialize_variables set to True - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) Parsing output -------------- @@ -319,11 +315,10 @@ Same example as above, with result parsing enabled: query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = {"time": datetime.now()} + query.variable_values = {"time": datetime.now()} result = await session.execute( query, - variable_values=variable_values, serialize_variables=True, parse_result=True, ) diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index 7793354b..09d51742 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -15,7 +15,7 @@ In order to upload a single file, you need to: * set the file as a variable value in the mutation * create a :class:`FileVar ` object with your file path -* provide the `FileVar` instance to the `variable_values` argument of `execute` +* provide the `FileVar` instance to the `variable_values` attribute of your query * set the `upload_files` argument to True .. code-block:: python @@ -37,11 +37,9 @@ In order to upload a single file, you need to: } ''') - params = {"file": FileVar("YOUR_FILE_PATH")} + query.variable_values = {"file": FileVar("YOUR_FILE_PATH")} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) Setting the content-type ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -97,11 +95,9 @@ It is also possible to upload multiple files using a list. f1 = FileVar("YOUR_FILE_PATH_1") f2 = FileVar("YOUR_FILE_PATH_2") - params = {"files": [f1, f2]} + query.variable_values = {"files": [f1, f2]} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) Streaming @@ -150,11 +146,9 @@ setting the `streaming` argument of :class:`FileVar ` to `True` streaming=True, ) - params = {"file": f1} + query.variable_values = {"file": f1} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) Another option is to use an async generator to provide parts of the file. @@ -172,11 +166,9 @@ to read the files in chunks and create this asynchronous generator. yield chunk f1 = FileVar(file_sender(file_name='YOUR_FILE_PATH')) - params = {"file": f1} + query.variable_values = {"file": f1} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) Streaming downloaded files ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -193,7 +185,7 @@ In order to do that, you need to: * get the response from an aiohttp request and then get the StreamReader instance from `resp.content` -* provide the StreamReader instance to the `variable_values` argument of `execute` +* provide the StreamReader instance to the `variable_values` attribute of your query Example: @@ -204,7 +196,7 @@ Example: async with http_client.get('YOUR_DOWNLOAD_URL') as resp: # We now have a StreamReader instance in resp.content - # and we provide it to the variable_values argument of execute + # and we provide it to the variable_values attribute of the query transport = AIOHTTPTransport(url='YOUR_GRAPHQL_URL') @@ -218,8 +210,6 @@ Example: } ''') - params = {"file": FileVar(resp.content)} + query.variable_values = {"file": FileVar(resp.content)} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) diff --git a/docs/usage/variables.rst b/docs/usage/variables.rst index 81924c6e..1eddd042 100644 --- a/docs/usage/variables.rst +++ b/docs/usage/variables.rst @@ -2,7 +2,7 @@ Using variables =============== It is possible to provide variable values with your query by providing a Dict to -the variable_values argument of the `execute` or the `subscribe` methods. +the variable_values attribute of your query. The variable values will be sent alongside the query in the transport message (there is no local substitution). @@ -19,14 +19,14 @@ The variable values will be sent alongside the query in the transport message """ ) - params = {"code": "EU"} + query.variable_values = {"code": "EU"} # Get name of continent with code "EU" - result = client.execute(query, variable_values=params) + result = client.execute(query) print(result) - params = {"code": "AF"} + query.variable_values = {"code": "AF"} # Get name of continent with code "AF" - result = client.execute(query, variable_values=params) + result = client.execute(query) print(result) diff --git a/gql/client.py b/gql/client.py index a0e07056..e17a0b7c 100644 --- a/gql/client.py +++ b/gql/client.py @@ -24,7 +24,6 @@ import backoff from anyio import fail_after from graphql import ( - DocumentNode, ExecutionResult, GraphQLSchema, IntrospectionQuery, @@ -33,7 +32,7 @@ validate, ) -from .graphql_request import GraphQLRequest +from .graphql_request import GraphQLRequest, support_deprecated_request from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportConnectionFailed, TransportQueryError from .transport.local_schema import LocalSchemaTransport @@ -155,13 +154,13 @@ def __init__( def batching_enabled(self) -> bool: return self.batch_interval != 0 - def validate(self, document: DocumentNode) -> None: + def validate(self, request: GraphQLRequest) -> None: """:meta private:""" assert ( self.schema ), "Cannot validate the document locally, you need to pass a schema." - validation_errors = validate(self.schema, document) + validation_errors = validate(self.schema, request.document) if validation_errors: raise validation_errors[0] @@ -205,10 +204,8 @@ def _get_event_loop() -> asyncio.AbstractEventLoop: @overload def execute_sync( self, - document: DocumentNode, - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -218,10 +215,8 @@ def execute_sync( @overload def execute_sync( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -231,10 +226,8 @@ def execute_sync( @overload def execute_sync( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -243,10 +236,8 @@ def execute_sync( def execute_sync( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -255,9 +246,7 @@ def execute_sync( """:meta private:""" with self as session: return session.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -319,10 +308,8 @@ def execute_batch_sync( @overload async def execute_async( self, - document: DocumentNode, - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -332,10 +319,8 @@ async def execute_async( @overload async def execute_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -345,10 +330,8 @@ async def execute_async( @overload async def execute_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -357,10 +340,8 @@ async def execute_async( async def execute_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -369,9 +350,7 @@ async def execute_async( """:meta private:""" async with self as session: return await session.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -433,10 +412,8 @@ async def execute_batch_async( @overload def execute( self, - document: DocumentNode, - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -446,10 +423,8 @@ def execute( @overload def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -459,10 +434,8 @@ def execute( @overload def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -471,16 +444,14 @@ def execute( def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: - """Execute the provided document AST against the remote server using + """Execute the provided request against the remote server using the transport provided during init. This function **WILL BLOCK** until the result is received from the server. @@ -512,9 +483,7 @@ def execute( data = loop.run_until_complete( self.execute_async( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -526,9 +495,7 @@ def execute( else: # Sync transports return self.execute_sync( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -631,10 +598,8 @@ def execute_batch( @overload def subscribe_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -644,10 +609,8 @@ def subscribe_async( @overload def subscribe_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -657,10 +620,8 @@ def subscribe_async( @overload def subscribe_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -671,10 +632,8 @@ def subscribe_async( async def subscribe_async( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -685,9 +644,7 @@ async def subscribe_async( """:meta private:""" async with self as session: generator = session.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -700,10 +657,8 @@ async def subscribe_async( @overload def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -713,10 +668,8 @@ def subscribe( @overload def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -726,10 +679,8 @@ def subscribe( @overload def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -740,10 +691,8 @@ def subscribe( def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -766,9 +715,7 @@ def subscribe( async_generator: Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ] = self.subscribe_async( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -941,9 +888,13 @@ def _execute( The extra arguments are passed to the transport execute method.""" + # Still supporting for now old method of providing + # variable_values and operation_name + request = support_deprecated_request(request, kwargs) + # Validate document if self.client.schema: - self.client.validate(request.document) + self.client.validate(request) # Parse variable values for custom scalars if requested if request.variable_values is not None: @@ -977,10 +928,8 @@ def _execute( @overload def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -990,10 +939,8 @@ def execute( @overload def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -1003,10 +950,8 @@ def execute( @overload def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -1015,24 +960,20 @@ def execute( def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: - """Execute the provided document AST synchronously using + """Execute the provided request synchronously using the sync transport. Raises a TransportQueryError if an error has been returned in the ExecutionResult. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL query as :class:`GraphQLRequest `. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1043,13 +984,6 @@ def execute( The extra arguments are passed to the transport execute method.""" - # Make GraphQLRequest object - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - # Validate and execute on the transport result = self._execute( request, @@ -1103,7 +1037,7 @@ def _execute_batch( if validate_document: for req in requests: - self.client.validate(req.document) + self.client.validate(req) # Parse variable values for custom scalars if requested if serialize_variables or ( @@ -1326,9 +1260,7 @@ def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = self.transport.execute( - GraphQLRequest(document=introspection_query) - ) + execution_result = self.transport.execute(GraphQLRequest(introspection_query)) self.client._build_schema_from_introspection(execution_result) @@ -1374,9 +1306,13 @@ async def _subscribe( The extra arguments are passed to the transport subscribe method.""" + # Still supporting for now old method of providing + # variable_values and operation_name + request = support_deprecated_request(request, kwargs) + # Validate document if self.client.schema: - self.client.validate(request.document) + self.client.validate(request) # Parse variable values for custom scalars if requested if request.variable_values is not None: @@ -1418,10 +1354,8 @@ async def _subscribe( @overload def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -1431,10 +1365,8 @@ def subscribe( @overload def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -1444,10 +1376,8 @@ def subscribe( @overload def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -1458,10 +1388,8 @@ def subscribe( async def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, @@ -1469,15 +1397,13 @@ async def subscribe( ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: - """Coroutine to subscribe asynchronously to the provided document AST + """Coroutine to subscribe asynchronously to the provided request asynchronously using the async transport. Raises a TransportQueryError if an error has been returned in the ExecutionResult. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL query as :class:`GraphQLRequest `. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1488,13 +1414,6 @@ async def subscribe( The extra arguments are passed to the transport subscribe method.""" - # Make GraphQLRequest object - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( request, serialize_variables=serialize_variables, @@ -1536,8 +1455,8 @@ async def _execute( * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param request: graphql request as a - :class:`graphqlrequest ` object. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1546,9 +1465,13 @@ async def _execute( The extra arguments are passed to the transport execute method.""" + # Still supporting for now old method of providing + # variable_values and operation_name + request = support_deprecated_request(request, kwargs) + # Validate document if self.client.schema: - self.client.validate(request.document) + self.client.validate(request) # Parse variable values for custom scalars if requested if request.variable_values is not None: @@ -1584,10 +1507,8 @@ async def _execute( @overload async def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[False] = ..., @@ -1597,10 +1518,8 @@ async def execute( @overload async def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: Literal[True], @@ -1610,10 +1529,8 @@ async def execute( @overload async def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., get_execution_result: bool, @@ -1622,24 +1539,20 @@ async def execute( async def execute( self, - document: DocumentNode, + request: GraphQLRequest, *, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: - """Coroutine to execute the provided document AST asynchronously using + """Coroutine to execute the provided request asynchronously using the async transport. Raises a TransportQueryError if an error has been returned in the ExecutionResult. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL query as :class:`GraphQLRequest `. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1650,13 +1563,6 @@ async def execute( The extra arguments are passed to the transport execute method.""" - # Make GraphQLRequest object - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - # Validate and execute on the transport result = await self._execute( request, @@ -1710,7 +1616,7 @@ async def _execute_batch( if validate_document: for req in requests: - self.client.validate(req.document) + self.client.validate(req) # Parse variable values for custom scalars if requested if serialize_variables or ( diff --git a/gql/dsl.py b/gql/dsl.py index e5b5131e..1a8716c2 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -64,6 +64,7 @@ ) from graphql.pyutils import inspect +from .graphql_request import GraphQLRequest from .utils import to_camel_case log = logging.getLogger(__name__) @@ -214,7 +215,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: def dsl_gql( *operations: "DSLExecutable", **operations_with_name: "DSLExecutable" -) -> DocumentNode: +) -> GraphQLRequest: r"""Given arguments instances of :class:`DSLExecutable` containing GraphQL operations or fragments, generate a Document which can be executed later in a @@ -231,7 +232,8 @@ def dsl_gql( :param \**operations_with_name: the GraphQL operations with an operation name :type \**operations_with_name: DSLQuery, DSLMutation, DSLSubscription - :return: a Document which can be later executed or subscribed by a + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a :class:`sync session ` @@ -259,10 +261,12 @@ def dsl_gql( f"Received: {type(operation)}." ) - return DocumentNode( + document = DocumentNode( definitions=[operation.executable_ast for operation in all_operations] ) + return GraphQLRequest(document) + class DSLSchema: """The DSLSchema is the root of the DSL code. diff --git a/gql/gql.py b/gql/gql.py index e9705947..f4cd3aea 100644 --- a/gql/gql.py +++ b/gql/gql.py @@ -1,24 +1,17 @@ -from __future__ import annotations +from .graphql_request import GraphQLRequest -from graphql import DocumentNode, Source, parse - -def gql(request_string: str | Source) -> DocumentNode: - """Given a string containing a GraphQL request, parse it into a Document. +def gql(request_string: str) -> GraphQLRequest: + """Given a string containing a GraphQL request, + parse it into a Document and put it into a GraphQLRequest object :param request_string: the GraphQL request as a String - :type request_string: str | Source - :return: a Document which can be later executed or subscribed by a + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a :class:`sync session ` :raises graphql.error.GraphQLError: if a syntax error is encountered. """ - if isinstance(request_string, Source): - source = request_string - elif isinstance(request_string, str): - source = Source(request_string, "GraphQL request") - else: - raise TypeError("Request must be passed as a string or Source object.") - return parse(source) + return GraphQLRequest(request_string) diff --git a/gql/graphql_request.py b/gql/graphql_request.py index 29a34717..fe3523a9 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,9 +1,7 @@ +import warnings from typing import Any, Dict, Optional, Union -from graphql import DocumentNode, GraphQLSchema, print_ast - -from .gql import gql -from .utilities import serialize_variable_values +from graphql import DocumentNode, GraphQLSchema, Source, parse, print_ast class GraphQLRequest: @@ -11,7 +9,7 @@ class GraphQLRequest: def __init__( self, - document: Union[DocumentNode, str], + request: Union[DocumentNode, "GraphQLRequest", str], *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, @@ -19,26 +17,46 @@ def __init__( """ Initialize a GraphQL request. - Args: - document: GraphQL query as AST Node object or as a string. - If string, it will be converted to DocumentNode using gql(). - variable_values: Dictionary of input parameters (Default: None). - operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as DocumentNode object or as a string. + If string, it will be converted to DocumentNode. + :param variable_values: Dictionary of input parameters (Default: None). + :param operation_name: Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` + :raises graphql.error.GraphQLError: if a syntax error is encountered. + """ - if isinstance(document, str): - self.document = gql(document) - else: - self.document = document + if isinstance(request, str): + source = Source(request, "GraphQL request") + self.document = parse(source) + elif isinstance(request, DocumentNode): + self.document = request + elif not isinstance(request, GraphQLRequest): + raise TypeError(f"Unexpected type for GraphQLRequest: {type(request)}") - self.variable_values = variable_values - self.operation_name = operation_name + if isinstance(request, GraphQLRequest): + self.document = request.document + if variable_values is None: + variable_values = request.variable_values + if operation_name is None: + operation_name = request.operation_name + + self.variable_values: Optional[Dict[str, Any]] = variable_values + self.operation_name: Optional[str] = operation_name def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": + + from .utilities.serialize_variable_values import serialize_variable_values + assert self.variable_values return GraphQLRequest( - document=self.document, + self.document, variable_values=serialize_variable_values( schema=schema, document=self.document, @@ -63,3 +81,47 @@ def payload(self) -> Dict[str, Any]: def __str__(self): return str(self.payload) + + +def support_deprecated_request( + request: Union[GraphQLRequest, DocumentNode], + kwargs: Dict, +) -> GraphQLRequest: + """This methods is there temporarily to convert the old style of calling + execute and subscribe methods with a DocumentNode, + variable_values and operation_name arguments. + """ + + if isinstance(request, DocumentNode): + warnings.warn( + ( + "Using a DocumentNode is deprecated. Please use a " + "GraphQLRequest instead." + ), + DeprecationWarning, + stacklevel=2, + ) + request = GraphQLRequest(request) + + if not isinstance(request, GraphQLRequest): + raise TypeError("request should be a GraphQLRequest object") + + variable_values = kwargs.pop("variable_values", None) + operation_name = kwargs.pop("operation_name", None) + + if variable_values or operation_name: + warnings.warn( + ( + "Using variable_values and operation_name arguments of " + "execute and subscribe methods is deprecated. Instead, " + "please use the variable_values and operation_name properties " + "of GraphQLRequest" + ), + DeprecationWarning, + stacklevel=2, + ) + + request.variable_values = variable_values + request.operation_name = operation_name + + return request diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index 4d6a243f..0422a225 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -139,4 +139,4 @@ def get_introspection_query_ast( dsl_query = dsl_gql(query, fragment_FullType, fragment_InputValue, fragment_TypeRef) - return dsl_query + return dsl_query.document diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index 5a36669c..4d9589f1 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -117,11 +117,11 @@ def test_shift_days(): query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = { + query.variable_values = { "time": now, } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) print(result) @@ -151,11 +151,11 @@ def test_shift_days_serialized_manually_in_variables(): query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = { + query.variable_values = { "time": "2021-11-12T11:58:13.461161", } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) print(result) @@ -171,13 +171,11 @@ def test_latest(): query = gql("query latest($times: [Datetime!]!) {latest(times: $times)}") - variable_values = { + query.variable_values = { "times": [now, in_five_days], } - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) @@ -194,11 +192,9 @@ def test_seconds(): "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" ) - variable_values = {"interval": {"start": now, "end": in_five_days}} + query.variable_values = {"interval": {"start": now, "end": in_five_days}} - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) @@ -214,11 +210,9 @@ def test_seconds_omit_optional_start_argument(): "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" ) - variable_values = {"interval": {"end": in_five_days}} + query.variable_values = {"interval": {"end": in_five_days}} - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 3526d548..ff893571 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -165,11 +165,11 @@ def test_opposite_color_variable_serialized_manually(): }""" ) - variable_values = { + query.variable_values = { "color": "RED", } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) print(result) @@ -190,13 +190,11 @@ def test_opposite_color_variable_serialized_by_gql(): }""" ) - variable_values = { + query.variable_values = { "color": RED, } - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) @@ -328,13 +326,12 @@ def test_parse_results_with_operation_type(): """ ) - variable_values = { + query.variable_values = { "color": "RED", } + query.operation_name = "GetOppositeColor" - result = client.execute( - query, variable_values=variable_values, operation_name="GetOppositeColor" - ) + result = client.execute(query) print(result) diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index d3eae3b8..903dfa6d 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -166,7 +166,7 @@ def test_json_value_input_in_ast_with_variables(): }""" ) - variable_values = { + query.variable_values = { "name": "Barbara", "level": 1, "is_connected": False, @@ -174,9 +174,7 @@ def test_json_value_input_in_ast_with_variables(): "friends": ["Alex", "John"], } - result = client.execute( - query, variable_values=variable_values, root_value=root_value - ) + result = client.execute(query, root_value=root_value) print(result) diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 8b4a99f4..55a6577a 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -20,7 +20,7 @@ ) from graphql.utilities import value_from_ast_untyped -from gql import Client, GraphQLRequest, gql +from gql import Client, gql from gql.transport.exceptions import TransportQueryError from gql.utilities import serialize_value, update_schema_scalar, update_schema_scalars @@ -275,11 +275,9 @@ def test_custom_scalar_in_input_variable_values(): money_value = {"amount": 10, "currency": "DM"} - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} - result = client.execute( - query, variable_values=variable_values, root_value=root_value - ) + result = client.execute(query, root_value=root_value) assert result["toEuros"] == 5 @@ -292,11 +290,10 @@ def test_custom_scalar_in_input_variable_values_serialized(): money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} result = client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, ) @@ -312,14 +309,13 @@ def test_custom_scalar_in_input_variable_values_serialized_with_operation_name() money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} + query.operation_name = "myquery" result = client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, - operation_name="myquery", ) assert result["toEuros"] == 5 @@ -342,12 +338,11 @@ def test_serialize_variable_values_exception_multiple_ops_without_operation_name money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} with pytest.raises(GraphQLError) as exc_info: client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, ) @@ -374,15 +369,14 @@ def test_serialize_variable_values_exception_operation_name_not_found(): money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} + query.operation_name = "invalid_operation_name" with pytest.raises(GraphQLError) as exc_info: client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, - operation_name="invalid_operation_name", ) exception = exc_info.value @@ -398,13 +392,12 @@ def test_custom_scalar_subscribe_in_input_variable_values_serialized(): money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} expected_result = {"spend": Money(10, "DM")} for result in client.subscribe( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, parse_result=True, @@ -544,9 +537,9 @@ async def test_custom_scalar_in_input_variable_values_with_transport(aiohttp_ser money_value = {"amount": 10, "currency": "DM"} # money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -570,9 +563,9 @@ async def test_custom_scalar_in_input_variable_values_split_with_transport( }""" ) - variable_values = {"amount": 10, "currency": "DM"} + query.variable_values = {"amount": 10, "currency": "DM"} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -590,11 +583,9 @@ async def test_custom_scalar_serialize_variables(aiohttp_server): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -611,12 +602,10 @@ async def test_custom_scalar_serialize_variables_no_schema(aiohttp_server): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} with pytest.raises(TransportQueryError): - await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + await session.execute(query, serialize_variables=True) @pytest.mark.asyncio @@ -643,11 +632,9 @@ async def test_custom_scalar_serialize_variables_schema_from_introspection( query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -667,11 +654,9 @@ async def test_update_schema_scalars(aiohttp_server): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -743,11 +728,9 @@ def test_code(): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -767,12 +750,12 @@ def test_code(): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} results = session.execute_batch( [ - GraphQLRequest(document=query, variable_values=variable_values), - GraphQLRequest(document=query, variable_values=variable_values), + query, + query, ], serialize_variables=True, ) @@ -795,12 +778,12 @@ async def test_custom_scalar_serialize_variables_async_transport(aiohttp_server) query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} results = await session.execute_batch( [ - GraphQLRequest(document=query, variable_values=variable_values), - GraphQLRequest(document=query, variable_values=variable_values), + query, + query, ], serialize_variables=True, ) diff --git a/tests/custom_scalars/test_parse_results.py b/tests/custom_scalars/test_parse_results.py index e3c6d6f6..32812818 100644 --- a/tests/custom_scalars/test_parse_results.py +++ b/tests/custom_scalars/test_parse_results.py @@ -93,6 +93,5 @@ def test_parse_results_null_mapping(): } }""" ) - assert client.execute(query, variable_values={"count": 2}) == { - "test": static_result - } + query.variable_values = {"count": 2} + assert client.execute(query) == {"test": static_result} diff --git a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py index e4653d48..67c2e739 100644 --- a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py +++ b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py @@ -65,10 +65,10 @@ def test_issue_447(): client.validate(q) # Creating a tree from the DocumentNode created by dsl_gql - dsl_tree = node_tree(q) + dsl_tree = node_tree(q.document) # Creating a tree from the DocumentNode created by gql - gql_tree = node_tree(gql(print_ast(q))) + gql_tree = node_tree(gql(print_ast(q.document)).document) print("=======") print(dsl_tree) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index d96435fc..e47a97d8 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -143,7 +143,7 @@ def test_use_variable_definition_multiple_times(ds): query = dsl_gql(op) assert ( - print_ast(query) + print_ast(query.document) == """mutation \ ($badReview: ReviewInput, $episode: Episode, $goodReview: ReviewInput) { badReview: createReview(review: $badReview, episode: $episode) { @@ -157,7 +157,9 @@ def test_use_variable_definition_multiple_times(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_add_variable_definitions(ds): @@ -171,7 +173,7 @@ def test_add_variable_definitions(ds): query = dsl_gql(op) assert ( - print_ast(query) + print_ast(query.document) == """mutation ($review: ReviewInput, $episode: Episode) { createReview(review: $review, episode: $episode) { stars @@ -180,7 +182,9 @@ def test_add_variable_definitions(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_add_variable_definitions_with_default_value_enum(ds): @@ -194,7 +198,7 @@ def test_add_variable_definitions_with_default_value_enum(ds): query = dsl_gql(op) assert ( - print_ast(query) + print_ast(query.document) == """mutation ($review: ReviewInput, $episode: Episode = NEWHOPE) { createReview(review: $review, episode: $episode) { stars @@ -216,7 +220,7 @@ def test_add_variable_definitions_with_default_value_input_object(ds): query = dsl_gql(op) assert ( - strip_braces_spaces(print_ast(query)) + strip_braces_spaces(print_ast(query.document)) == """ mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { createReview(review: $review, episode: $episode) { @@ -226,7 +230,9 @@ def test_add_variable_definitions_with_default_value_input_object(ds): }""".strip() ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_add_variable_definitions_in_input_object(ds): @@ -241,7 +247,7 @@ def test_add_variable_definitions_in_input_object(ds): query = dsl_gql(op) assert ( - strip_braces_spaces(print_ast(query)) + strip_braces_spaces(print_ast(query.document)) == """mutation ($stars: Int, $commentary: String, $episode: Episode) { createReview( review: {stars: $stars, commentary: $commentary} @@ -253,7 +259,9 @@ def test_add_variable_definitions_in_input_object(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_invalid_field_on_type_query(ds): @@ -416,7 +424,9 @@ def test_hero_name_query_result(ds, client): result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_arg_serializer_list(ds, client): @@ -436,7 +446,9 @@ def test_arg_serializer_list(ds, client): ] } assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_arg_serializer_enum(ds, client): @@ -444,7 +456,9 @@ def test_arg_serializer_enum(ds, client): result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_create_review_mutation_result(ds, client): @@ -459,7 +473,9 @@ def test_create_review_mutation_result(ds, client): result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_subscription(ds): @@ -472,7 +488,7 @@ def test_subscription(ds): ) ) assert ( - print_ast(query) + print_ast(query.document) == """subscription { reviewAdded(episode: JEDI) { stars @@ -481,7 +497,9 @@ def test_subscription(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_field_does_not_exit_in_type(ds): @@ -522,7 +540,9 @@ def test_multiple_root_fields(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_root_fields_aliased(ds, client): @@ -538,7 +558,9 @@ def test_root_fields_aliased(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_operation_name(ds): @@ -549,7 +571,7 @@ def test_operation_name(ds): ) assert ( - print_ast(query) + print_ast(query.document) == """query GetHeroName { hero { name @@ -557,7 +579,9 @@ def test_operation_name(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_multiple_operations(ds): @@ -571,7 +595,7 @@ def test_multiple_operations(ds): ) assert ( - strip_braces_spaces(print_ast(query)) + strip_braces_spaces(print_ast(query.document)) == """query GetHeroName { hero { name @@ -589,7 +613,9 @@ def test_multiple_operations(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_inline_fragments(ds): @@ -656,12 +682,14 @@ def test_fragments(ds): query_dsl = DSLQuery(ds.Query.hero.select(name_and_appearances)) - document = dsl_gql(name_and_appearances, query_dsl) + request = dsl_gql(name_and_appearances, query_dsl) + + document = request.document print(print_ast(document)) assert query == print_ast(document) - assert node_tree(document) == node_tree(gql(print_ast(document))) + assert node_tree(document) == node_tree(gql(print_ast(document)).document) def test_fragment_without_type_condition_error(ds): @@ -753,12 +781,14 @@ def test_dsl_nested_query_with_fragment(ds): ) ) - document = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + request = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + + document = request.document print(print_ast(document)) assert query == print_ast(document) - assert node_tree(document) == node_tree(gql(print_ast(document))) + assert node_tree(document) == node_tree(gql(print_ast(document)).document) # Same thing, but incrementaly @@ -779,12 +809,14 @@ def test_dsl_nested_query_with_fragment(ds): query_dsl = DSLQuery(hero) - document = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + request = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + + document = request.document print(print_ast(document)) assert query == print_ast(document) - assert node_tree(document) == node_tree(gql(print_ast(document))) + assert node_tree(document) == node_tree(gql(print_ast(document)).document) def test_dsl_query_all_fields_should_be_instances_of_DSLField(): @@ -828,7 +860,7 @@ def test_dsl_root_type_not_default(): version } """ - assert print_ast(query) == expected_query.strip() + assert print_ast(query.document) == expected_query.strip() with pytest.raises(GraphQLError) as excinfo: DSLSubscription(ds.QueryNotDefault.version) @@ -837,7 +869,9 @@ def test_dsl_root_type_not_default(): "Invalid field for : " ) in str(excinfo.value) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): @@ -925,7 +959,7 @@ def test_type_hero_query(ds): ) query_dsl = DSLQuery(type_hero) - assert query == str(print_ast(dsl_gql(query_dsl))).strip() + assert query == str(print_ast(dsl_gql(query_dsl).document)).strip() def test_invalid_meta_field_selection(ds): @@ -1000,9 +1034,11 @@ def test_get_introspection_query_ast(option): ) try: - assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert print_ast(gql(introspection_query).document) == print_ast( + dsl_introspection_query + ) assert node_tree(dsl_introspection_query) == node_tree( - gql(print_ast(dsl_introspection_query)) + gql(print_ast(dsl_introspection_query)).document ) except AssertionError: @@ -1015,9 +1051,11 @@ def test_get_introspection_query_ast(option): input_value_deprecation=option, type_recursion_level=9, ) - assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert print_ast(gql(introspection_query).document) == print_ast( + dsl_introspection_query + ) assert node_tree(dsl_introspection_query) == node_tree( - gql(print_ast(dsl_introspection_query)) + gql(print_ast(dsl_introspection_query)).document ) @@ -1047,7 +1085,7 @@ def test_node_tree_with_loc(ds): } }""".strip() - document = gql(query) + document = gql(query).document node_tree_result = """ DocumentNode @@ -1232,4 +1270,4 @@ def test_legacy_fragment_with_variables(ds): } } """.strip() - assert print_ast(query) == expected + assert print_ast(query.document) == expected diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py index 8020b586..2ae94ea8 100644 --- a/tests/starwars/test_parse_results.py +++ b/tests/starwars/test_parse_results.py @@ -22,6 +22,7 @@ def test_hero_name_and_friends_query(): } """ ) + result = { "hero": { "id": "2001", @@ -34,7 +35,7 @@ def test_hero_name_and_friends_query(): } } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -58,6 +59,7 @@ def test_hero_name_and_friends_query_with_fragment(): } """ ) + result = { "hero": { "id": "2001", @@ -70,7 +72,7 @@ def test_hero_name_and_friends_query_with_fragment(): } } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -91,7 +93,7 @@ def test_key_not_found_in_result(): # Should be impossible. In that case, we ignore the missing key result: Dict[str, Any] = {} - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -112,7 +114,7 @@ def test_invalid_result_raise_error(): with pytest.raises(GraphQLError) as exc_info: - parse_result(StarWarsSchema, query, result) + parse_result(StarWarsSchema, query.document, result) assert "Invalid result for container of field id: 5" in str(exc_info) @@ -141,7 +143,7 @@ def test_fragment(): "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -164,7 +166,7 @@ def test_fragment_not_found(): with pytest.raises(GraphQLError) as exc_info: - parse_result(StarWarsSchema, query, result) + parse_result(StarWarsSchema, query.document, result) assert 'Fragment "HumanFragment" not found in document!' in str(exc_info) @@ -183,7 +185,7 @@ def test_return_none_if_result_is_none(): result = None - assert parse_result(StarWarsSchema, query, result) is None + assert parse_result(StarWarsSchema, query.document, result) is None def test_null_result_is_allowed(): @@ -200,7 +202,7 @@ def test_null_result_is_allowed(): result = {"hero": None} - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -224,6 +226,6 @@ def test_inline_fragment(): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index 7a2a8084..ff2af7d7 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,5 +1,5 @@ import pytest -from graphql import GraphQLError, Source +from graphql import GraphQLError from gql import Client, gql from tests.starwars.schema import StarWarsSchema @@ -136,11 +136,11 @@ def test_fetch_some_id_query(client): } """ ) - params = { + query.variable_values = { "someId": "1000", } expected = {"human": {"name": "Luke Skywalker"}} - result = client.execute(query, variable_values=params) + result = client.execute(query) assert result == expected @@ -154,11 +154,11 @@ def test_fetch_some_id_query2(client): } """ ) - params = { + query.variable_values = { "someId": "1002", } expected = {"human": {"name": "Han Solo"}} - result = client.execute(query, variable_values=params) + result = client.execute(query) assert result == expected @@ -172,11 +172,11 @@ def test_invalid_id_query(client): } """ ) - params = { + query.variable_values = { "id": "not a valid id", } expected = {"human": None} - result = client.execute(query, variable_values=params) + result = client.execute(query) assert result == expected @@ -316,24 +316,10 @@ def test_mutation_result(client): } """ ) - params = { + query.variable_values = { "ep": "JEDI", "review": {"stars": 5, "commentary": "This is a great movie!"}, } expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} - result = client.execute(query, variable_values=params) - assert result == expected - - -def test_query_from_source(client): - source = Source("{ hero { name } }") - query = gql(source) - expected = {"hero": {"name": "R2-D2"}} result = client.execute(query) assert result == expected - - -def test_already_parsed_query(client): - query = gql("{ hero { name } }") - with pytest.raises(TypeError, match="must be passed as a string"): - gql(query) # type: ignore diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index bbaafd5c..4f5f425b 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -3,7 +3,7 @@ import pytest from graphql import ExecutionResult, GraphQLError, subscribe -from gql import Client, GraphQLRequest, gql +from gql import Client, gql from .fixtures import reviews from .schema import StarWarsSchema @@ -41,7 +41,7 @@ async def test_subscription_support(): expected = [{**review, "episode": "JEDI"} for review in reviews[6]] ai = await await_if_coroutine( - subscribe(StarWarsSchema, subs, variable_values=params) + subscribe(StarWarsSchema, subs.document, variable_values=params) ) result = [result.data["reviewAdded"] async for result in ai] @@ -59,14 +59,14 @@ async def test_subscription_support_using_client(): subs = gql(subscription_str) - params = {"ep": "JEDI"} + subs.variable_values = {"ep": "JEDI"} expected = [{**review, "episode": "JEDI"} for review in reviews[6]] async with Client(schema=StarWarsSchema) as session: results = [ result["reviewAdded"] async for result in await await_if_coroutine( - session.subscribe(subs, variable_values=params, parse_result=False) + session.subscribe(subs, parse_result=False) ) ] @@ -85,7 +85,7 @@ async def test_subscription_support_using_client_invalid_field(): subs = gql(subscription_invalid_str) - params = {"ep": "JEDI"} + subs.variable_values = {"ep": "JEDI"} async with Client(schema=StarWarsSchema) as session: @@ -93,9 +93,7 @@ async def test_subscription_support_using_client_invalid_field(): results = [ result async for result in await await_if_coroutine( - session.transport.subscribe( - GraphQLRequest(subs, variable_values=params) - ) + session.transport.subscribe(subs) ) ] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 24f82c9d..e3ac08c4 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -6,7 +6,7 @@ import pytest -from gql import Client, FileVar, GraphQLRequest, gql +from gql import Client, FileVar, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -421,7 +421,7 @@ async def handler(request): query = gql(query1_str) with pytest.raises(TransportClosed): - await transport.execute(GraphQLRequest(query)) + await transport.execute(query) @pytest.mark.asyncio @@ -491,14 +491,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute( - query, variable_values=params, operation_name="getEurope" - ) + result = await session.execute(query) continent = result["continent"] @@ -528,14 +527,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute( - query, variable_values=params, operation_name="getEurope" - ) + result = await session.execute(query) continent = result["continent"] @@ -660,16 +658,14 @@ async def test_aiohttp_file_upload(aiohttp_server): # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} # Execute query asynchronously with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -677,22 +673,19 @@ async def test_aiohttp_file_upload(aiohttp_server): # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = {"file": FileVar(f), "other_var": 42} + query.variable_values = {"file": FileVar(f), "other_var": 42} with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success # Using an filename string inside a FileVar object - params = {"file": FileVar(file_path), "other_var": 42} - result = await session.execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -735,15 +728,13 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): # Setting the content_type f.content_type = "application/pdf" # type: ignore - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -751,7 +742,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = { + query.variable_values = { "file": FileVar( f, content_type="application/pdf", @@ -759,15 +750,13 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): "other_var": 42, } - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success # Using an filename string inside a FileVar object - params = { + query.variable_values = { "file": FileVar( file_path, content_type="application/pdf", @@ -775,9 +764,7 @@ async def test_aiohttp_file_upload_with_content_type(aiohttp_server): "other_var": 42, } - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -815,16 +802,14 @@ async def test_aiohttp_file_upload_default_filename_is_basename(aiohttp_server): query = gql(file_upload_mutation_1) - params = { + query.variable_values = { "file": FileVar( file_path, ), "other_var": 42, } - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -861,7 +846,7 @@ async def test_aiohttp_file_upload_with_filename(aiohttp_server): query = gql(file_upload_mutation_1) - params = { + query.variable_values = { "file": FileVar( file_path, filename="filename1.txt", @@ -869,9 +854,7 @@ async def test_aiohttp_file_upload_with_filename(aiohttp_server): "other_var": 42, } - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -908,9 +891,9 @@ def test_code(): file_path = test_file.filename - params = {"file": FileVar(file_path), "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - result = client.execute(query, variable_values=params, upload_files=True) + result = client.execute(query, upload_files=True) success = result["success"] assert success @@ -952,12 +935,10 @@ async def test_aiohttp_binary_file_upload(aiohttp_server): file_path = test_file.filename - params = {"file": FileVar(file_path), "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] @@ -1003,15 +984,13 @@ async def binary_data_handler(request): query = gql(file_upload_mutation_1) async with ClientSession() as client: async with client.get(binary_data_url) as resp: - params = {"file": resp.content, "other_var": 42} + query.variable_values = {"file": resp.content, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -1021,11 +1000,9 @@ async def binary_data_handler(request): query = gql(file_upload_mutation_1) async with ClientSession() as client: async with client.get(binary_data_url) as resp: - params = {"file": FileVar(resp.content), "other_var": 42} + query.variable_values = {"file": FileVar(resp.content), "other_var": 42} - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -1074,15 +1051,13 @@ async def file_sender(file_name): # Not using FileVar async with Client(transport=transport) as session: - params = {"file": file_sender(file_path), "other_var": 42} + query.variable_values = {"file": file_sender(file_path), "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -1090,12 +1065,13 @@ async def file_sender(file_name): # Using FileVar async with Client(transport=transport) as session: - params = {"file": FileVar(file_sender(file_path)), "other_var": 42} + query.variable_values = { + "file": FileVar(file_sender(file_path)), + "other_var": 42, + } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -1103,15 +1079,13 @@ async def file_sender(file_name): # Using FileVar with new streaming support async with Client(transport=transport) as session: - params = { + query.variable_values = { "file": FileVar(file_path, streaming=True), "other_var": 42, } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -1171,14 +1145,12 @@ async def test_aiohttp_file_upload_two_files(aiohttp_server): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - params = { + query.variable_values = { "file1": FileVar(file_path_1), "file2": FileVar(file_path_2), } - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] @@ -1241,7 +1213,7 @@ async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - params = { + query.variable_values = { "files": [ FileVar(file_path_1), FileVar(file_path_2), @@ -1249,9 +1221,7 @@ async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] @@ -1829,3 +1799,104 @@ async def handler(request): assert africa["code"] == "AF" await connector.close() + + +@pytest.mark.asyncio +async def test_aiohttp_deprecation_warning_using_document_node_execute(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.warns( + DeprecationWarning, + match="Using a DocumentNode is deprecated", + ): + result = await session.execute(query.document) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +async def test_aiohttp_deprecation_warning_execute_variable_values(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query2_str) + + with pytest.warns( + DeprecationWarning, + match=( + "Using variable_values and operation_name arguments of " + "execute and subscribe methods is deprecated" + ), + ): + result = await session.execute( + query, + variable_values={"code": "EU"}, + operation_name="getEurope", + ) + + continent = result["continent"] + + assert continent["name"] == "Europe" + + +@pytest.mark.asyncio +async def test_aiohttp_type_error_execute(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + with pytest.raises(TypeError) as exc_info: + await session.execute("qmlsdkfj") + + assert "request should be a GraphQLRequest object" in str(exc_info.value) diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py index e3407a4d..ad9924a0 100644 --- a/tests/test_aiohttp_batch.py +++ b/tests/test_aiohttp_batch.py @@ -70,7 +70,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Execute query asynchronously results = await session.execute_batch(query) @@ -286,7 +286,7 @@ def test_code(): client = Client(transport=transport) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] results = client.execute_batch(query) @@ -330,7 +330,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportQueryError): await session.execute_batch(query) @@ -368,7 +368,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportProtocolError): await session.execute_batch(query) @@ -398,7 +398,7 @@ async def handler(request): transport = AIOHTTPTransport(url=url, timeout=10) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportClosed): await transport.execute_batch(query) @@ -433,7 +433,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Passing extra arguments to the post method of aiohttp results = await session.execute_batch( @@ -480,7 +480,7 @@ async def handler(request): transport = AIOHTTPTransport(url=url) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] async with Client(transport=transport) as session: @@ -504,15 +504,13 @@ async def test_aiohttp_batch_online_manual(): transport=AIOHTTPTransport(url=ONLINE_URL, timeout=10), ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ async with client as session: diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 22dd1004..e03ad8f9 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -425,10 +425,9 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 32daf038..f06046df 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -444,10 +444,9 @@ async def test_aiohttp_websocket_subscription_with_operation_name( count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -751,15 +750,13 @@ async def test_async_aiohttp_client_validation(server, subscription_str, client_ async with client as session: - variable_values = {"ep": "JEDI"} - subscription = gql(subscription_str) + subscription.variable_values = {"ep": "JEDI"} + expected = [] - async for result in session.subscribe( - subscription, variable_values=variable_values, parse_result=False - ): + async for result in session.subscribe(subscription, parse_result=False): review = result["reviewAdded"] expected.append(review) diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 0be04034..b2299960 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -514,10 +514,10 @@ async def test_appsync_execute_method_not_allowed(server): }""" ) - variable_values = {"message": "Hello world!"} + query.variable_values = {"message": "Hello world!"} with pytest.raises(AssertionError) as exc_info: - await session.execute(query, variable_values=variable_values) + await session.execute(query) assert ( "execute method is not allowed for AppSyncWebsocketsTransport " @@ -693,10 +693,11 @@ async def test_appsync_subscription_variable_values_and_operation_name(server): async with client as session: subscription = gql(on_create_message_subscription_str) + subscription.variable_values = {"key1": "val1"} + subscription.operation_name = "onCreateMessage" + async for execution_result in session.subscribe( subscription, - operation_name="onCreateMessage", - variable_values={"key1": "val1"}, get_execution_result=True, ): diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index c256e5dd..ec73593e 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -97,15 +97,13 @@ async def test_async_client_validation(server, subscription_str, client_params): async with client as session: - variable_values = {"ep": "JEDI"} - subscription = gql(subscription_str) + subscription.variable_values = {"ep": "JEDI"} + expected = [] - async for result in session.subscribe( - subscription, variable_values=variable_values, parse_result=False - ): + async for result in session.subscribe(subscription, parse_result=False): review = result["reviewAdded"] expected.append(review) @@ -144,14 +142,12 @@ async def test_async_client_validation_invalid_query( async with client as session: - variable_values = {"ep": "JEDI"} - subscription = gql(subscription_str) + subscription.variable_values = {"ep": "JEDI"} + with pytest.raises(graphql.error.GraphQLError): - async for _result in session.subscribe( - subscription, variable_values=variable_values - ): + async for _result in session.subscribe(subscription): pass diff --git a/tests/test_client.py b/tests/test_client.py index 3412059e..4e2e9bca 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -94,7 +94,7 @@ def test_retries_on_transport(execute_mock): assert execute_mock.call_count == expected_retries + 1 execute_mock.reset_mock() - queries = map(lambda d: GraphQLRequest(document=d), [query, query, query]) + queries = [query, query, query] with client as session: # We're using the client as context manager with pytest.raises(Exception): @@ -143,7 +143,7 @@ def test_execute_result_error(): Batching is not supported anymore on countries backend with pytest.raises(TransportQueryError) as exc_info: - client.execute_batch([GraphQLRequest(document=failing_query)]) + client.execute_batch([GraphQLRequest(failing_query)]) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) """ @@ -171,7 +171,7 @@ def test_http_transport_verify_error(http_transport_query): Batching is not supported anymore on countries backend with pytest.warns(Warning) as record: - client.execute_batch([GraphQLRequest(document=http_transport_query)]) + client.execute_batch([GraphQLRequest(http_transport_query)]) assert len(record) == 1 assert "Unverified HTTPS request is being made to host" in str( @@ -197,7 +197,7 @@ def test_http_transport_specify_method_valid(http_transport_query): """ Batching is not supported anymore on countries backend - result = client.execute_batch([GraphQLRequest(document=http_transport_query)]) + result = client.execute_batch([GraphQLRequest(http_transport_query)]) assert result is not None """ diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index 346dc00e..ea255c7d 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -18,7 +18,7 @@ ) from graphql.utilities import value_from_ast_untyped -from gql import GraphQLRequest, gql +from gql import GraphQLRequest from .conftest import MS, strip_braces_spaces @@ -188,12 +188,12 @@ async def subscribe_spend_all(_root, _info, money): def test_serialize_variables_using_money_example(): - req = GraphQLRequest(document=gql("{balance}")) + req = GraphQLRequest("{balance}") money_value = Money(10, "DM") req = GraphQLRequest( - document=gql("query myquery($money: Money) {toEuros(money: $money)}"), + "query myquery($money: Money) {toEuros(money: $money)}", variable_values={"money": money_value}, ) @@ -210,3 +210,29 @@ def test_graphql_request_using_string_instead_of_document(): print(request) assert str(request) == strip_braces_spaces(expected_payload) + + +def test_graphql_request_init_with_graphql_request(): + money_value_1 = Money(10, "DM") + money_value_2 = Money(20, "DM") + + request_1 = GraphQLRequest( + "query myquery($money: Money) {toEuros(money: $money)}", + variable_values={"money": money_value_1}, + ) + request_2 = GraphQLRequest( + request_1, + ) + request_3 = GraphQLRequest( + request_1, + variable_values={"money": money_value_2}, + ) + + assert request_1.document == request_2.document + assert request_2.document == request_3.document + assert isinstance(request_1.variable_values, Dict) + assert isinstance(request_2.variable_values, Dict) + assert isinstance(request_3.variable_values, Dict) + assert request_1.variable_values["money"] == money_value_1 + assert request_2.variable_values["money"] == money_value_1 + assert request_3.variable_values["money"] == money_value_2 diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 45e7aba4..416726aa 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -425,10 +425,9 @@ async def test_graphqlws_subscription_with_operation_name( count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") diff --git a/tests/test_httpx.py b/tests/test_httpx.py index b944391f..3a424355 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -3,7 +3,7 @@ import pytest -from gql import Client, FileVar, GraphQLRequest, gql +from gql import Client, FileVar, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -470,7 +470,7 @@ def test_code(): query = gql(query1_str) with pytest.raises(TransportClosed): - transport.execute(GraphQLRequest(query)) + transport.execute(query) await run_sync_test(server, test_code) @@ -573,35 +573,29 @@ def test_code(): # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = {"file": FileVar(f), "other_var": 42} - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": FileVar(f), "other_var": 42} + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] # Using an filename string inside a FileVar object - params = { + query.variable_values = { "file": FileVar(file_path), "other_var": 42, } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -645,25 +639,21 @@ def test_code(): # Setting the content_type f.content_type = "application/pdf" # type: ignore - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] # Using FileVar - params = { + query.variable_values = { "file": FileVar(file_path, content_type="application/pdf"), "other_var": 42, } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -706,13 +696,11 @@ def test_code(): query = gql(file_upload_mutation_1) # Using FileVar - params = { + query.variable_values = { "file": FileVar(file_path), "other_var": 42, } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -750,10 +738,8 @@ def test_code(): file_path = test_file.filename - params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -795,11 +781,9 @@ def test_code(): file_path = test_file.filename - params = {"file": FileVar(file_path), "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -862,14 +846,12 @@ def test_code(): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - params = { + query.variable_values = { "file1": FileVar(file_path_1), "file2": FileVar(file_path_2), } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -933,16 +915,14 @@ def test_code(): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - params = { + query.variable_values = { "files": [ FileVar(file_path_1), FileVar(file_path_2), ], } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 56c65873..25fd27aa 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, GraphQLRequest, gql +from gql import Client, FileVar, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -429,7 +429,7 @@ async def handler(request): query = gql(query1_str) with pytest.raises(TransportClosed): - await transport.execute(GraphQLRequest(query)) + await transport.execute(query) @pytest.mark.aiohttp @@ -498,14 +498,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute( - query, variable_values=params, operation_name="getEurope" - ) + result = await session.execute(query) continent = result["continent"] @@ -536,14 +535,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute( - query, variable_values=params, operation_name="getEurope" - ) + result = await session.execute(query) continent = result["continent"] @@ -671,16 +669,14 @@ async def test_httpx_file_upload(aiohttp_server): # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} # Execute query asynchronously with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -688,23 +684,19 @@ async def test_httpx_file_upload(aiohttp_server): # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = {"file": FileVar(f), "other_var": 42} + query.variable_values = {"file": FileVar(f), "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success # Using an filename string inside a FileVar object - params = {"file": FileVar(file_path), "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -742,9 +734,9 @@ def test_code(): file_path = test_file.filename - params = {"file": FileVar(file_path), "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - result = client.execute(query, variable_values=params, upload_files=True) + result = client.execute(query, upload_files=True) success = result["success"] @@ -788,12 +780,10 @@ async def test_httpx_binary_file_upload(aiohttp_server): file_path = test_file.filename - params = {"file": FileVar(file_path), "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] @@ -855,14 +845,12 @@ async def test_httpx_file_upload_two_files(aiohttp_server): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - params = { + query.variable_values = { "file1": FileVar(file_path_1), "file2": FileVar(file_path_2), } - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success @@ -925,7 +913,7 @@ async def test_httpx_file_upload_list_of_two_files(aiohttp_server): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - params = { + query.variable_values = { "files": [ FileVar(file_path_1), FileVar(file_path_2), @@ -933,9 +921,7 @@ async def test_httpx_file_upload_list_of_two_files(aiohttp_server): } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] assert success diff --git a/tests/test_httpx_batch.py b/tests/test_httpx_batch.py index 9e5b9b93..63472dab 100644 --- a/tests/test_httpx_batch.py +++ b/tests/test_httpx_batch.py @@ -2,7 +2,7 @@ import pytest -from gql import Client, GraphQLRequest, gql +from gql import Client, GraphQLRequest from gql.transport.exceptions import ( TransportClosed, TransportProtocolError, @@ -54,7 +54,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Execute query asynchronously results = await session.execute_batch(query) @@ -98,7 +98,7 @@ async def handler(request): def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] results = session.execute_batch(query) @@ -143,7 +143,7 @@ def test_code(): client = Client(transport=transport) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] results = client.execute_batch(query) @@ -188,7 +188,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportQueryError): await session.execute_batch(query) @@ -227,7 +227,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportProtocolError): await session.execute_batch(query) @@ -255,7 +255,7 @@ async def handler(request): transport = HTTPXAsyncTransport(url=url, timeout=10) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportClosed): await transport.execute_batch(query) @@ -283,7 +283,7 @@ async def handler(request): transport = HTTPXTransport(url=url, timeout=10) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportClosed): transport.execute_batch(query) @@ -316,7 +316,7 @@ async def handler(request): async with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Passing extra arguments to the post method results = await session.execute_batch( @@ -364,7 +364,7 @@ async def handler(request): transport = HTTPXAsyncTransport(url=url) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] async with Client(transport=transport) as session: @@ -388,15 +388,13 @@ async def test_httpx_batch_online_async_manual(): transport=HTTPXAsyncTransport(url=ONLINE_URL), ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ async with client as session: @@ -419,15 +417,13 @@ async def test_httpx_batch_online_sync_manual(): transport=HTTPXTransport(url=ONLINE_URL), ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ with client as session: diff --git a/tests/test_requests.py b/tests/test_requests.py index ff6a5651..45901875 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, GraphQLRequest, gql +from gql import Client, FileVar, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -471,7 +471,7 @@ def test_code(): query = gql(query1_str) with pytest.raises(TransportClosed): - transport.execute(GraphQLRequest(query)) + transport.execute(query) await run_sync_test(server, test_code) @@ -574,35 +574,29 @@ def test_code(): # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = {"file": FileVar(f), "other_var": 42} + query.variable_values = {"file": FileVar(f), "other_var": 42} with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] # Using an filename string inside a FileVar object - params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -646,27 +640,23 @@ def test_code(): # Setting the content_type f.content_type = "application/pdf" # type: ignore - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = { + query.variable_values = { "file": FileVar(f, content_type="application/pdf"), "other_var": 42, } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -709,13 +699,11 @@ def test_code(): with Client(transport=transport) as session: query = gql(file_upload_mutation_1) - params = { + query.variable_values = { "file": FileVar(file_path), "other_var": 42, } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -756,13 +744,11 @@ def test_code(): with open(file_path, "rb") as f: - params = { + query.variable_values = { "file": FileVar(f, filename="filename1.txt"), "other_var": 42, } - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -802,14 +788,12 @@ def test_code(): with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -853,15 +837,13 @@ def test_code(): with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -928,7 +910,7 @@ def test_code(): f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params_1 = { + query.variable_values = { "file1": f1, "file2": f2, } @@ -937,9 +919,7 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params_1, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -953,14 +933,12 @@ def test_code(): f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params_2 = { + query.variable_values = { "file1": FileVar(f1), "file2": FileVar(f2), } - execution_result = session.execute( - query, variable_values=params_2, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -1031,15 +1009,13 @@ def test_code(): f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params = {"files": [f1, f2]} + query.variable_values = {"files": [f1, f2]} with pytest.warns( DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session.execute( - query, variable_values=params, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] @@ -1053,11 +1029,9 @@ def test_code(): f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params_2 = {"files": [FileVar(f1), FileVar(f2)]} + query.variable_values = {"files": [FileVar(f1), FileVar(f2)]} - execution_result = session.execute( - query, variable_values=params_2, upload_files=True - ) + execution_result = session.execute(query, upload_files=True) assert execution_result["success"] diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 38850d56..a2f0cdbf 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -71,7 +71,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Execute query synchronously results = session.execute_batch(query) @@ -225,7 +225,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Execute query synchronously results = session.execute_batch(query) @@ -265,7 +265,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportServerError) as exc_info: session.execute_batch(query) @@ -353,7 +353,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportServerError) as exc_info: session.execute_batch(query) @@ -388,7 +388,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportServerError): session.execute_batch(query) @@ -422,7 +422,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportQueryError): session.execute_batch(query) @@ -464,7 +464,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportProtocolError): session.execute_batch(query) @@ -493,7 +493,7 @@ async def handler(request): def test_code(): transport = RequestsHTTPTransport(url=url) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportClosed): transport.execute_batch(query) @@ -536,7 +536,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] execution_results = session.execute_batch(query, get_execution_result=True) @@ -626,15 +626,13 @@ def test_requests_sync_batch_auto_execute_future(): batch_max=3, ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ with client as session: @@ -661,15 +659,13 @@ def test_requests_sync_batch_manual(): transport=RequestsHTTPTransport(url=ONLINE_URL), ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ with client as session: diff --git a/tests/test_transport.py b/tests/test_transport.py index 87b31eb1..7c2a5a8f 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -100,9 +100,10 @@ def test_query_with_variable(client): } """ ) + query.variable_values = {"id": "UGxhbmV0OjEw"} expected = {"planet": {"id": "UGxhbmV0OjEw", "name": "Kamino"}} with use_cassette("queries"): - result = client.execute(query, variable_values={"id": "UGxhbmV0OjEw"}) + result = client.execute(query) assert result == expected @@ -123,9 +124,10 @@ def test_named_query(client): } """ ) + query.operation_name = "Planet2" expected = {"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}} with use_cassette("queries"): - result = client.execute(query, operation_name="Planet2") + result = client.execute(query) assert result == expected diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index 0b2a3158..671858e7 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -2,7 +2,7 @@ import pytest -from gql import Client, GraphQLRequest, gql +from gql import Client, gql # We serve https://github.com/graphql-python/swapi-graphene locally: URL = "http://127.0.0.1:8000/graphql" @@ -87,7 +87,7 @@ def test_hero_name_query(client): } ] with use_cassette("queries_batch"): - results = client.execute_batch([GraphQLRequest(document=query)]) + results = client.execute_batch([query]) assert results == expected @@ -102,11 +102,10 @@ def test_query_with_variable(client): } """ ) + query.variable_values = {"id": "UGxhbmV0OjEw"} expected = [{"planet": {"id": "UGxhbmV0OjEw", "name": "Kamino"}}] with use_cassette("queries_batch"): - results = client.execute_batch( - [GraphQLRequest(document=query, variable_values={"id": "UGxhbmV0OjEw"})] - ) + results = client.execute_batch([query]) assert results == expected @@ -127,11 +126,10 @@ def test_named_query(client): } """ ) + query.operation_name = "Planet2" expected = [{"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}}] with use_cassette("queries_batch"): - results = client.execute_batch( - [GraphQLRequest(document=query, operation_name="Planet2")] - ) + results = client.execute_batch([query]) assert results == expected @@ -149,7 +147,7 @@ def test_header_query(client): expected = [{"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}}] with use_cassette("queries_batch"): results = client.execute_batch( - [GraphQLRequest(document=query)], + [query], extra_args={"headers": {"authorization": "xxx-123"}}, ) assert results == expected diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 487b9ba5..5baa0b4e 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -370,10 +370,9 @@ async def test_websocket_subscription_with_operation_name( count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py index f0448c79..31422487 100644 --- a/tests/test_websockets_adapter.py +++ b/tests/test_websockets_adapter.py @@ -39,7 +39,7 @@ async def test_websockets_adapter_simple_query(server): url = f"ws://{server.hostname}:{server.port}/graphql" - query = print_ast(gql(query1_str)) + query = print_ast(gql(query1_str).document) print("query=", query) adapter = WebSocketsAdapter(url) @@ -71,7 +71,7 @@ async def test_websockets_adapter_edge_cases(server): url = f"ws://{server.hostname}:{server.port}/graphql" - query = print_ast(gql(query1_str)) + query = print_ast(gql(query1_str).document) print("query=", query) adapter = WebSocketsAdapter(url, headers={"a": "r1"}, ssl=False, connect_args={}) From 4fa1553fd5ba7c99b7176fe9eb6a18d8cc2cc5d4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 28 May 2025 14:14:42 +0000 Subject: [PATCH 222/239] Refactor transports (#557) * Refactor requests transport to be similar to aiohttp and httpx * Removing _prepare_batch_request to avoid code duplication --- gql/transport/aiohttp.py | 46 +++--- gql/transport/httpx.py | 55 +++----- gql/transport/requests.py | 287 ++++++++++++++++++++------------------ 3 files changed, 188 insertions(+), 200 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 61d01fb4..40e212cf 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -127,7 +127,7 @@ async def connect(self) -> None: # Adding custom parameters passed from init if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore + client_session_args.update(self.client_session_args) log.debug("Connecting transport") @@ -164,36 +164,22 @@ async def close(self) -> None: self.session = None - def _prepare_batch_request( - self, - reqs: List[GraphQLRequest], - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - - payload = [req.payload for req in reqs] - - post_args = {"json": payload} - - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(post_args["json"])) - - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) - - return post_args - def _prepare_request( self, - request: GraphQLRequest, + request: Union[GraphQLRequest, List[GraphQLRequest]], extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - payload = request.payload + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] if upload_files: + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) post_args = self._prepare_file_uploads(request, payload) else: post_args = {"json": payload} @@ -379,15 +365,15 @@ async def execute( :returns: an ExecutionResult object. """ + if self.session is None: + raise TransportClosed("Transport is not connected") + post_args = self._prepare_request( request, extra_args, upload_files, ) - if self.session is None: - raise TransportClosed("Transport is not connected") - try: async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: return await self._prepare_result(resp) @@ -413,14 +399,14 @@ async def execute_batch( if an error occurred. """ - post_args = self._prepare_batch_request( + if self.session is None: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( reqs, extra_args, ) - if self.session is None: - raise TransportClosed("Transport is not connected") - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: return await self._prepare_batch_result(reqs, resp) diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index afb1360c..7fe2a7db 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -59,15 +59,22 @@ def __init__( def _prepare_request( self, - req: GraphQLRequest, + request: Union[GraphQLRequest, List[GraphQLRequest]], + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - payload = req.payload + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] if upload_files: - post_args = self._prepare_file_uploads(req, payload) + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads(request, payload) else: post_args = {"json": payload} @@ -81,26 +88,6 @@ def _prepare_request( return post_args - def _prepare_batch_request( - self, - reqs: List[GraphQLRequest], - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - - payload = [req.payload for req in reqs] - - post_args = {"json": payload} - - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(payload)) - - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) - - return post_args - def _prepare_file_uploads( self, request: GraphQLRequest, @@ -244,7 +231,7 @@ def connect(self): self.client = httpx.Client(**self.kwargs) - def execute( # type: ignore + def execute( self, request: GraphQLRequest, *, @@ -269,8 +256,8 @@ def execute( # type: ignore post_args = self._prepare_request( request, - extra_args, - upload_files, + extra_args=extra_args, + upload_files=upload_files, ) try: @@ -292,7 +279,7 @@ def execute_batch( :code:`execute_batch` on a client or a session. :param reqs: GraphQL requests as a list of GraphQLRequest objects. - :param extra_args: additional arguments to send to the aiohttp post method + :param extra_args: additional arguments to send to the httpx post method :return: A list of results of execution. For every result `data` is the result of executing the query, `errors` is null if no errors occurred, and is a non-empty array @@ -302,9 +289,9 @@ def execute_batch( if not self.client: raise TransportClosed("Transport is not connected") - post_args = self._prepare_batch_request( + post_args = self._prepare_request( reqs, - extra_args, + extra_args=extra_args, ) response = self.client.post(self.url, **post_args) @@ -361,8 +348,8 @@ async def execute( post_args = self._prepare_request( request, - extra_args, - upload_files, + extra_args=extra_args, + upload_files=upload_files, ) try: @@ -384,7 +371,7 @@ async def execute_batch( :code:`execute_batch` on a client or a session. :param reqs: GraphQL requests as a list of GraphQLRequest objects. - :param extra_args: additional arguments to send to the aiohttp post method + :param extra_args: additional arguments to send to the httpx post method :return: A list of results of execution. For every result `data` is the result of executing the query, `errors` is null if no errors occurred, and is a non-empty array @@ -394,9 +381,9 @@ async def execute_batch( if not self.client: raise TransportClosed("Transport is not connected") - post_args = self._prepare_batch_request( + post_args = self._prepare_request( reqs, - extra_args, + extra_args=extra_args, ) response = await self.client.post(self.url, **post_args) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 16d07025..17bf4695 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -137,32 +137,20 @@ def connect(self): else: raise TransportAlreadyConnected("Transport is already connected") - def execute( # type: ignore + def _prepare_request( self, - request: GraphQLRequest, + request: Union[GraphQLRequest, List[GraphQLRequest]], + *, timeout: Optional[int] = None, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, - ) -> ExecutionResult: - """Execute GraphQL query. - - Execute the provided request against the configured remote server. This - uses the requests library to perform a HTTP POST request to the remote server. - - :param request: GraphQL request as a - :class:`GraphQLRequest ` object. - :param timeout: Specifies a default timeout for requests (Default: None). - :param extra_args: additional arguments to send to the requests post method - :param upload_files: Set to True if you want to put files in the variable values - :return: The result of execution. - `data` is the result of executing the query, `errors` is null - if no errors occurred, and is a non-empty array if an error occurred. - """ - - if not self.session: - raise TransportClosed("Transport is not connected") + ) -> Dict[str, Any]: - payload = request.payload + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] post_args: Dict[str, Any] = { "headers": self.headers, @@ -173,111 +161,139 @@ def execute( # type: ignore } if upload_files: - # If the upload_files flag is set, then we need variable_values - assert request.variable_values is not None - - # If we upload files, we will extract the files present in the - # variable_values dict and replace them by null values - nulled_variable_values, files = extract_files( - variables=request.variable_values, - file_classes=self.file_classes, + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads( + request=request, + payload=payload, + post_args=post_args, ) - # Opening the files using the FileVar parameters - open_files(list(files.values())) - self.files = files + else: + data_key = "json" if self.use_json else "data" + post_args[data_key] = payload - # Save the nulled variable values in the payload - payload["variables"] = nulled_variable_values + # Log the payload + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) - # Add the payload to the operations field - operations_str = self.json_serialize(payload) - log.debug("operations %s", operations_str) + # Pass kwargs to requests post method + post_args.update(self.kwargs) - # Generate the file map - # path is nested in a list because the spec allows multiple pointers - # to the same file. But we don't support that. - # Will generate something like {"0": ["variables.file"]} - file_map = {str(i): [path] for i, path in enumerate(files)} + # Pass post_args to requests post method + if extra_args: + post_args.update(extra_args) + + return post_args - # Enumerate the file streams - # Will generate something like {'0': FileVar object} - file_vars = {str(i): files[path] for i, path in enumerate(files)} + def _prepare_file_uploads( + self, + request: GraphQLRequest, + *, + payload: Dict[str, Any], + post_args: Dict[str, Any], + ) -> Dict[str, Any]: + # If the upload_files flag is set, then we need variable_values + assert request.variable_values is not None + + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=request.variable_values, + file_classes=self.file_classes, + ) - # Add the file map field - file_map_str = self.json_serialize(file_map) - log.debug("file_map %s", file_map_str) + # Opening the files using the FileVar parameters + open_files(list(files.values())) + self.files = files - fields = {"operations": operations_str, "map": file_map_str} + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values - # Add the extracted files as remaining fields - for k, file_var in file_vars.items(): - assert isinstance(file_var, FileVar) - name = k if file_var.filename is None else file_var.filename + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) - if file_var.content_type is None: - fields[k] = (name, file_var.f) - else: - fields[k] = (name, file_var.f, file_var.content_type) + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} - # Prepare requests http to send multipart-encoded data - data = MultipartEncoder(fields=fields) + # Enumerate the file streams + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} - post_args["data"] = data + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) - if post_args["headers"] is None: - post_args["headers"] = {} + fields = {"operations": operations_str, "map": file_map_str} + + # Add the extracted files as remaining fields + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) + name = k if file_var.filename is None else file_var.filename + + if file_var.content_type is None: + fields[k] = (name, file_var.f) else: - post_args["headers"] = dict(post_args["headers"]) + fields[k] = (name, file_var.f, file_var.content_type) + + # Prepare requests http to send multipart-encoded data + data = MultipartEncoder(fields=fields) - post_args["headers"]["Content-Type"] = data.content_type + post_args["data"] = data + if post_args["headers"] is None: + post_args["headers"] = {} else: - data_key = "json" if self.use_json else "data" - post_args[data_key] = payload + post_args["headers"] = dict(post_args["headers"]) - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(payload)) + post_args["headers"]["Content-Type"] = data.content_type - # Pass kwargs to requests post method - post_args.update(self.kwargs) + return post_args - # Pass post_args to requests post method - if extra_args: - post_args.update(extra_args) + def execute( + self, + request: GraphQLRequest, + timeout: Optional[int] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute GraphQL query. + + Execute the provided request against the configured remote server. This + uses the requests library to perform a HTTP POST request to the remote server. + + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. + :param timeout: Specifies a default timeout for requests (Default: None). + :param extra_args: additional arguments to send to the requests post method + :param upload_files: Set to True if you want to put files in the variable values + :return: The result of execution. + `data` is the result of executing the query, `errors` is null + if no errors occurred, and is a non-empty array if an error occurred. + """ + + if not self.session: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + request, + timeout=timeout, + extra_args=extra_args, + upload_files=upload_files, + ) # Using the created session to perform requests try: - response = self.session.request( - self.method, self.url, **post_args # type: ignore - ) + response = self.session.request(self.method, self.url, **post_args) finally: if upload_files: close_files(list(self.files.values())) - self.response_headers = response.headers - - try: - if self.json_deserialize == json.loads: - result = response.json() - else: - result = self.json_deserialize(response.text) - - if log.isEnabledFor(logging.DEBUG): - log.debug("<<< %s", response.text) - - except Exception: - self._raise_response_error(response, "Not a JSON answer") - - if "errors" not in result and "data" not in result: - self._raise_response_error(response, 'No "data" or "errors" keys in answer') - - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) + return self._prepare_result(response) @staticmethod def _raise_transport_server_error_if_status_more_than_400( @@ -327,27 +343,27 @@ def execute_batch( if not self.session: raise TransportClosed("Transport is not connected") - # Using the created session to perform requests + post_args = self._prepare_request( + reqs, + timeout=timeout, + extra_args=extra_args, + ) + response = self.session.request( self.method, self.url, - **self._build_batch_post_args(reqs, timeout, extra_args), + **post_args, ) - self.response_headers = response.headers - answers = self._extract_response(response) + return self._prepare_batch_result(reqs, response) - try: - return get_batch_execution_result_list(reqs, answers) - except TransportProtocolError: - # Raise a TransportServerError if status > 400 - self._raise_transport_server_error_if_status_more_than_400(response) - # In other cases, raise a TransportProtocolError - raise + def _get_json_result(self, response: requests.Response) -> Any: + + # Saving latest response headers in the transport + self.response_headers = response.headers - def _extract_response(self, response: requests.Response) -> Any: try: - result = response.json() + result = self.json_deserialize(response.text) if log.isEnabledFor(logging.DEBUG): log.debug("<<< %s", response.text) @@ -357,35 +373,34 @@ def _extract_response(self, response: requests.Response) -> Any: return result - def _build_batch_post_args( - self, - reqs: List[GraphQLRequest], - timeout: Optional[int] = None, - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - post_args: Dict[str, Any] = { - "headers": self.headers, - "auth": self.auth, - "cookies": self.cookies, - "timeout": timeout or self.default_timeout, - "verify": self.verify, - } + def _prepare_result(self, response: requests.Response) -> ExecutionResult: - data_key = "json" if self.use_json else "data" - post_args[data_key] = [req.payload for req in reqs] + result = self._get_json_result(response) - # Log the payload - if log.isEnabledFor(logging.DEBUG): - log.debug(">>> %s", self.json_serialize(post_args[data_key])) + if "errors" not in result and "data" not in result: + self._raise_response_error(response, 'No "data" or "errors" keys in answer') - # Pass kwargs to requests post method - post_args.update(self.kwargs) + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) - # Pass post_args to requests post method - if extra_args: - post_args.update(extra_args) + def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: requests.Response, + ) -> List[ExecutionResult]: - return post_args + answers = self._get_json_result(response) + + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise def close(self): """Closing the transport by closing the inner session""" From 4af703e6a0c0bbb824012a5d4c75978f53a8e25a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 28 May 2025 15:21:57 +0000 Subject: [PATCH 223/239] Trapping dependencies Exceptions into TransportConnectionFailed (#558) --- docs/advanced/error_handling.rst | 5 +++++ gql/gql.py | 9 ++++----- gql/graphql_request.py | 11 ++++------- gql/transport/aiohttp.py | 15 +++++++++++++-- gql/transport/exceptions.py | 5 +++-- gql/transport/httpx.py | 15 +++++++++++++-- gql/transport/requests.py | 16 +++++++++++----- tests/test_aiohttp.py | 18 ++++++++++++------ tests/test_httpx.py | 16 ++++++++++------ tests/test_httpx_async.py | 16 ++++++++++------ tests/test_requests.py | 4 ++-- 11 files changed, 87 insertions(+), 43 deletions(-) diff --git a/docs/advanced/error_handling.rst b/docs/advanced/error_handling.rst index 4e6618c9..458f2667 100644 --- a/docs/advanced/error_handling.rst +++ b/docs/advanced/error_handling.rst @@ -46,6 +46,11 @@ Here are the possible Transport Errors: If you don't need the schema, you can try to create the client with :code:`fetch_schema_from_transport=False` +- :class:`TransportConnectionFailed `: + This exception is generated when an unexpected Exception is received from the + transport dependency when trying to connect or to send the request. + For example in case of an SSL error, or if a websocket connection suddenly fails. + - :class:`TransportClosed `: This exception is generated when the client is trying to use the transport while the transport was previously closed. diff --git a/gql/gql.py b/gql/gql.py index f4cd3aea..8a5a1b32 100644 --- a/gql/gql.py +++ b/gql/gql.py @@ -3,15 +3,14 @@ def gql(request_string: str) -> GraphQLRequest: """Given a string containing a GraphQL request, - parse it into a Document and put it into a GraphQLRequest object + parse it into a Document and put it into a GraphQLRequest object. :param request_string: the GraphQL request as a String :return: a :class:`GraphQLRequest ` which can be later executed or subscribed by a - :class:`Client `, by an - :class:`async session ` or by a - :class:`sync session ` - + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` :raises graphql.error.GraphQLError: if a syntax error is encountered. """ return GraphQLRequest(request_string) diff --git a/gql/graphql_request.py b/gql/graphql_request.py index fe3523a9..5e6f3ee4 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -14,22 +14,19 @@ def __init__( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ): - """ - Initialize a GraphQL request. + """Initialize a GraphQL request. :param request: GraphQL request as DocumentNode object or as a string. If string, it will be converted to DocumentNode. :param variable_values: Dictionary of input parameters (Default: None). :param operation_name: Name of the operation that shall be executed. Only required in multi-operation documents (Default: None). - :return: a :class:`GraphQLRequest ` which can be later executed or subscribed by a - :class:`Client `, by an - :class:`async session ` or by a - :class:`sync session ` + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` :raises graphql.error.GraphQLError: if a syntax error is encountered. - """ if isinstance(request, str): source = Source(request, "GraphQL request") diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 40e212cf..e3bfdb3b 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -31,6 +31,8 @@ from .exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, + TransportError, TransportProtocolError, TransportServerError, ) @@ -377,6 +379,10 @@ async def execute( try: async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: return await self._prepare_result(resp) + except TransportError: + raise + except Exception as e: + raise TransportConnectionFailed(str(e)) from e finally: if upload_files: close_files(list(self.files.values())) @@ -407,8 +413,13 @@ async def execute_batch( extra_args, ) - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: - return await self._prepare_batch_result(reqs, resp) + try: + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_batch_result(reqs, resp) + except TransportError: + raise + except Exception as e: + raise TransportConnectionFailed(str(e)) from e def subscribe( self, diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 3e63f0bc..0049d5c2 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -62,9 +62,10 @@ class TransportClosed(TransportError): class TransportConnectionFailed(TransportError): - """Transport adapter connection closed. + """Transport connection failed. - This exception is by the connection adapter code when a connection closed. + This exception is by the connection adapter code when a connection closed + or if an unexpected Exception was received when trying to send a request. """ diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 7fe2a7db..0a338639 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -23,6 +23,7 @@ from .exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportServerError, ) @@ -262,6 +263,8 @@ def execute( try: response = self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e finally: if upload_files: close_files(list(self.files.values())) @@ -294,7 +297,10 @@ def execute_batch( extra_args=extra_args, ) - response = self.client.post(self.url, **post_args) + try: + response = self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e return self._prepare_batch_result(reqs, response) @@ -354,6 +360,8 @@ async def execute( try: response = await self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e finally: if upload_files: close_files(list(self.files.values())) @@ -386,7 +394,10 @@ async def execute_batch( extra_args=extra_args, ) - response = await self.client.post(self.url, **post_args) + try: + response = await self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e return self._prepare_batch_result(reqs, response) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 17bf4695..a29f7f0f 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -29,6 +29,7 @@ from .exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportServerError, ) @@ -289,6 +290,8 @@ def execute( # Using the created session to perform requests try: response = self.session.request(self.method, self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e finally: if upload_files: close_files(list(self.files.values())) @@ -349,11 +352,14 @@ def execute_batch( extra_args=extra_args, ) - response = self.session.request( - self.method, - self.url, - **post_args, - ) + try: + response = self.session.request( + self.method, + self.url, + **post_args, + ) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e return self._prepare_batch_result(reqs, response) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index e3ac08c4..506b04f4 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -11,6 +11,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -1455,7 +1456,6 @@ async def handler(request): async def test_aiohttp_query_https_self_cert_fail(ssl_aiohttp_server): """By default, we should verify the ssl certificate""" from aiohttp import web - from aiohttp.client_exceptions import ClientConnectorCertificateError from gql.transport.aiohttp import AIOHTTPTransport @@ -1472,16 +1472,22 @@ async def handler(request): transport = AIOHTTPTransport(url=url, timeout=10) - with pytest.raises(ClientConnectorCertificateError) as exc_info: - async with Client(transport=transport) as session: - query = gql(query1_str) + query = gql(query1_str) - # Execute query asynchronously + expected_error = "certificate verify failed: self-signed certificate" + + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: await session.execute(query) - expected_error = "certificate verify failed: self-signed certificate" + assert expected_error in str(exc_info.value) + + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: + await session.execute_batch([query]) assert expected_error in str(exc_info.value) + assert transport.session is None diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 3a424355..0411294b 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -7,6 +7,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -150,7 +151,6 @@ async def test_httpx_query_https_self_cert_fail( ): """By default, we should verify the ssl certificate""" from aiohttp import web - from httpx import ConnectError from gql.transport.httpx import HTTPXTransport @@ -180,15 +180,19 @@ def test_code(): **extra_args, ) - with pytest.raises(ConnectError) as exc_info: - with Client(transport=transport) as session: + query = gql(query1_str) - query = gql(query1_str) + expected_error = "certificate verify failed: self-signed certificate" - # Execute query synchronously + with pytest.raises(TransportConnectionFailed) as exc_info: + with Client(transport=transport) as session: session.execute(query) - expected_error = "certificate verify failed: self-signed certificate" + assert expected_error in str(exc_info.value) + + with pytest.raises(TransportConnectionFailed) as exc_info: + with Client(transport=transport) as session: + session.execute_batch([query]) assert expected_error in str(exc_info.value) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 25fd27aa..690b3ee7 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -1155,7 +1156,6 @@ async def handler(request): @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_httpx_query_https_self_cert_fail(ssl_aiohttp_server, verify_https): from aiohttp import web - from httpx import ConnectError from gql.transport.httpx import HTTPXAsyncTransport @@ -1177,15 +1177,19 @@ async def handler(request): transport = HTTPXAsyncTransport(url=url, timeout=10, **extra_args) - with pytest.raises(ConnectError) as exc_info: - async with Client(transport=transport) as session: + query = gql(query1_str) - query = gql(query1_str) + expected_error = "certificate verify failed: self-signed certificate" - # Execute query asynchronously + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: await session.execute(query) - expected_error = "certificate verify failed: self-signed certificate" + assert expected_error in str(exc_info.value) + + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: + await session.execute_batch([query]) assert expected_error in str(exc_info.value) diff --git a/tests/test_requests.py b/tests/test_requests.py index 45901875..fe57f5e3 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -8,6 +8,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -154,7 +155,6 @@ async def test_requests_query_https_self_cert_fail( ): """By default, we should verify the ssl certificate""" from aiohttp import web - from requests.exceptions import SSLError from gql.transport.requests import RequestsHTTPTransport @@ -182,7 +182,7 @@ def test_code(): **extra_args, ) - with pytest.raises(SSLError) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: with Client(transport=transport) as session: query = gql(query1_str) From a90f92322713bd75d8557b8fa79e89c1708fdcff Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 28 May 2025 17:54:52 +0200 Subject: [PATCH 224/239] Bump version number to 4.0.0b0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 7870304a..65ac68de 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "4.0.0a0" +__version__ = "4.0.0b0" From dba4953add92249800d8d8a92cb33dc66856a8a6 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 17 Aug 2025 14:44:53 +0200 Subject: [PATCH 225/239] Documentation improvements (#561) --- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- README.md | 56 ++++++++++++++--- docs/async/async_intro.rst | 18 ------ docs/async/index.rst | 10 --- docs/code_examples/aiohttp_async.py | 34 +++++----- docs/code_examples/aiohttp_sync.py | 2 +- docs/conf.py | 4 +- docs/index.rst | 5 +- docs/intro.rst | 6 +- docs/{async => usage}/async_usage.rst | 24 ++++++- docs/usage/index.rst | 3 +- docs/usage/subscriptions.rst | 63 ++++++++++++++++--- .../usage/{basic_usage.rst => sync_usage.rst} | 9 ++- 13 files changed, 158 insertions(+), 78 deletions(-) delete mode 100644 docs/async/async_intro.rst delete mode 100644 docs/async/index.rst rename docs/{async => usage}/async_usage.rst (52%) rename docs/usage/{basic_usage.rst => sync_usage.rst} (86%) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index f89a2238..45f01d82 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -9,7 +9,7 @@ assignees: '' **Common problems** - If you receive a TransportQueryError, it means the error is coming from the backend (See [Error Handling](https://gql.readthedocs.io/en/latest/advanced/error_handling.html)) and has probably nothing to do with gql -- If you use IPython (Jupyter, Spyder), then [you need to use the async version](https://gql.readthedocs.io/en/latest/async/async_usage.html#ipython) +- If you use IPython (Jupyter, Spyder), then [you need to use the async version](https://gql.readthedocs.io/en/latest/usage/async_usage.html#ipython) - Before sending a bug report, please consider [activating debug logs](https://gql.readthedocs.io/en/latest/advanced/logging.html) to see the messages exchanged between the client and the backend **Describe the bug** diff --git a/README.md b/README.md index e79a63d2..86f380f3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # GQL -This is a GraphQL client for Python 3.8+. -Plays nicely with `graphene`, `graphql-core`, `graphql-js` and any other GraphQL implementation compatible with the spec. +This is a GraphQL client for Python. +Plays nicely with `graphene`, `graphql-core`, `graphql-js` and any other GraphQL implementation +compatible with the [GraphQL specification](https://spec.graphql.org). GQL architecture is inspired by `React-Relay` and `Apollo-Client`. @@ -37,7 +38,7 @@ The complete documentation for GQL can be found at * AWS AppSync realtime protocol (experimental) * Possibility to [validate the queries locally](https://gql.readthedocs.io/en/latest/usage/validation.html) using a GraphQL schema provided locally or fetched from the backend using an instrospection query * Supports GraphQL queries, mutations and [subscriptions](https://gql.readthedocs.io/en/latest/usage/subscriptions.html) -* Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) +* Supports [sync](https://gql.readthedocs.io/en/latest/usage/sync_usage.html) or [async](https://gql.readthedocs.io/en/latest/usage/async_usage.html) usage, [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) * Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html) @@ -57,17 +58,17 @@ pip install "gql[all]" ## Usage -### Basic usage +### Sync usage ```python -from gql import gql, Client +from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport # Select your transport with a defined url endpoint transport = AIOHTTPTransport(url="https://countries.trevorblades.com/") # Create a GraphQL client using the defined transport -client = Client(transport=transport, fetch_schema_from_transport=True) +client = Client(transport=transport) # Provide a GraphQL query query = gql( @@ -95,7 +96,48 @@ $ python basic_example.py > **WARNING**: Please note that this basic example won't work if you have an asyncio event loop running. In some > python environments (as with Jupyter which uses IPython) an asyncio event loop is created for you. In that case you -> should use instead the [async usage example](https://gql.readthedocs.io/en/latest/async/async_usage.html#async-usage). +> should use instead the [async usage example](https://gql.readthedocs.io/en/latest/usage/async_usage.html#async-usage). + +### Async usage + +```python +import asyncio + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport + + +async def main(): + + # Select your transport with a defined url endpoint + transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") + + # Create a GraphQL client using the defined transport + client = Client(transport=transport) + + # Provide a GraphQL query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with client as session: + + # Execute the query + result = await session.execute(query) + print(result) + + +asyncio.run(main()) +``` ## Contributing See [CONTRIBUTING.md](CONTRIBUTING.md) diff --git a/docs/async/async_intro.rst b/docs/async/async_intro.rst deleted file mode 100644 index 6d4fea37..00000000 --- a/docs/async/async_intro.rst +++ /dev/null @@ -1,18 +0,0 @@ -On previous versions of GQL, the code was `sync` only , it means that when you ran -`execute` on the Client, you could do nothing else in the current Thread and had to wait for -an answer or a timeout from the backend to continue. The only http library was `requests`, allowing only sync usage. - -From the version 3 of GQL, we support `sync` and `async` :ref:`transports ` using `asyncio`_. - -With the :ref:`async transports `, there is now the possibility to execute GraphQL requests -asynchronously, :ref:`allowing to execute multiple requests in parallel if needed `. - -If you don't care or need async functionality, it is still possible, with :ref:`async transports `, -to run the `execute` or `subscribe` methods directly from the Client -(as described in the :ref:`Basic Usage ` example) and GQL will execute the request -in a synchronous manner by running an asyncio event loop itself. - -This won't work though if you already have an asyncio event loop running. In that case you should use -:ref:`Async Usage ` - -.. _asyncio: https://docs.python.org/3/library/asyncio.html diff --git a/docs/async/index.rst b/docs/async/index.rst deleted file mode 100644 index 3f3d2a8a..00000000 --- a/docs/async/index.rst +++ /dev/null @@ -1,10 +0,0 @@ -Async vs Sync -============= - -.. include:: async_intro.rst - -.. toctree:: - :hidden: - :maxdepth: 1 - - async_usage diff --git a/docs/code_examples/aiohttp_async.py b/docs/code_examples/aiohttp_async.py index 0c1d10dd..bc615fa8 100644 --- a/docs/code_examples/aiohttp_async.py +++ b/docs/code_examples/aiohttp_async.py @@ -6,27 +6,29 @@ async def main(): + # Select your transport with a defined url endpoint transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") + # Create a GraphQL client using the defined transport + client = Client(transport=transport) + + # Provide a GraphQL query + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + # Using `async with` on the client will start a connection on the transport # and provide a `session` variable to execute queries on this connection - async with Client( - transport=transport, - fetch_schema_from_transport=True, - ) as session: - - # Execute single query - query = gql( - """ - query getContinents { - continents { - code - name - } - } - """ - ) + async with client as session: + # Execute the query result = await session.execute(query) print(result) diff --git a/docs/code_examples/aiohttp_sync.py b/docs/code_examples/aiohttp_sync.py index 8b1cf899..18dab8ae 100644 --- a/docs/code_examples/aiohttp_sync.py +++ b/docs/code_examples/aiohttp_sync.py @@ -5,7 +5,7 @@ transport = AIOHTTPTransport(url="https://countries.trevorblades.com/") # Create a GraphQL client using the defined transport -client = Client(transport=transport, fetch_schema_from_transport=True) +client = Client(transport=transport) # Provide a GraphQL query query = gql( diff --git a/docs/conf.py b/docs/conf.py index 8289ef4b..024dd9e6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,8 +17,8 @@ # -- Project information ----------------------------------------------------- -project = 'gql 3' -copyright = '2020, graphql-python.org' +project = 'gql' +copyright = '2025, graphql-python.org' author = 'graphql-python.org' # The full version, including alpha/beta/rc tags diff --git a/docs/index.rst b/docs/index.rst index ecb2f0e1..d0ab36f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,5 +1,5 @@ -Welcome to GQL 3 documentation! -=============================== +GQL documentation +================= Contents -------- @@ -9,7 +9,6 @@ Contents intro usage/index - async/index transports/index advanced/index gql-cli/intro diff --git a/docs/intro.rst b/docs/intro.rst index 3151755d..f47166f6 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -1,7 +1,7 @@ Introduction ============ -`GQL 3`_ is a `GraphQL`_ Client for Python 3.8+ which plays nicely with other +`GQL`_ is a `GraphQL`_ Client for Python which plays nicely with other graphql implementations compatible with the spec. Under the hood, it uses `GraphQL-core`_ which is a Python port of `GraphQL.js`_, @@ -10,7 +10,7 @@ the JavaScript reference implementation for GraphQL. Installation ------------ -You can install GQL 3 and all the extra dependencies using pip_:: +You can install GQL and all the extra dependencies using pip_:: pip install "gql[all]" @@ -93,7 +93,7 @@ Please check the `Contributing`_ file to learn how to make a good pull request. .. _GraphQL: https://graphql.org/ .. _GraphQL-core: https://github.com/graphql-python/graphql-core .. _GraphQL.js: https://github.com/graphql/graphql-js -.. _GQL 3: https://github.com/graphql-python/gql +.. _GQL: https://github.com/graphql-python/gql .. _pip: https://pip.pypa.io/ .. _GitHub repository for gql: https://github.com/graphql-python/gql .. _Contributing: https://github.com/graphql-python/gql/blob/master/CONTRIBUTING.md diff --git a/docs/async/async_usage.rst b/docs/usage/async_usage.rst similarity index 52% rename from docs/async/async_usage.rst rename to docs/usage/async_usage.rst index e0e9ee02..a83c4767 100644 --- a/docs/async/async_usage.rst +++ b/docs/usage/async_usage.rst @@ -1,8 +1,28 @@ .. _async_usage: -Async Usage +Async usage =========== +On previous versions of GQL, the code was `sync` only , it means that when you ran +`execute` on the Client, you could do nothing else in the current Thread and had to wait for +an answer or a timeout from the backend to continue. The only http library was `requests`, allowing only sync usage. + +From the version 3 of GQL, we support `sync` and `async` :ref:`transports ` using `asyncio`_. + +With the :ref:`async transports `, there is now the possibility to execute GraphQL requests +asynchronously, :ref:`allowing to execute multiple requests in parallel if needed `. + +If you don't care or need async functionality, it is still possible, with :ref:`async transports `, +to run the `execute` or `subscribe` methods directly from the Client +(as described in the :ref:`Sync Usage ` example) and GQL will execute the request +in a synchronous manner by running an asyncio event loop itself. + +This won't work though if you already have an asyncio event loop running. In that case you should use the async +methods. + +Example +------- + If you use an :ref:`async transport `, you can use GQL asynchronously using `asyncio`_. * put your code in an asyncio coroutine (method starting with :code:`async def`) @@ -10,8 +30,6 @@ If you use an :ref:`async transport `, you can use GQL asynchr * use the :code:`await` keyword to execute requests: :code:`await session.execute(...)` * then run your coroutine in an asyncio event loop by running :code:`asyncio.run` -Example: - .. literalinclude:: ../code_examples/aiohttp_async.py IPython diff --git a/docs/usage/index.rst b/docs/usage/index.rst index f73ac75a..5fb3480f 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -4,7 +4,8 @@ Usage .. toctree:: :maxdepth: 2 - basic_usage + sync_usage + async_usage validation subscriptions variables diff --git a/docs/usage/subscriptions.rst b/docs/usage/subscriptions.rst index 9448328d..549054b9 100644 --- a/docs/usage/subscriptions.rst +++ b/docs/usage/subscriptions.rst @@ -1,29 +1,76 @@ Subscriptions ============= -Using the :ref:`websockets transport `, it is possible to execute GraphQL subscriptions: +Using the :ref:`websockets transport `, it is possible to execute GraphQL subscriptions, +either using the sync or async usage. + +The async usage is recommended for any non-trivial tasks (it allows efficient concurrent queries and subscriptions). + +See :ref:`Async permanent session ` and :ref:`Async advanced usage ` +for more advanced examples. + +.. note:: + + The websockets transport can also execute queries or mutations, it is not restricted to subscriptions. + +Sync +---- .. code-block:: python - from gql import gql, Client + from gql import Client, gql from gql.transport.websockets import WebsocketsTransport + # Select your transport with a defined url endpoint transport = WebsocketsTransport(url='wss://your_server/graphql') - client = Client( - transport=transport, - fetch_schema_from_transport=True, - ) + # Create a GraphQL client using the defined transport + client = Client(transport=transport) + # Provide a GraphQL subscription query query = gql(''' subscription yourSubscription { ... } ''') + # Connect and subscribe to the results using a simple 'for' for result in client.subscribe(query): print (result) -.. note:: +Async +----- + +.. code-block:: python + + import asyncio + + from gql import Client, gql + from gql.transport.websockets import WebsocketsTransport + + + async def main(): + + # Select your transport with a defined url endpoint + transport = WebsocketsTransport(url='wss://your_server/graphql') + + # Create a GraphQL client using the defined transport + client = Client(transport=transport) + + # Provide a GraphQL subscription query + query = gql(''' + subscription yourSubscription { + ... + } + ''') + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection + async with client as session: + + # Then get the results using 'async for' + async for result in client.subscribe(query): + print (result) + - The websockets transport can also execute queries or mutations, it is not restricted to subscriptions + asyncio.run(main()) diff --git a/docs/usage/basic_usage.rst b/docs/usage/sync_usage.rst similarity index 86% rename from docs/usage/basic_usage.rst rename to docs/usage/sync_usage.rst index d53c18d5..f1551618 100644 --- a/docs/usage/basic_usage.rst +++ b/docs/usage/sync_usage.rst @@ -1,9 +1,9 @@ -.. _basic_usage: +.. _sync_usage: -Basic usage ------------ +Sync usage +========== -In order to execute a GraphQL request against a GraphQL API: +To execute a GraphQL request against a GraphQL API: * create your gql :ref:`transport ` in order to choose the destination url and the protocol used to communicate with it @@ -18,4 +18,3 @@ In order to execute a GraphQL request against a GraphQL API: Please note that this basic example won't work if you have an asyncio event loop running. In some python environments (as with Jupyter which uses IPython) an asyncio event loop is created for you. In that case you should use instead the :ref:`Async Usage example`. - From 7695620579f6b3c94d69fed3fba147950fab4e5c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 17 Aug 2025 16:09:48 +0200 Subject: [PATCH 226/239] Restrict graphql-core to <3.3 instead of <3.2.7 on stable branch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0e1e7e63..4f0c8537 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.2,<3.2.7", + "graphql-core>=3.2,<3.3", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", From 059cabaadf69c3ebe94f34db5e4f5a24c918c231 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 17 Aug 2025 16:22:29 +0200 Subject: [PATCH 227/239] Bump version number to 4.0.0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index 65ac68de..ce1305bf 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "4.0.0b0" +__version__ = "4.0.0" From 0778c1952ed862b26a76dc3c30dbaf20a5bc1e13 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 17 Aug 2025 16:47:22 +0200 Subject: [PATCH 228/239] Restore graphql-core alpha version for master branch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4f0c8537..3db1c9f8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages install_requires = [ - "graphql-core>=3.2,<3.3", + "graphql-core>=3.3.0a3,<3.4", "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", From 76ff8adb6baec92a5bd2d32f126c3b7cb2fa1411 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 17 Aug 2025 16:48:33 +0200 Subject: [PATCH 229/239] Bump version number to 4.1.0b0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index ce1305bf..b672be1c 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "4.0.0" +__version__ = "4.1.0b0" From 12478774b552818e94ae29f66cda7c55466931ca Mon Sep 17 00:00:00 2001 From: Katherine Baker <43652476+kasbaker@users.noreply.github.com> Date: Mon, 1 Sep 2025 07:01:11 -0700 Subject: [PATCH 230/239] Add comprehensive directive support to DSL module (#563) - `DSLDirective` class: Represents GraphQL directives with argument validation and AST generation - `DSLDirectable` mixin: Provides reusable `.directives()` method for all DSL elements that support directives - `DSLFragmentSpread` class: Represents fragment spreads with their own directives, separate from fragment definitions - Executable directive location support on query, mutation, subscription, fields, fragments, inline fragments, fragment spreads, and variable definitions ([spec](https://spec.graphql.org/October2021/#sec-Type-System.Directives)) - Automatic schema resolution: Fields automatically use their parent schema for custom directive validation - Fallback on builtin directives: Built-in directives are still available if a schema is not available to validate against The implementation follows the [October 2021 GraphQL specification](https://spec.graphql.org/October2021/) for executable directive locations and maintains backward compatibility with existing DSL code. Users can now use both built-in directives (`@skip`, `@include`) and custom schema directives across all supported GraphQL locations. Co-authored-by: Leszek Hanusz --- docs/advanced/dsl_module.rst | 306 ++++++++++++++++++++++- gql/dsl.py | 464 +++++++++++++++++++++++++++++++++-- tests/starwars/schema.py | 116 +++++++++ tests/starwars/test_dsl.py | 292 ++++++++++++++++++++++ 4 files changed, 1150 insertions(+), 28 deletions(-) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index 1c2c1c82..c6ee035a 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -64,11 +64,11 @@ from the :code:`ds` instance ds.Query.hero.select(ds.Character.name) -The select method return the same instance, so it is possible to chain the calls:: +The select method returns the same instance, so it is possible to chain the calls:: ds.Query.hero.select(ds.Character.name).select(ds.Character.id) -Or do it sequencially:: +Or do it sequentially:: hero_query = ds.Query.hero @@ -279,7 +279,7 @@ will generate the request:: Multiple operations in a document ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -It is possible to create an Document with multiple operations:: +It is possible to create a Document with multiple operations:: query = dsl_gql( operation_name_1=DSLQuery( ... ), @@ -384,6 +384,305 @@ you can use the :class:`DSLMetaField ` class:: DSLMetaField("__typename") ) +Directives +^^^^^^^^^^ + +`Directives`_ provide a way to describe alternate runtime execution and type validation +behavior in a GraphQL document. The DSL module supports both built-in GraphQL directives +(:code:`@skip`, :code:`@include`) and custom schema-defined directives. + +To add directives to DSL elements, use the :meth:`DSLSchema.__call__ ` +factory method and the :meth:`directives ` method:: + + # Using built-in @skip directive with DSLSchema.__call__ factory + ds.Query.hero.select( + ds.Character.name.directives(ds("@skip").args(**{"if": True})) + ) + +Directive Arguments +""""""""""""""""""" + +Directive arguments can be passed using the :meth:`args ` method. +For arguments that don't conflict with Python reserved words, you can pass them directly:: + + # Using the args method for non-reserved names + ds("@custom").args(value="foo", reason="testing") + +It can also be done by calling the directive directly:: + + ds("@custom")(value="foo", reason="testing") + +However, when the GraphQL directive argument name conflicts with a Python reserved word +(like :code:`if`), you need to unpack a dictionary to escape it:: + + # Dictionary unpacking for Python reserved words + ds("@skip").args(**{"if": True}) + ds("@include")(**{"if": False}) + +This ensures that the exact GraphQL argument name is passed to the directive and that +no post-processing of arguments is required. + +The :meth:`DSLSchema.__call__ ` factory method automatically handles +schema lookup and validation for both built-in directives (:code:`@skip`, :code:`@include`) +and custom schema-defined directives using the same syntax. + +Directive Locations +""""""""""""""""""" + +The DSL module supports all executable directive locations from the GraphQL specification: + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - GraphQL Spec Location + - DSL Class/Method + - Description + * - QUERY + - :code:`DSLQuery.directives()` + - Directives on query operations + * - MUTATION + - :code:`DSLMutation.directives()` + - Directives on mutation operations + * - SUBSCRIPTION + - :code:`DSLSubscription.directives()` + - Directives on subscription operations + * - FIELD + - :code:`DSLField.directives()` + - Directives on fields (including meta-fields) + * - FRAGMENT_DEFINITION + - :code:`DSLFragment.directives()` + - Directives on fragment definitions + * - FRAGMENT_SPREAD + - :code:`DSLFragmentSpread.directives()` + - Directives on fragment spreads (via .spread()) + * - INLINE_FRAGMENT + - :code:`DSLInlineFragment.directives()` + - Directives on inline fragments + * - VARIABLE_DEFINITION + - :code:`DSLVariable.directives()` + - Directives on variable definitions + +Examples by Location +"""""""""""""""""""" + +**Operation directives**:: + + # Query operation + query = DSLQuery(ds.Query.hero.select(ds.Character.name)).directives( + ds("@customQueryDirective") + ) + + # Mutation operation + mutation = DSLMutation( + ds.Mutation.createReview.args(episode=6, review={"stars": 5}).select( + ds.Review.stars + ) + ).directives(ds("@customMutationDirective")) + +**Field directives**:: + + # Single directive on field + ds.Query.hero.select( + ds.Character.name.directives(ds("@customFieldDirective")) + ) + + # Multiple directives on a field + ds.Query.hero.select( + ds.Character.appearsIn.directives( + ds("@repeat").args(value="first"), + ds("@repeat").args(value="second"), + ds("@repeat").args(value="third"), + ) + ) + +**Fragment directives**: + +You can add directives to fragment definitions and to fragment spread instances. +To do this, first define your fragment in the usual way:: + + name_and_appearances = ( + DSLFragment("NameAndAppearances") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + ) + +Then, use :meth:`spread() ` when you need to add +directives to the fragment spread:: + + query_with_fragment = DSLQuery( + ds.Query.hero.select( + name_and_appearances.spread().directives( + ds("@customFragmentSpreadDirective") + ) + ) + ) + +The :meth:`spread() ` method creates a +:class:`DSLFragmentSpread ` instance that allows you to add +directives specific to the fragment spread location, separate from directives on the +fragment definition itself. + +Example with fragment definition and spread-specific directives:: + + # Fragment definition with directive + name_and_appearances = ( + DSLFragment("CharacterInfo") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + .directives(ds("@customFragmentDefinitionDirective")) + ) + + # Using fragment with spread-specific directives + query_without_spread_directive = DSLQuery( + # Direct usage (no spread directives) + ds.Query.hero.select(name_and_appearances) + ) + query_with_spread_directive = DSLQuery( + # Enhanced usage with spread directives + name_and_appearances.spread().directives( + ds("@customFragmentSpreadDirective") + ) + ) + + # Don't forget to include the fragment definition in dsl_gql + query = dsl_gql( + name_and_appearances, + BaseQuery=query_without_spread_directive, + QueryWithDirective=query_with_spread_directive, + ) + +This generates GraphQL equivalent to:: + + fragment CharacterInfo on Character @customFragmentDefinitionDirective { + name + appearsIn + } + + { + BaseQuery hero { + ...CharacterInfo + } + QueryWithDirective hero { + ...CharacterInfo @customFragmentSpreadDirective + } + } + +**Inline fragment directives**: + +Inline fragments also support directives using the +:meth:`directives ` method:: + + query_with_directive = ds.Query.hero.args(episode=6).select( + ds.Character.name, + DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet).directives( + ds("@customInlineFragmentDirective") + ) + ) + +This generates:: + + { + hero(episode: JEDI) { + name + ... on Human @customInlineFragmentDirective { + homePlanet + } + } + } + +**Variable definition directives**: + +You can also add directives to variable definitions using the +:meth:`directives ` method:: + + var = DSLVariableDefinitions() + var.episode.directives(ds("@customVariableDirective")) + # Note: the directive is attached to the `.episode` variable definition (singular), + # and not the `var` variable definitions (plural) holder. + + op = DSLQuery(ds.Query.hero.args(episode=var.episode).select(ds.Character.name)) + op.variable_definitions = var + +This will generate:: + + query ($episode: Episode @customVariableDirective) { + hero(episode: $episode) { + name + } + } + +Complete Example for Directives +""""""""""""""""""""""""""""""" + +Here's a comprehensive example showing directives on multiple locations: + +.. code-block:: python + + from gql.dsl import DSLFragment, DSLInlineFragment, DSLQuery, dsl_gql + + # Create variables for directive conditions + var = DSLVariableDefinitions() + + # Fragment with directive on definition + character_fragment = DSLFragment("CharacterInfo").on(ds.Character).select( + ds.Character.name, ds.Character.appearsIn + ).directives(ds("@fragmentDefinition")) + + # Query with directives on multiple locations + query = DSLQuery( + ds.Query.hero.args(episode=var.episode).select( + # Field with directive + ds.Character.name.directives(ds("@skip").args(**{"if": var.skipName})), + + # Fragment spread with directive + character_fragment.spread().directives( + ds("@include").args(**{"if": var.includeFragment}) + ), + + # Inline fragment with directive + DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet).directives( + ds("@skip").args(**{"if": var.skipHuman}) + ), + + # Meta field with directive + DSLMetaField("__typename").directives( + ds("@include").args(**{"if": var.includeType}) + ) + ) + ).directives(ds("@query")) # Operation directive + + # Variable definition with directive + var.episode.directives(ds("@variableDefinition")) + query.variable_definitions = var + + # Generate the document + document = dsl_gql(character_fragment, query) + +This generates GraphQL equivalent to:: + + fragment CharacterInfo on Character @fragmentDefinition { + name + appearsIn + } + + query ( + $episode: Episode @variableDefinition + $skipName: Boolean! + $includeFragment: Boolean! + $skipHuman: Boolean! + $includeType: Boolean! + ) @query { + hero(episode: $episode) { + name @skip(if: $skipName) + ...CharacterInfo @include(if: $includeFragment) + ... on Human @skip(if: $skipHuman) { + homePlanet + } + __typename @include(if: $includeType) + } + } + Executable examples ------------------- @@ -399,4 +698,5 @@ Sync example .. _Fragment: https://graphql.org/learn/queries/#fragments .. _Inline Fragment: https://graphql.org/learn/queries/#inline-fragments +.. _Directives: https://graphql.org/learn/queries/#directives .. _issue #308: https://github.com/graphql-python/gql/issues/308 diff --git a/gql/dsl.py b/gql/dsl.py index 1a8716c2..da4cf64c 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,17 +1,31 @@ """ -.. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa - :alt: UML diagram +.. image:: https://www.plantuml.com/plantuml/png/hLZXJkGs4FwVft1_NLXOfBR_Lcrrz3Wg93WA2rTL24Kc6LYtMITdErpf5QdFqaVharp6tincS8ZsLlTd8PxnpESltun7UMsTDAvPbichRzm2bY3gKYgT9Bfo8AGLfrNHb73KwDofIjjaCWahWfOca-J_V_yJXIsp-mzbEgbgCD9RziIazvHzL6wHQRc4dPdunSXwSNvo0HyQiCu7aDPbTwPQPW-oR23rltl2FTQGjHlEQWmYo-ltkFwkAk26xx9Wb2pLtr2405cZSM-HhWqlX05T23nkakIbj5OSpa_cUSk559yI8QRJzcStot9PbbcM8lwPiCxipD3nK1d8dNg0u7GFJZfdOh_B5ahoH1d20iKVtNgae2pONahg0-mMtMDMm1rHov0XI-Gs4sH30j1EAUC3JoP_VfJctWwS5vTViZF0xwLHyhQ4GxXJMdar1EWFAuD5JBcxjixizJVSR40GEQDRwvJvmwupfQtNPLENS1t3mFFlYVtz_Hl4As_Rc39tOgq3A25tbGbeBJxXjio2cubvzpW7Xu48wwSkq9DG5jMeYkmEtsBgVriyjrLLhYEc4x_kwoNy5sgbtIYHrmFzoE5n8U2HdYd18WdTiTdR3gSTXKfHKlglWynof1FwVnJbHLKvBsB6PiW_nizWi2CZxvUWtLU9zRL0OGnw3vnLQLq8CnDNMbNwsYSDR-9Obqf3TwAmHkUh3KZlrtjPracdyYU1AlVYW1L6ctOAYlH3wcSunqJ_zY_86-_5YxHVLBCNofgQ2NLQhEcRZQg7yGO40gNiAM0jvQoxLm96kcOoRFepGMRii-Z0u_KSU3E84vqtO1w7aeWVUPRzywkt5xzp4OsN4yjpsZWVQgDKfrUN1vV7P--spZPlRcrkLBrnnldLp_Ct5yU_RfsL14EweZRUtL0aD4JGKn02w2g1EuOGNTXEHgrEPLEwC0VuneIhpuAkhibZNJSE4wpBp5Ke4GyYxSQF3a8GCZVoEuZIfmm6Tzk2FEfyWRnUNubR1cStLZzj6H8_dj17IWDc7dx3MujlzVhIWQ-yqeNFo5qsPsIq__xM8ZX0035B-8UTqWDD_IzD4uEns6lWJJjAmysKRtFQU8fnyhZZwEqSUsyZGSGxokokNwCXr9jmkPO6T2YRxY9SkPpT_W6vhy0zGJNfmDp97Bgwt2ri-Rmfj738lF7uIdXmQS2skRnfnpZhvBJ5XG1EzWYdot_Phg_8Y2ZSkZFp8j-YnM3QSI9uZ2y0-KeSwmKOvQJEGHWe_Qra5wgsINz6_-6VwJGQws8FDk74PXfOnuF4asYIy8ayJZRWm2w5sCmRKfAmS16IP01LxCH2nkPaY01oew5W20gp9_qdRwTfQj140z2WbGqioV0PU8CRPuEx3WSSlWi6F6Dn9yERkKJHYRFCpMIdTMe9M1HlgcLTMNyRyA8GKt4Y7y68RyMgdWH-8H6cgjnEilwwCPt-H5yYPY8t81rORkTV6yXfi_JVYTJd3PiAKVasPJq4J8e9wBGCmU070-zDfYz6yxr86ollGIWjQDQrErp7F0dBZ_agxQJIbXVg44-D1TlNd_U9somTGJmeARgfAtaDkcYMvMS0 # noqa + :alt: UML diagram - rename png to uml to edit """ import logging import re from abc import ABC, abstractmethod from math import isfinite -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast +from typing import ( + Any, + Dict, + Iterable, + Literal, + Mapping, + Optional, + Set, + Tuple, + Union, + cast, + overload, +) from graphql import ( ArgumentNode, BooleanValueNode, + DirectiveLocation, + DirectiveNode, DocumentNode, EnumValueNode, FieldNode, @@ -19,6 +33,7 @@ FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLDirective, GraphQLEnumType, GraphQLError, GraphQLField, @@ -61,6 +76,7 @@ is_non_null_type, is_wrapping_type, print_ast, + specified_directives, ) from graphql.pyutils import inspect @@ -132,8 +148,9 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: Produce a GraphQL Value AST given a Python object. - Raises a GraphQLError instead of returning None if we receive an Undefined - of if we receive a Null value for a Non-Null type. + :raises graphql.error.GraphQLError: + instead of returning None if we receive an Undefined + of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): return value.set_type(type_).ast_variable_name @@ -274,6 +291,9 @@ class DSLSchema: Attributes of the DSLSchema class are generated automatically with the `__getattr__` dunder method in order to generate instances of :class:`DSLType` + + .. automethod:: __call__ + .. automethod:: __getattr__ """ def __init__(self, schema: GraphQLSchema): @@ -293,7 +313,57 @@ def __init__(self, schema: GraphQLSchema): self._schema: GraphQLSchema = schema + @overload + def __call__( + self, shortcut: Literal["__typename", "__schema", "__type"] + ) -> "DSLMetaField": ... # pragma: no cover + + @overload + def __call__( + self, shortcut: Literal["..."] + ) -> "DSLInlineFragment": ... # pragma: no cover + + @overload + def __call__( + self, shortcut: Literal["fragment"], name: str + ) -> "DSLFragment": ... # pragma: no cover + + @overload + def __call__(self, shortcut: Any) -> "DSLDirective": ... # pragma: no cover + + def __call__( + self, shortcut: str, name: Optional[str] = None + ) -> Union["DSLMetaField", "DSLInlineFragment", "DSLFragment", "DSLDirective"]: + """Factory method for creating DSL objects. + + Currently, supports creating DSLDirective instances when name starts with '@'. + Future support planned for meta-fields (__typename), inline fragments (...), + and fragment definitions (fragment). + + :param shortcut: the name of the object to create + :type shortcut: str + + :return: :class:`DSLDirective` instance + + :raises ValueError: if shortcut format is not supported + """ + if shortcut.startswith("@"): + return DSLDirective(name=shortcut[1:], dsl_schema=self) + # Future support: + # if name.startswith("__"): return DSLMetaField(name) + # if name == "...": return DSLInlineFragment() + # if name.startswith("fragment "): return DSLFragment(name[9:]) + + raise ValueError(f"Unsupported shortcut: {shortcut}") + def __getattr__(self, name: str) -> "DSLType": + """Attributes of the DSLSchema class are generated automatically + with this dunder method in order to generate + instances of :class:`DSLType` + + :return: :class:`DSLType` instance + :raises AttributeError: if the name is not valid + """ type_def: Optional[GraphQLNamedType] = self._schema.get_type(name) @@ -381,7 +451,218 @@ def select( log.debug(f"Added fields: {added_fields} in {self!r}") -class DSLExecutable(DSLSelector): +class DSLDirective: + """The DSLDirective represents a GraphQL directive for the DSL code. + + Directives provide a way to describe alternate runtime execution and type validation + behavior in a GraphQL document. + """ + + def __init__(self, name: str, dsl_schema: "DSLSchema"): + r"""Initialize the DSLDirective with the given name and arguments. + + :param name: the name of the directive + :param dsl_schema: DSLSchema for directive validation and definition lookup + + :raises graphql.error.GraphQLError: if directive not found or not executable + """ + self._dsl_schema = dsl_schema + + # Find directive definition in schema or built-ins + directive_def = self._dsl_schema._schema.get_directive(name) + + if directive_def is None: + # Try to find in built-in directives using specified_directives + builtins = {builtin.name: builtin for builtin in specified_directives} + directive_def = builtins.get(name) + + if directive_def is None: + available: Set[str] = set() + available.update(f"@{d.name}" for d in self._dsl_schema._schema.directives) + available.update(f"@{d.name}" for d in specified_directives) + raise GraphQLError( + f"Directive '@{name}' not found in schema or built-ins. " + f"Available directives: {', '.join(sorted(available))}" + ) + + # Check directive has at least one executable location + executable_locations = { + DirectiveLocation.QUERY, + DirectiveLocation.MUTATION, + DirectiveLocation.SUBSCRIPTION, + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_DEFINITION, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + DirectiveLocation.VARIABLE_DEFINITION, + } + + if not any(loc in executable_locations for loc in directive_def.locations): + raise GraphQLError( + f"Directive '@{name}' is not a valid request executable directive. " + f"It can only be used in type system locations, not in requests." + ) + + self.directive_def: GraphQLDirective = directive_def + self.ast_directive = DirectiveNode(name=NameNode(value=name), arguments=()) + + @property + def name(self) -> str: + """Get the directive name.""" + return self.ast_directive.name.value + + def __call__(self, **kwargs: Any) -> "DSLDirective": + """Add arguments by calling the directive like a function. + + :param kwargs: directive arguments + :return: itself + """ + return self.args(**kwargs) + + def args(self, **kwargs: Any) -> "DSLDirective": + r"""Set the arguments of a directive + + The arguments are parsed to be stored in the AST of this field. + + .. note:: + You can also call the field directly with your arguments. + :code:`ds("@someDirective").args(value="foo")` is equivalent to: + :code:`ds("@someDirective")(value="foo")` + + :param \**kwargs: the arguments (keyword=value) + + :return: itself + + :raises AttributeError: if arguments already set for this directive + :raises graphql.error.GraphQLError: + if argument doesn't exist in directive definition + """ + if len(self.ast_directive.arguments) > 0: + raise AttributeError(f"Arguments for directive @{self.name} already set.") + + errs = [] + for key, value in kwargs.items(): + if key not in self.directive_def.args: + errs.append( + f"Argument '{key}' does not exist in directive '@{self.name}'" + ) + if errs: + raise GraphQLError("\n".join(errs)) + + # Update AST directive with arguments + self.ast_directive = DirectiveNode( + name=NameNode(value=self.name), + arguments=tuple( + ArgumentNode( + name=NameNode(value=key), + value=ast_from_value(value, self.directive_def.args[key].type), + ) + for key, value in kwargs.items() + ), + ) + + return self + + def __repr__(self) -> str: + args_str = ", ".join( + f"{arg.name.value}={getattr(arg.value, 'value')}" + for arg in self.ast_directive.arguments + ) + return f"" + + +class DSLDirectable(ABC): + """Mixin class for DSL elements that can have directives. + + Provides the directives() method for adding GraphQL directives to DSL elements. + Classes that need immediate AST updates should override the directives() method. + """ + + _directives: Tuple[DSLDirective, ...] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._directives = () + + @abstractmethod + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if a directive is valid for this DSL element. + + :param directive: The DSLDirective to validate + :return: True if the directive can be used at this location + """ + raise NotImplementedError( + "Any DSLDirectable concrete class must have an is_valid_directive method" + ) # pragma: no cover + + def directives(self, *directives: DSLDirective) -> Any: + r"""Add directives to this DSL element. + + :param \*directives: DSLDirective instances to add + :return: itself + + :raises graphql.error.GraphQLError: if directive location is invalid + :raises TypeError: if argument is not a DSLDirective + + Usage: + + .. code-block:: python + + # Using new factory method + element.directives(ds("@include")(**{"if": var.show})) + element.directives(ds("@skip")(**{"if": var.hide})) + """ + validated_directives = [] + + for directive in directives: + if not isinstance(directive, DSLDirective): + raise TypeError( + f"Expected DSLDirective, got {type(directive)}. " + f"Use ds('@directiveName') to create directive instances." + ) + + # Validate directive location using the abstract method + if not self.is_valid_directive(directive): + # Get valid locations for error message + valid_locations = [ + loc.name + for loc in directive.directive_def.locations + if loc + in { + DirectiveLocation.QUERY, + DirectiveLocation.MUTATION, + DirectiveLocation.SUBSCRIPTION, + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_DEFINITION, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + DirectiveLocation.VARIABLE_DEFINITION, + } + ] + raise GraphQLError( + f"Invalid directive location: '@{directive.name}' " + f"cannot be used on {self.__class__.__name__}. " + f"Valid locations for this directive: {', '.join(valid_locations)}" + ) + + validated_directives.append(directive) + + # Update stored directives + self._directives = self._directives + tuple(validated_directives) + + log.debug( + f"Added directives {[d.name for d in validated_directives]} to {self!r}" + ) + + return self + + @property + def directives_ast(self) -> Tuple[DirectiveNode, ...]: + """Get AST directive nodes for this element.""" + return tuple(directive.ast_directive for directive in self._directives) + + +class DSLExecutable(DSLSelector, DSLDirectable): """Interface for the root elements which can be executed in the :func:`dsl_gql ` function @@ -430,6 +711,7 @@ def __init__( self.variable_definitions = DSLVariableDefinitions() DSLSelector.__init__(self, *fields, **fields_with_alias) + DSLDirectable.__init__(self) class DSLRootFieldSelector(DSLSelector): @@ -508,7 +790,7 @@ def executable_ast(self) -> OperationDefinitionNode: selection_set=self.selection_set, variable_definitions=self.variable_definitions.get_ast_definitions(), **({"name": NameNode(value=self.name)} if self.name else {}), - directives=(), + directives=self.directives_ast, ) def __repr__(self) -> str: @@ -518,16 +800,28 @@ def __repr__(self) -> str: class DSLQuery(DSLOperation): operation_type = OperationType.QUERY + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Query operations.""" + return DirectiveLocation.QUERY in directive.directive_def.locations + class DSLMutation(DSLOperation): operation_type = OperationType.MUTATION + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Mutation operations.""" + return DirectiveLocation.MUTATION in directive.directive_def.locations + class DSLSubscription(DSLOperation): operation_type = OperationType.SUBSCRIPTION + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Subscription operations.""" + return DirectiveLocation.SUBSCRIPTION in directive.directive_def.locations -class DSLVariable: + +class DSLVariable(DSLDirectable): """The DSLVariable represents a single variable defined in a GraphQL operation Instances of this class are generated for you automatically as attributes @@ -545,6 +839,8 @@ def __init__(self, name: str): self.default_value = None self.type: Optional[GraphQLInputType] = None + DSLDirectable.__init__(self) + def to_ast_type(self, type_: GraphQLInputType) -> TypeNode: if is_wrapping_type(type_): if isinstance(type_, GraphQLList): @@ -568,6 +864,18 @@ def default(self, default_value: Any) -> "DSLVariable": self.default_value = default_value return self + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Variable definitions.""" + for arg in directive.ast_directive.arguments: + if isinstance(arg.value, VariableNode): + raise GraphQLError( + f"Directive @{directive.name} argument value has " + f"unexpected variable '${arg.value.name}' in constant location." + ) + return ( + DirectiveLocation.VARIABLE_DEFINITION in directive.directive_def.locations + ) + class DSLVariableDefinitions: """The DSLVariableDefinitions represents variable definitions in a GraphQL operation @@ -579,6 +887,8 @@ class DSLVariableDefinitions: with the `__getattr__` dunder method in order to generate instances of :class:`DSLVariable`, that can then be used as values in the :meth:`args ` method. + + .. automethod:: __getattr__ """ def __init__(self): @@ -586,6 +896,12 @@ def __init__(self): self.variables: Dict[str, DSLVariable] = {} def __getattr__(self, name: str) -> "DSLVariable": + """Attributes of the DSLVariableDefinitions class are generated automatically + with this dunder method in order to generate + instances of :class:`DSLVariable` + + :return: :class:`DSLVariable` instance + """ if name not in self.variables: self.variables[name] = DSLVariable(name) return self.variables[name] @@ -605,7 +921,7 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: if var.default_value is None else ast_from_value(var.default_value, var.type) ), - directives=(), + directives=var.directives_ast, ) for var in self.variables.values() if var.type is not None # only variables used @@ -625,6 +941,8 @@ class DSLType: Attributes of the DSLType class are generated automatically with the `__getattr__` dunder method in order to generate instances of :class:`DSLField` + + .. automethod:: __getattr__ """ def __init__( @@ -646,6 +964,13 @@ def __init__( log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": + """Attributes of the DSLType class are generated automatically + with this dunder method in order to generate + instances of :class:`DSLField` + + :return: :class:`DSLField` instance + :raises AttributeError: if the field name does not exist in the type + """ camel_cased_name = to_camel_case(name) if name in self._type.fields: @@ -665,7 +990,7 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" -class DSLSelectable(ABC): +class DSLSelectable(DSLDirectable): """DSLSelectable is an abstract class which indicates that the subclasses can be used as arguments of the :meth:`select ` method. @@ -715,7 +1040,7 @@ def is_valid_field(self, field: DSLSelectable) -> bool: assert isinstance(self, (DSLFragment, DSLInlineFragment)) - if isinstance(field, (DSLFragment, DSLInlineFragment)): + if isinstance(field, (DSLFragment, DSLFragmentSpread, DSLInlineFragment)): return True assert isinstance(field, DSLField) @@ -747,7 +1072,7 @@ def is_valid_field(self, field: DSLSelectable) -> bool: assert isinstance(self, DSLField) - if isinstance(field, (DSLFragment, DSLInlineFragment)): + if isinstance(field, (DSLFragment, DSLFragmentSpread, DSLInlineFragment)): return True assert isinstance(field, DSLField) @@ -837,6 +1162,7 @@ def __init__( log.debug(f"Creating {self!r}") DSLSelector.__init__(self) + DSLDirectable.__init__(self) @property def name(self): @@ -903,6 +1229,17 @@ def select( return self + def directives(self, *directives: DSLDirective) -> "DSLField": + """Add directives to this field.""" + super().directives(*directives) + self.ast_field.directives = self.directives_ast + + return self + + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Field locations.""" + return DirectiveLocation.FIELD in directive.directive_def.locations + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.parent_type.name}" f"::{self.name}>" @@ -941,6 +1278,10 @@ def __init__(self, name: str): super().__init__(name, self.meta_type, field) + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for MetaField locations (same as Field).""" + return DirectiveLocation.FIELD in directive.directive_def.locations + class DSLInlineFragment(DSLSelectable, DSLFragmentSelector): """DSLInlineFragment represents an inline fragment for the DSL code.""" @@ -966,6 +1307,7 @@ def __init__( self.ast_field = InlineFragmentNode(directives=()) DSLSelector.__init__(self, *fields, **fields_with_alias) + DSLDirectable.__init__(self) def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" @@ -987,6 +1329,15 @@ def on(self, type_condition: DSLType) -> "DSLInlineFragment": ) return self + def directives(self, *directives: DSLDirective) -> "DSLInlineFragment": + """Add directives to this inline fragment. + + Inline fragments support all directive types through auto-validation. + """ + super().directives(*directives) + self.ast_field.directives = self.directives_ast + return self + def __repr__(self) -> str: type_info = "" @@ -997,13 +1348,62 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__}{type_info}>" + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Inline Fragment locations.""" + return DirectiveLocation.INLINE_FRAGMENT in directive.directive_def.locations + + +class DSLFragmentSpread(DSLSelectable): + """Represents a fragment spread (usage) with its own directives. + + This class is created by calling .spread() on a DSLFragment and allows + adding directives specific to the FRAGMENT_SPREAD location. + """ + + ast_field: FragmentSpreadNode + _fragment: "DSLFragment" + + def __init__(self, fragment: "DSLFragment"): + """Initialize a fragment spread from a fragment definition. + + :param fragment: The DSLFragment to create a spread from + """ + self._fragment = fragment + self.ast_field = FragmentSpreadNode( + name=NameNode(value=fragment.name), directives=() + ) + + log.debug(f"Creating fragment spread for {fragment.name}") + + DSLDirectable.__init__(self) + + @property + def name(self) -> str: + """:meta private:""" + return self.ast_field.name.value + + def directives(self, *directives: DSLDirective) -> "DSLFragmentSpread": + """Add directives to this fragment spread. + + Fragment spreads support all directive types through auto-validation. + """ + super().directives(*directives) + self.ast_field.directives = self.directives_ast + return self + + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Fragment Spread locations.""" + return DirectiveLocation.FRAGMENT_SPREAD in directive.directive_def.locations + + def __repr__(self) -> str: + return f"" + class DSLFragment(DSLSelectable, DSLFragmentSelector, DSLExecutable): """DSLFragment represents a named GraphQL fragment for the DSL code.""" _type: Optional[Union[GraphQLObjectType, GraphQLInterfaceType]] ast_field: FragmentSpreadNode - name: str def __init__( self, @@ -1017,24 +1417,32 @@ def __init__( DSLExecutable.__init__(self) - self.name = name + self.ast_field = FragmentSpreadNode(name=NameNode(value=name), directives=()) + self._type = None log.debug(f"Creating {self!r}") - @property # type: ignore - def ast_field(self) -> FragmentSpreadNode: # type: ignore - """ast_field property will generate a FragmentSpreadNode with the - provided name. + @property + def name(self) -> str: + """:meta private:""" + return self.ast_field.name.value - Note: We need to ignore the type because of - `issue #4125 of mypy `_. - """ + @name.setter + def name(self, value: str) -> None: + """:meta private:""" + if hasattr(self, "ast_field"): + self.ast_field.name.value = value - spread_node = FragmentSpreadNode(directives=()) - spread_node.name = NameNode(value=self.name) + def spread(self) -> DSLFragmentSpread: + """Create a fragment spread that can have its own directives. - return spread_node + This allows adding directives specific to the FRAGMENT_SPREAD location, + separate from directives on the fragment definition itself. + + :return: DSLFragmentSpread instance for this fragment + """ + return DSLFragmentSpread(self) def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" @@ -1096,7 +1504,13 @@ def executable_ast(self) -> FragmentDefinitionNode: selection_set=self.selection_set, **variable_definition_kwargs, name=NameNode(value=self.name), - directives=(), + directives=self.directives_ast, + ) + + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Fragment Definition locations.""" + return ( + DirectiveLocation.FRAGMENT_DEFINITION in directive.directive_def.locations ) def __repr__(self) -> str: diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 8f1efe99..f14a4ea1 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -2,7 +2,9 @@ from typing import cast from graphql import ( + DirectiveLocation, GraphQLArgument, + GraphQLDirective, GraphQLEnumType, GraphQLEnumValue, GraphQLField, @@ -19,6 +21,7 @@ get_introspection_query, graphql_sync, print_schema, + specified_directives, ) from .fixtures import ( @@ -264,12 +267,125 @@ async def resolve_review(review, _info, **_args): }, ) +query_directive = GraphQLDirective( + name="query", + description="Test directive for QUERY location", + locations=[DirectiveLocation.QUERY], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +field_directive = GraphQLDirective( + name="field", + description="Test directive for FIELD location", + locations=[DirectiveLocation.FIELD], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +fragment_spread_directive = GraphQLDirective( + name="fragmentSpread", + description="Test directive for FRAGMENT_SPREAD location", + locations=[DirectiveLocation.FRAGMENT_SPREAD], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +inline_fragment_directive = GraphQLDirective( + name="inlineFragment", + description="Test directive for INLINE_FRAGMENT location", + locations=[DirectiveLocation.INLINE_FRAGMENT], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +fragment_definition_directive = GraphQLDirective( + name="fragmentDefinition", + description="Test directive for FRAGMENT_DEFINITION location", + locations=[DirectiveLocation.FRAGMENT_DEFINITION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +mutation_directive = GraphQLDirective( + name="mutation", + description="Test directive for MUTATION location (tests keyword conflict)", + locations=[DirectiveLocation.MUTATION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +subscription_directive = GraphQLDirective( + name="subscription", + description="Test directive for SUBSCRIPTION location", + locations=[DirectiveLocation.SUBSCRIPTION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +variable_definition_directive = GraphQLDirective( + name="variableDefinition", + description="Test directive for VARIABLE_DEFINITION location", + locations=[DirectiveLocation.VARIABLE_DEFINITION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +repeat_directive = GraphQLDirective( + name="repeat", + description="Test repeatable directive for FIELD location", + locations=[DirectiveLocation.FIELD], + args={ + "value": GraphQLArgument( + GraphQLString, + description="A string value for the repeatable directive", + ) + }, + is_repeatable=True, +) + StarWarsSchema = GraphQLSchema( query=query_type, mutation=mutation_type, subscription=subscription_type, types=[human_type, droid_type, review_type, review_input_type], + directives=[ + *specified_directives, + query_directive, + field_directive, + fragment_spread_directive, + inline_fragment_directive, + fragment_definition_directive, + mutation_directive, + subscription_directive, + variable_definition_directive, + repeat_directive, + ], ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index e47a97d8..a3d1ef8c 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -23,7 +23,9 @@ from gql import Client, gql from gql.dsl import ( + DSLField, DSLFragment, + DSLFragmentSpread, DSLInlineFragment, DSLMetaField, DSLMutation, @@ -47,6 +49,12 @@ def ds(): return DSLSchema(StarWarsSchema) +@pytest.fixture +def var(): + """Common DSLVariableDefinitions fixture for directive tests""" + return DSLVariableDefinitions() + + @pytest.fixture def client(): return Client(schema=StarWarsSchema) @@ -659,7 +667,23 @@ def test_fragments_repr(ds): assert repr(DSLInlineFragment()) == "" assert repr(DSLInlineFragment().on(ds.Droid)) == "" assert repr(DSLFragment("fragment_1")) == "" + assert repr(DSLFragment("fragment_1").spread()) == "" assert repr(DSLFragment("fragment_2").on(ds.Droid)) == "" + assert ( + repr(DSLFragment("fragment_2").on(ds.Droid).spread()) + == "" + ) + + +def test_fragment_spread_instances(ds): + """Test that each .spread() creates new DSLFragmentSpread instance""" + fragment = DSLFragment("Test").on(ds.Character).select(ds.Character.name) + spread1 = fragment.spread() + spread2 = fragment.spread() + + assert isinstance(spread1, DSLFragmentSpread) + assert isinstance(spread2, DSLFragmentSpread) + assert spread1 is not spread2 def test_fragments(ds): @@ -1271,3 +1295,271 @@ def test_legacy_fragment_with_variables(ds): } """.strip() assert print_ast(query.document) == expected + + +def test_dsl_schema_call_validation(ds): + with pytest.raises(ValueError, match="(?i)unsupported shortcut"): + ds("foo") + + +def test_executable_directives(ds, var): + """Test ALL executable directive locations and types in one document""" + + # Fragment with both built-in and custom directives + fragment = ( + DSLFragment("CharacterInfo") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + .directives(ds("@fragmentDefinition")) + ) + + # Query with multiple directive types + query = DSLQuery( + ds.Query.hero.args(episode=var.episode).select( + # Field with both built-in and custom directives + ds.Character.name.directives( + ds("@skip")(**{"if": var.skipName}), + ds("@field"), # custom field directive + ), + # Field with repeated directives (same directive multiple times) + ds.Character.appearsIn.directives( + ds("@repeat")(value="first"), + ds("@repeat")(value="second"), + ds("@repeat")(value="third"), + ), + # Fragment spread with multiple directives + fragment.spread().directives( + ds("@include")(**{"if": var.includeSpread}), + ds("@fragmentSpread"), + ), + # Inline fragment with directives + DSLInlineFragment() + .on(ds.Human) + .select(ds.Human.homePlanet) + .directives( + ds("@skip")(**{"if": var.skipInline}), + ds("@inlineFragment"), + ), + # Meta field with directive + DSLMetaField("__typename").directives( + ds("@include")(**{"if": var.includeType}) + ), + ) + ).directives(ds("@query")) + + # Mutation with directives + mutation = DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "Great!"} + ).select(ds.Review.stars, ds.Review.commentary) + ).directives(ds("@mutation")) + + # Subscription with directives + subscription = DSLSubscription( + ds.Subscription.reviewAdded.args(episode=6).select( + ds.Review.stars, ds.Review.commentary + ) + ).directives(ds("@subscription")) + + # Variable definitions with directives + var.episode.directives( + # Note that `$episode: Episode @someDirective(value=$someValue)` + # is INVALID GraphQL because variable definitions must be literal values + ds("@variableDefinition"), + ) + query.variable_definitions = var + + # Generate ONE document with everything + doc = dsl_gql( + fragment, HeroQuery=query, CreateReview=mutation, ReviewSub=subscription + ) + + expected = """\ +fragment CharacterInfo on Character @fragmentDefinition { + name + appearsIn +} + +query HeroQuery(\ +$episode: Episode @variableDefinition, \ +$skipName: Boolean!, \ +$includeSpread: Boolean!, \ +$skipInline: Boolean!, \ +$includeType: Boolean!\ +) @query { + hero(episode: $episode) { + name @skip(if: $skipName) @field + appearsIn @repeat(value: "first") @repeat(value: "second") @repeat(value: "third") + ...CharacterInfo @include(if: $includeSpread) @fragmentSpread + ... on Human @skip(if: $skipInline) @inlineFragment { + homePlanet + } + __typename @include(if: $includeType) + } +} + +mutation CreateReview @mutation { + createReview(episode: JEDI, review: {stars: 5, commentary: "Great!"}) { + stars + commentary + } +} + +subscription ReviewSub @subscription { + reviewAdded(episode: JEDI) { + stars + commentary + } +}""" + + assert strip_braces_spaces(print_ast(doc.document)) == expected + assert node_tree(doc.document) == node_tree(gql(expected).document) + + +def test_directive_repr(ds): + """Test DSLDirective string representation""" + directive = ds("@include")(**{"if": True}) + expected = "" + assert repr(directive) == expected + + +def test_directive_error_handling(ds): + """Test error handling for directives""" + # Invalid directive argument type + with pytest.raises(TypeError, match="Expected DSLDirective"): + ds.Query.hero.directives(123) + + # Invalid directive name from `__call__ + with pytest.raises(GraphQLError, match="Directive '@nonexistent' not found"): + ds("@nonexistent") + + # Invalid directive argument + with pytest.raises(GraphQLError, match="Argument 'invalid' does not exist"): + ds("@include")(invalid=True) + + # Tried to set arguments twice + with pytest.raises( + AttributeError, match="Arguments for directive @field already set." + ): + ds("@field").args(value="foo").args(value="bar") + + with pytest.raises( + GraphQLError, + match="(?i)Directive '@deprecated' is not a valid request executable directive", + ): + ds("@deprecated") + + with pytest.raises(GraphQLError, match="unexpected variable"): + # variable definitions must be static, literal values defined in the query! + var = DSLVariableDefinitions() + query = DSLQuery( + ds.Query.hero.args(episode=var.episode).select(ds.Character.name) + ) + var.episode.directives( + ds("@variableDefinition").args(value=var.nonStatic), + ) + query.variable_definitions = var + _ = dsl_gql(query).document + + +# Parametrized tests for comprehensive directive location validation +@pytest.fixture( + params=[ + "@query", + "@mutation", + "@subscription", + "@field", + "@fragmentDefinition", + "@fragmentSpread", + "@inlineFragment", + "@variableDefinition", + ] +) +def directive_name(request): + return request.param + + +@pytest.fixture( + params=[ + (DSLQuery, "QUERY"), + (DSLMutation, "MUTATION"), + (DSLSubscription, "SUBSCRIPTION"), + (DSLField, "FIELD"), + (DSLMetaField, "FIELD"), + (DSLFragment, "FRAGMENT_DEFINITION"), + (DSLFragmentSpread, "FRAGMENT_SPREAD"), + (DSLInlineFragment, "INLINE_FRAGMENT"), + (DSLVariable, "VARIABLE_DEFINITION"), + ] +) +def dsl_class_and_location(request): + return request.param + + +@pytest.fixture +def is_valid_combination(directive_name, dsl_class_and_location): + # Map directive names to their expected locations + directive_to_location = { + "@query": "QUERY", + "@mutation": "MUTATION", + "@subscription": "SUBSCRIPTION", + "@field": "FIELD", + "@fragmentDefinition": "FRAGMENT_DEFINITION", + "@fragmentSpread": "FRAGMENT_SPREAD", + "@inlineFragment": "INLINE_FRAGMENT", + "@variableDefinition": "VARIABLE_DEFINITION", + } + expected_location = directive_to_location[directive_name] + _, actual_location = dsl_class_and_location + return expected_location == actual_location + + +def create_dsl_instance(dsl_class, ds): + """Helper function to create DSL instances for testing""" + if dsl_class == DSLQuery: + return DSLQuery(ds.Query.hero.select(ds.Character.name)) + elif dsl_class == DSLMutation: + return DSLMutation( + ds.Mutation.createReview.args(episode=6, review={"stars": 5}).select( + ds.Review.stars + ) + ) + elif dsl_class == DSLSubscription: + return DSLSubscription( + ds.Subscription.reviewAdded.args(episode=6).select(ds.Review.stars) + ) + elif dsl_class == DSLField: + return ds.Query.hero + elif dsl_class == DSLMetaField: + return DSLMetaField("__typename") + elif dsl_class == DSLFragment: + return DSLFragment("test").on(ds.Character).select(ds.Character.name) + elif dsl_class == DSLFragmentSpread: + fragment = DSLFragment("test").on(ds.Character).select(ds.Character.name) + return fragment.spread() + elif dsl_class == DSLInlineFragment: + return DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet) + elif dsl_class == DSLVariable: + var = DSLVariableDefinitions() + return var.testVar + else: + raise ValueError(f"Unknown DSL class: {dsl_class}") + + +def test_directive_location_validation( + ds, directive_name, dsl_class_and_location, is_valid_combination +): + """Test all 64 combinations of 8 directives × 8 DSL classes""" + dsl_class, _ = dsl_class_and_location + directive = ds(directive_name) + + # Create instance of DSL class and try to apply directive + instance = create_dsl_instance(dsl_class, ds) + + if is_valid_combination: + # Should work without error + instance.directives(directive) + else: + # Should raise GraphQLError for invalid location + with pytest.raises(GraphQLError, match="Invalid directive location"): + instance.directives(directive) From 49de08497b87c39cde519052073c8fb9b6ec4195 Mon Sep 17 00:00:00 2001 From: Katherine Baker <43652476+kasbaker@users.noreply.github.com> Date: Fri, 5 Sep 2025 00:57:09 -0700 Subject: [PATCH 231/239] Feature dsl schema shortcuts (#566) --- docs/advanced/dsl_module.rst | 17 +++++++++++++++++ gql/dsl.py | 35 +++++++++++++++++------------------ tests/starwars/test_dsl.py | 17 +++++++++++++++++ 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index c6ee035a..e30655b5 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -373,6 +373,14 @@ this can be written in a concise manner:: DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet) ) +Alternatively, you can use the DSL shortcut syntax to create an inline fragment by +passing the string ``"..."`` directly to the :meth:`__call__ ` method:: + + query_with_inline_fragment = ds.Query.hero.args(episode=6).select( + ds.Character.name, + ds("...").on(ds.Human).select(ds.Human.homePlanet) + ) + Meta-fields ^^^^^^^^^^^ @@ -384,6 +392,15 @@ you can use the :class:`DSLMetaField ` class:: DSLMetaField("__typename") ) +Alternatively, you can use the DSL shortcut syntax to create the same meta-field by +passing the ``"__typename"`` string directly to the :meth:`__call__ ` method:: + + query = ds.Query.hero.select( + ds.Character.name, + ds("__typename") + ) + + Directives ^^^^^^^^^^ diff --git a/gql/dsl.py b/gql/dsl.py index da4cf64c..2e6d3967 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -323,36 +323,35 @@ def __call__( self, shortcut: Literal["..."] ) -> "DSLInlineFragment": ... # pragma: no cover - @overload - def __call__( - self, shortcut: Literal["fragment"], name: str - ) -> "DSLFragment": ... # pragma: no cover - @overload def __call__(self, shortcut: Any) -> "DSLDirective": ... # pragma: no cover def __call__( - self, shortcut: str, name: Optional[str] = None - ) -> Union["DSLMetaField", "DSLInlineFragment", "DSLFragment", "DSLDirective"]: - """Factory method for creating DSL objects. + self, shortcut: str + ) -> Union["DSLMetaField", "DSLInlineFragment", "DSLDirective"]: + """Factory method for creating DSL objects from a shortcut string. - Currently, supports creating DSLDirective instances when name starts with '@'. - Future support planned for meta-fields (__typename), inline fragments (...), - and fragment definitions (fragment). + The shortcut determines which DSL object is created: - :param shortcut: the name of the object to create + * "__typename", "__schema", "__type" -> :class:`DSLMetaField` + * "..." -> :class:`DSLInlineFragment` + * "@" -> :class:`DSLDirective` + + :param shortcut: The shortcut string identifying the DSL object. :type shortcut: str - :return: :class:`DSLDirective` instance + :return: A DSL object corresponding to the given shortcut. + :rtype: DSLMetaField | DSLInlineFragment | DSLDirective - :raises ValueError: if shortcut format is not supported + :raises ValueError: If the shortcut is not recognized. """ + + if shortcut in ("__typename", "__schema", "__type"): + return DSLMetaField(name=shortcut) + if shortcut == "...": + return DSLInlineFragment() if shortcut.startswith("@"): return DSLDirective(name=shortcut[1:], dsl_schema=self) - # Future support: - # if name.startswith("__"): return DSLMetaField(name) - # if name == "...": return DSLInlineFragment() - # if name.startswith("fragment "): return DSLFragment(name[9:]) raise ValueError(f"Unsupported shortcut: {shortcut}") diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index a3d1ef8c..7f042a07 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -23,6 +23,7 @@ from gql import Client, gql from gql.dsl import ( + DSLDirective, DSLField, DSLFragment, DSLFragmentSpread, @@ -1297,6 +1298,22 @@ def test_legacy_fragment_with_variables(ds): assert print_ast(query.document) == expected +@pytest.mark.parametrize( + "shortcut,expected", + [ + ("__typename", DSLMetaField("__typename")), + ("__schema", DSLMetaField("__schema")), + ("__type", DSLMetaField("__type")), + ("...", DSLInlineFragment()), + ("@skip", DSLDirective(name="skip", dsl_schema=DSLSchema(StarWarsSchema))), + ], +) +def test_dsl_schema_call_shortcuts(ds, shortcut, expected): + actual = ds(shortcut) + assert getattr(actual, "name", None) == getattr(expected, "name", None) + assert isinstance(actual, type(expected)) + + def test_dsl_schema_call_validation(ds): with pytest.raises(ValueError, match="(?i)unsupported shortcut"): ds("foo") From 316967026883f122e8f19aa8da187fc132eccd20 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 5 Sep 2025 16:06:07 +0200 Subject: [PATCH 232/239] Bump version number to 4.2.0b0 --- gql/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/__version__.py b/gql/__version__.py index b672be1c..dcb01d6e 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "4.1.0b0" +__version__ = "4.2.0b0" From 4d48c2146979db559e489c9fa55da48367ab75f4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 5 Sep 2025 23:01:22 +0200 Subject: [PATCH 233/239] Fix python 3.9 dev dependencies (#568) --- gql/transport/aiohttp.py | 2 +- gql/transport/httpx.py | 2 +- gql/transport/requests.py | 2 +- setup.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index e3bfdb3b..ab26bd03 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -173,7 +173,7 @@ def _prepare_request( upload_files: bool = False, ) -> Dict[str, Any]: - payload: Dict | List + payload: Union[Dict, List] if isinstance(request, GraphQLRequest): payload = request.payload else: diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 0a338639..7143f263 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -66,7 +66,7 @@ def _prepare_request( upload_files: bool = False, ) -> Dict[str, Any]: - payload: Dict | List + payload: Union[Dict, List] if isinstance(request, GraphQLRequest): payload = request.payload else: diff --git a/gql/transport/requests.py b/gql/transport/requests.py index a29f7f0f..8311c036 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -147,7 +147,7 @@ def _prepare_request( upload_files: bool = False, ) -> Dict[str, Any]: - payload: Dict | List + payload: Union[Dict, List] if isinstance(request, GraphQLRequest): payload = request.payload else: diff --git a/setup.py b/setup.py index 3db1c9f8..58d3387c 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,8 @@ "sphinx>=7.0.0,<8;python_version<='3.9'", "sphinx>=8.1.0,<9;python_version>'3.9'", "sphinx_rtd_theme>=3.0.2,<4", - "sphinx-argparse==0.5.2", + "sphinx-argparse==0.5.2; python_version>='3.10'", + "sphinx-argparse==0.4.0; python_version<'3.10'", "types-aiofiles", "types-requests", ] + tests_requires From a3a4597ca56a84c8c6115b317ab8ed5fad20fdf7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 5 Sep 2025 23:32:29 +0200 Subject: [PATCH 234/239] Refactor DSL code to use Self type (#567) --- gql/dsl.py | 450 +++++++++++++++++++++++++++-------------------------- setup.py | 1 + 2 files changed, 229 insertions(+), 222 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 2e6d3967..2dd98d31 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -5,6 +5,7 @@ import logging import re +import sys from abc import ABC, abstractmethod from math import isfinite from typing import ( @@ -83,6 +84,11 @@ from .graphql_request import GraphQLRequest from .utils import to_camel_case +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + log = logging.getLogger(__name__) _re_integer_string = re.compile("^-?(?:0|[1-9][0-9]*)$") @@ -230,61 +236,6 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: raise TypeError(f"Unexpected input type: {inspect(type_)}.") -def dsl_gql( - *operations: "DSLExecutable", **operations_with_name: "DSLExecutable" -) -> GraphQLRequest: - r"""Given arguments instances of :class:`DSLExecutable` - containing GraphQL operations or fragments, - generate a Document which can be executed later in a - gql client or a gql session. - - Similar to the :func:`gql.gql` function but instead of parsing a python - string to describe the request, we are using operations which have been generated - dynamically using instances of :class:`DSLField`, generated - by instances of :class:`DSLType` which themselves originated from - a :class:`DSLSchema` class. - - :param \*operations: the GraphQL operations and fragments - :type \*operations: DSLQuery, DSLMutation, DSLSubscription, DSLFragment - :param \**operations_with_name: the GraphQL operations with an operation name - :type \**operations_with_name: DSLQuery, DSLMutation, DSLSubscription - - :return: a :class:`GraphQLRequest ` - which can be later executed or subscribed by a - :class:`Client `, by an - :class:`async session ` or by a - :class:`sync session ` - - :raises TypeError: if an argument is not an instance of :class:`DSLExecutable` - :raises AttributeError: if a type has not been provided in a :class:`DSLFragment` - """ - - # Concatenate operations without and with name - all_operations: Tuple["DSLExecutable", ...] = ( - *operations, - *(operation for operation in operations_with_name.values()), - ) - - # Set the operation name - for name, operation in operations_with_name.items(): - operation.name = name - - # Check the type - for operation in all_operations: - if not isinstance(operation, DSLExecutable): - raise TypeError( - "Operations should be instances of DSLExecutable " - "(DSLQuery, DSLMutation, DSLSubscription or DSLFragment).\n" - f"Received: {type(operation)}." - ) - - document = DocumentNode( - definitions=[operation.executable_ast for operation in all_operations] - ) - - return GraphQLRequest(document) - - class DSLSchema: """The DSLSchema is the root of the DSL code. @@ -378,78 +329,6 @@ def __getattr__(self, name: str) -> "DSLType": return DSLType(type_def, self) -class DSLSelector(ABC): - """DSLSelector is an abstract class which defines the - :meth:`select ` method to select - children fields in the query. - - Inherited by - :class:`DSLRootFieldSelector `, - :class:`DSLFieldSelector ` - :class:`DSLFragmentSelector ` - """ - - selection_set: SelectionSetNode - - def __init__( - self, - *fields: "DSLSelectable", - **fields_with_alias: "DSLSelectableWithAlias", - ): - """:meta private:""" - self.selection_set = SelectionSetNode(selections=()) - - if fields or fields_with_alias: - self.select(*fields, **fields_with_alias) - - @abstractmethod - def is_valid_field(self, field: "DSLSelectable") -> bool: - raise NotImplementedError( - "Any DSLSelector subclass must have a is_valid_field method" - ) # pragma: no cover - - def select( - self, - *fields: "DSLSelectable", - **fields_with_alias: "DSLSelectableWithAlias", - ) -> Any: - r"""Select the fields which should be added. - - :param \*fields: fields or fragments - :type \*fields: DSLSelectable - :param \**fields_with_alias: fields or fragments with alias as key - :type \**fields_with_alias: DSLSelectable - - :raises TypeError: if an argument is not an instance of :class:`DSLSelectable` - :raises graphql.error.GraphQLError: if an argument is not a valid field - """ - # Concatenate fields without and with alias - added_fields: Tuple["DSLSelectable", ...] = DSLField.get_aliased_fields( - fields, fields_with_alias - ) - - # Check that each field is valid - for field in added_fields: - if not isinstance(field, DSLSelectable): - raise TypeError( - "Fields should be instances of DSLSelectable. " - f"Received: {type(field)}" - ) - - if not self.is_valid_field(field): - raise GraphQLError(f"Invalid field for {self!r}: {field!r}") - - # Get a list of AST Nodes for each added field - added_selections: Tuple[ - Union[FieldNode, InlineFragmentNode, FragmentSpreadNode], ... - ] = tuple(field.ast_field for field in added_fields) - - # Update the current selection list with new selections - self.selection_set.selections = self.selection_set.selections + added_selections - - log.debug(f"Added fields: {added_fields} in {self!r}") - - class DSLDirective: """The DSLDirective represents a GraphQL directive for the DSL code. @@ -457,7 +336,7 @@ class DSLDirective: behavior in a GraphQL document. """ - def __init__(self, name: str, dsl_schema: "DSLSchema"): + def __init__(self, name: str, dsl_schema: DSLSchema): r"""Initialize the DSLDirective with the given name and arguments. :param name: the name of the directive @@ -510,7 +389,7 @@ def name(self) -> str: """Get the directive name.""" return self.ast_directive.name.value - def __call__(self, **kwargs: Any) -> "DSLDirective": + def __call__(self, **kwargs: Any) -> Self: """Add arguments by calling the directive like a function. :param kwargs: directive arguments @@ -518,7 +397,7 @@ def __call__(self, **kwargs: Any) -> "DSLDirective": """ return self.args(**kwargs) - def args(self, **kwargs: Any) -> "DSLDirective": + def args(self, **kwargs: Any) -> Self: r"""Set the arguments of a directive The arguments are parsed to be stored in the AST of this field. @@ -584,7 +463,7 @@ def __init__(self, *args, **kwargs): self._directives = () @abstractmethod - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if a directive is valid for this DSL element. :param directive: The DSLDirective to validate @@ -594,7 +473,7 @@ def is_valid_directive(self, directive: "DSLDirective") -> bool: "Any DSLDirectable concrete class must have an is_valid_directive method" ) # pragma: no cover - def directives(self, *directives: DSLDirective) -> Any: + def directives(self, *directives: DSLDirective) -> Self: r"""Add directives to this DSL element. :param \*directives: DSLDirective instances to add @@ -661,6 +540,138 @@ def directives_ast(self) -> Tuple[DirectiveNode, ...]: return tuple(directive.ast_directive for directive in self._directives) +class DSLSelectable(DSLDirectable): + """DSLSelectable is an abstract class which indicates that + the subclasses can be used as arguments of the + :meth:`select ` method. + + Inherited by + :class:`DSLField `, + :class:`DSLFragment ` + :class:`DSLInlineFragment ` + """ + + ast_field: Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] + + @staticmethod + def get_aliased_fields( + fields: Iterable["DSLSelectable"], + fields_with_alias: Dict[str, "DSLSelectableWithAlias"], + ) -> Tuple["DSLSelectable", ...]: + """ + :meta private: + + Concatenate all the fields (with or without alias) in a Tuple. + + Set the requested alias for the fields with alias. + """ + + return ( + *fields, + *(field.alias(alias) for alias, field in fields_with_alias.items()), + ) + + def __str__(self) -> str: + return print_ast(self.ast_field) + + +class DSLSelectableWithAlias(DSLSelectable): + """DSLSelectableWithAlias is an abstract class which indicates that + the subclasses can be selected with an alias. + """ + + ast_field: FieldNode + + def alias(self, alias: str) -> Self: + """Set an alias + + .. note:: + You can also pass the alias directly at the + :meth:`select ` method. + :code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to: + :code:`ds.Query.human.select(ds.Character.name.alias("my_name"))` + + :param alias: the alias + :type alias: str + :return: itself + """ + + self.ast_field.alias = NameNode(value=alias) + return self + + +class DSLSelector(ABC): + """DSLSelector is an abstract class which defines the + :meth:`select ` method to select + children fields in the query. + + Inherited by + :class:`DSLRootFieldSelector `, + :class:`DSLFieldSelector ` + :class:`DSLFragmentSelector ` + """ + + selection_set: SelectionSetNode + + def __init__( + self, + *fields: DSLSelectable, + **fields_with_alias: DSLSelectableWithAlias, + ): + """:meta private:""" + self.selection_set = SelectionSetNode(selections=()) + + if fields or fields_with_alias: + self.select(*fields, **fields_with_alias) + + @abstractmethod + def is_valid_field(self, field: DSLSelectable) -> bool: + raise NotImplementedError( + "Any DSLSelector subclass must have a is_valid_field method" + ) # pragma: no cover + + def select( + self, + *fields: DSLSelectable, + **fields_with_alias: DSLSelectableWithAlias, + ) -> Any: + r"""Select the fields which should be added. + + :param \*fields: fields or fragments + :type \*fields: DSLSelectable + :param \**fields_with_alias: fields or fragments with alias as key + :type \**fields_with_alias: DSLSelectable + + :raises TypeError: if an argument is not an instance of :class:`DSLSelectable` + :raises graphql.error.GraphQLError: if an argument is not a valid field + """ + # Concatenate fields without and with alias + added_fields: Tuple[DSLSelectable, ...] = DSLField.get_aliased_fields( + fields, fields_with_alias + ) + + # Check that each field is valid + for field in added_fields: + if not isinstance(field, DSLSelectable): + raise TypeError( + "Fields should be instances of DSLSelectable. " + f"Received: {type(field)}" + ) + + if not self.is_valid_field(field): + raise GraphQLError(f"Invalid field for {self!r}: {field!r}") + + # Get a list of AST Nodes for each added field + added_selections: Tuple[ + Union[FieldNode, InlineFragmentNode, FragmentSpreadNode], ... + ] = tuple(field.ast_field for field in added_fields) + + # Update the current selection list with new selections + self.selection_set.selections = self.selection_set.selections + added_selections + + log.debug(f"Added fields: {added_fields} in {self!r}") + + class DSLExecutable(DSLSelector, DSLDirectable): """Interface for the root elements which can be executed in the :func:`dsl_gql ` function @@ -684,8 +695,8 @@ def executable_ast(self): def __init__( self, - *fields: "DSLSelectable", - **fields_with_alias: "DSLSelectableWithAlias", + *fields: DSLSelectable, + **fields_with_alias: DSLSelectableWithAlias, ): r"""Given arguments of type :class:`DSLSelectable` containing GraphQL requests, generate an operation which can be converted to a Document @@ -722,7 +733,7 @@ class DSLRootFieldSelector(DSLSelector): :class:`DSLOperation ` """ - def is_valid_field(self, field: "DSLSelectable") -> bool: + def is_valid_field(self, field: DSLSelectable) -> bool: """Check that a field is valid for a root field. For operations, the fields arguments should be fields of root GraphQL types @@ -799,7 +810,7 @@ def __repr__(self) -> str: class DSLQuery(DSLOperation): operation_type = OperationType.QUERY - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Query operations.""" return DirectiveLocation.QUERY in directive.directive_def.locations @@ -807,7 +818,7 @@ def is_valid_directive(self, directive: "DSLDirective") -> bool: class DSLMutation(DSLOperation): operation_type = OperationType.MUTATION - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Mutation operations.""" return DirectiveLocation.MUTATION in directive.directive_def.locations @@ -815,7 +826,7 @@ def is_valid_directive(self, directive: "DSLDirective") -> bool: class DSLSubscription(DSLOperation): operation_type = OperationType.SUBSCRIPTION - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Subscription operations.""" return DirectiveLocation.SUBSCRIPTION in directive.directive_def.locations @@ -854,16 +865,16 @@ def to_ast_type(self, type_: GraphQLInputType) -> TypeNode: return NamedTypeNode(name=NameNode(value=type_.name)) - def set_type(self, type_: GraphQLInputType) -> "DSLVariable": + def set_type(self, type_: GraphQLInputType) -> Self: self.type = type_ self.ast_variable_type = self.to_ast_type(type_) return self - def default(self, default_value: Any) -> "DSLVariable": + def default(self, default_value: Any) -> Self: self.default_value = default_value return self - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Variable definitions.""" for arg in directive.ast_directive.arguments: if isinstance(arg.value, VariableNode): @@ -894,7 +905,7 @@ def __init__(self): """:meta private:""" self.variables: Dict[str, DSLVariable] = {} - def __getattr__(self, name: str) -> "DSLVariable": + def __getattr__(self, name: str) -> DSLVariable: """Attributes of the DSLVariableDefinitions class are generated automatically with this dunder method in order to generate instances of :class:`DSLVariable` @@ -989,41 +1000,6 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" -class DSLSelectable(DSLDirectable): - """DSLSelectable is an abstract class which indicates that - the subclasses can be used as arguments of the - :meth:`select ` method. - - Inherited by - :class:`DSLField `, - :class:`DSLFragment ` - :class:`DSLInlineFragment ` - """ - - ast_field: Union[FieldNode, InlineFragmentNode, FragmentSpreadNode] - - @staticmethod - def get_aliased_fields( - fields: Iterable["DSLSelectable"], - fields_with_alias: Dict[str, "DSLSelectableWithAlias"], - ) -> Tuple["DSLSelectable", ...]: - """ - :meta private: - - Concatenate all the fields (with or without alias) in a Tuple. - - Set the requested alias for the fields with alias. - """ - - return ( - *fields, - *(field.alias(alias) for alias, field in fields_with_alias.items()), - ) - - def __str__(self) -> str: - return print_ast(self.ast_field) - - class DSLFragmentSelector(DSLSelector): """Class used to define the :meth:`is_valid_field ` method @@ -1090,31 +1066,6 @@ def is_valid_field(self, field: DSLSelectable) -> bool: return False -class DSLSelectableWithAlias(DSLSelectable): - """DSLSelectableWithAlias is an abstract class which indicates that - the subclasses can be selected with an alias. - """ - - ast_field: FieldNode - - def alias(self, alias: str) -> "DSLSelectableWithAlias": - """Set an alias - - .. note:: - You can also pass the alias directly at the - :meth:`select ` method. - :code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to: - :code:`ds.Query.human.select(ds.Character.name.alias("my_name"))` - - :param alias: the alias - :type alias: str - :return: itself - """ - - self.ast_field.alias = NameNode(value=alias) - return self - - class DSLField(DSLSelectableWithAlias, DSLFieldSelector): """The DSLField represents a GraphQL field for the DSL code. @@ -1168,10 +1119,10 @@ def name(self): """:meta private:""" return self.ast_field.name.value - def __call__(self, **kwargs: Any) -> "DSLField": + def __call__(self, **kwargs: Any) -> Self: return self.args(**kwargs) - def args(self, **kwargs: Any) -> "DSLField": + def args(self, **kwargs: Any) -> Self: r"""Set the arguments of a field The arguments are parsed to be stored in the AST of this field. @@ -1217,8 +1168,8 @@ def _get_argument(self, name: str) -> GraphQLArgument: return arg def select( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" - ) -> "DSLField": + self, *fields: DSLSelectable, **fields_with_alias: DSLSelectableWithAlias + ) -> Self: """Calling :meth:`select ` method with corrected typing hints """ @@ -1228,14 +1179,14 @@ def select( return self - def directives(self, *directives: DSLDirective) -> "DSLField": + def directives(self, *directives: DSLDirective) -> Self: """Add directives to this field.""" super().directives(*directives) self.ast_field.directives = self.directives_ast return self - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Field locations.""" return DirectiveLocation.FIELD in directive.directive_def.locations @@ -1277,7 +1228,7 @@ def __init__(self, name: str): super().__init__(name, self.meta_type, field) - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for MetaField locations (same as Field).""" return DirectiveLocation.FIELD in directive.directive_def.locations @@ -1290,8 +1241,8 @@ class DSLInlineFragment(DSLSelectable, DSLFragmentSelector): def __init__( self, - *fields: "DSLSelectable", - **fields_with_alias: "DSLSelectableWithAlias", + *fields: DSLSelectable, + **fields_with_alias: DSLSelectableWithAlias, ): r"""Initialize the DSLInlineFragment. @@ -1309,8 +1260,8 @@ def __init__( DSLDirectable.__init__(self) def select( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" - ) -> "DSLInlineFragment": + self, *fields: DSLSelectable, **fields_with_alias: DSLSelectableWithAlias + ) -> Self: """Calling :meth:`select ` method with corrected typing hints """ @@ -1319,7 +1270,7 @@ def select( return self - def on(self, type_condition: DSLType) -> "DSLInlineFragment": + def on(self, type_condition: DSLType) -> Self: """Provides the GraphQL type of this inline fragment.""" self._type = type_condition._type @@ -1328,7 +1279,7 @@ def on(self, type_condition: DSLType) -> "DSLInlineFragment": ) return self - def directives(self, *directives: DSLDirective) -> "DSLInlineFragment": + def directives(self, *directives: DSLDirective) -> Self: """Add directives to this inline fragment. Inline fragments support all directive types through auto-validation. @@ -1347,7 +1298,7 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__}{type_info}>" - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Inline Fragment locations.""" return DirectiveLocation.INLINE_FRAGMENT in directive.directive_def.locations @@ -1381,7 +1332,7 @@ def name(self) -> str: """:meta private:""" return self.ast_field.name.value - def directives(self, *directives: DSLDirective) -> "DSLFragmentSpread": + def directives(self, *directives: DSLDirective) -> Self: """Add directives to this fragment spread. Fragment spreads support all directive types through auto-validation. @@ -1390,7 +1341,7 @@ def directives(self, *directives: DSLDirective) -> "DSLFragmentSpread": self.ast_field.directives = self.directives_ast return self - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Fragment Spread locations.""" return DirectiveLocation.FRAGMENT_SPREAD in directive.directive_def.locations @@ -1444,8 +1395,8 @@ def spread(self) -> DSLFragmentSpread: return DSLFragmentSpread(self) def select( - self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" - ) -> "DSLFragment": + self, *fields: DSLSelectable, **fields_with_alias: DSLSelectableWithAlias + ) -> Self: """Calling :meth:`select ` method with corrected typing hints """ @@ -1458,7 +1409,7 @@ def select( return self - def on(self, type_condition: DSLType) -> "DSLFragment": + def on(self, type_condition: DSLType) -> Self: """Provides the GraphQL type of this fragment. :param type_condition: the provided type @@ -1506,7 +1457,7 @@ def executable_ast(self) -> FragmentDefinitionNode: directives=self.directives_ast, ) - def is_valid_directive(self, directive: "DSLDirective") -> bool: + def is_valid_directive(self, directive: DSLDirective) -> bool: """Check if directive is valid for Fragment Definition locations.""" return ( DirectiveLocation.FRAGMENT_DEFINITION in directive.directive_def.locations @@ -1514,3 +1465,58 @@ def is_valid_directive(self, directive: "DSLDirective") -> bool: def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.name!s}>" + + +def dsl_gql( + *operations: DSLExecutable, **operations_with_name: DSLExecutable +) -> GraphQLRequest: + r"""Given arguments instances of :class:`DSLExecutable` + containing GraphQL operations or fragments, + generate a Document which can be executed later in a + gql client or a gql session. + + Similar to the :func:`gql.gql` function but instead of parsing a python + string to describe the request, we are using operations which have been generated + dynamically using instances of :class:`DSLField`, generated + by instances of :class:`DSLType` which themselves originated from + a :class:`DSLSchema` class. + + :param \*operations: the GraphQL operations and fragments + :type \*operations: DSLQuery, DSLMutation, DSLSubscription, DSLFragment + :param \**operations_with_name: the GraphQL operations with an operation name + :type \**operations_with_name: DSLQuery, DSLMutation, DSLSubscription + + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` + + :raises TypeError: if an argument is not an instance of :class:`DSLExecutable` + :raises AttributeError: if a type has not been provided in a :class:`DSLFragment` + """ + + # Concatenate operations without and with name + all_operations: Tuple[DSLExecutable, ...] = ( + *operations, + *(operation for operation in operations_with_name.values()), + ) + + # Set the operation name + for name, operation in operations_with_name.items(): + operation.name = name + + # Check the type + for operation in all_operations: + if not isinstance(operation, DSLExecutable): + raise TypeError( + "Operations should be instances of DSLExecutable " + "(DSLQuery, DSLMutation, DSLSubscription or DSLFragment).\n" + f"Received: {type(operation)}." + ) + + document = DocumentNode( + definitions=[operation.executable_ast for operation in all_operations] + ) + + return GraphQLRequest(document) diff --git a/setup.py b/setup.py index 58d3387c..39a6e453 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ "yarl>=1.6,<2.0", "backoff>=1.11.1,<3.0", "anyio>=3.0,<5", + "typing_extensions>=4.0.0; python_version<'3.11'", ] console_scripts = [ From ccf3f2ab2f91e4764f8207bc8bcf72c5800c7213 Mon Sep 17 00:00:00 2001 From: Evgeniy Martynenko <79136602+Enimalojd@users.noreply.github.com> Date: Wed, 12 Nov 2025 23:49:10 +0300 Subject: [PATCH 235/239] chore: Replace archived backoff with tenacity (#573) --- docs/advanced/async_advanced_usage.rst | 24 ++++-- docs/advanced/async_permanent_session.rst | 77 +++++++++++++------ .../reconnecting_mutation_http.py | 10 +-- .../code_examples/reconnecting_mutation_ws.py | 10 +-- gql/client.py | 36 +++++---- setup.py | 2 +- 6 files changed, 102 insertions(+), 57 deletions(-) diff --git a/docs/advanced/async_advanced_usage.rst b/docs/advanced/async_advanced_usage.rst index 4164cb37..78952d0f 100644 --- a/docs/advanced/async_advanced_usage.rst +++ b/docs/advanced/async_advanced_usage.rst @@ -6,7 +6,7 @@ Async advanced usage It is possible to send multiple GraphQL queries (query, mutation or subscription) in parallel, on the same websocket connection, using asyncio tasks. -In order to retry in case of connection failure, we can use the great `backoff`_ module. +In order to retry in case of connection failure, we can use the great `tenacity`_ module. .. code-block:: python @@ -28,10 +28,22 @@ In order to retry in case of connection failure, we can use the great `backoff`_ async for result in session.subscribe(subscription2): print(result) - # Then create a couroutine which will connect to your API and run all your queries as tasks. - # We use a `backoff` decorator to reconnect using exponential backoff in case of connection failure. - - @backoff.on_exception(backoff.expo, Exception, max_time=300) + # Then create a couroutine which will connect to your API and run all your + # queries as tasks. We use a `tenacity` retry decorator to reconnect using + # exponential backoff in case of connection failure. + + from tenacity import ( + retry, + retry_if_exception_type, + stop_after_delay, + wait_exponential, + ) + + @retry( + retry=retry_if_exception_type(Exception), + stop=stop_after_delay(300), # max_time in seconds + wait=wait_exponential(), + ) async def graphql_connection(): transport = WebsocketsTransport(url="wss://YOUR_URL") @@ -54,4 +66,4 @@ Subscriptions tasks can be stopped at any time by running task.cancel() -.. _backoff: https://github.com/litl/backoff +.. _tenacity: https://github.com/jd/tenacity diff --git a/docs/advanced/async_permanent_session.rst b/docs/advanced/async_permanent_session.rst index e42010cf..885d2fd2 100644 --- a/docs/advanced/async_permanent_session.rst +++ b/docs/advanced/async_permanent_session.rst @@ -36,19 +36,22 @@ Retries Connection retries ^^^^^^^^^^^^^^^^^^ -With :code:`reconnecting=True`, gql will use the `backoff`_ module to repeatedly try to connect with -exponential backoff and jitter with a maximum delay of 60 seconds by default. +With :code:`reconnecting=True`, gql will use the `tenacity`_ module to repeatedly +try to connect with exponential backoff and jitter with a maximum delay of +60 seconds by default. You can change the default reconnecting profile by providing your own -backoff decorator to the :code:`retry_connect` argument. +retry decorator (from tenacity) to the :code:`retry_connect` argument. .. code-block:: python + from tenacity import retry, retry_if_exception_type, wait_exponential + # Here wait maximum 5 minutes between connection retries - retry_connect = backoff.on_exception( - backoff.expo, # wait generator (here: exponential backoff) - Exception, # which exceptions should cause a retry (here: everything) - max_value=300, # max wait time in seconds + retry_connect = retry( + # which exceptions should cause a retry (here: everything) + retry=retry_if_exception_type(Exception), + wait=wait_exponential(max=300), # max wait time in seconds ) session = await client.connect_async( reconnecting=True, @@ -66,32 +69,49 @@ There is no retry in case of a :code:`TransportQueryError` exception as it indic the connection to the backend is working correctly. You can change the default execute retry profile by providing your own -backoff decorator to the :code:`retry_execute` argument. +retry decorator (from tenacity) to the :code:`retry_execute` argument. .. code-block:: python + from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, + ) + # Here Only 3 tries for execute calls - retry_execute = backoff.on_exception( - backoff.expo, - Exception, - max_tries=3, + retry_execute = retry( + retry=retry_if_exception_type(Exception), + stop=stop_after_attempt(3), + wait=wait_exponential(), ) session = await client.connect_async( reconnecting=True, retry_execute=retry_execute, ) -If you don't want any retry on the execute calls, you can disable the retries with :code:`retry_execute=False` +If you don't want any retry on the execute calls, you can disable the retries +with :code:`retry_execute=False` .. note:: If you want to retry even with :code:`TransportQueryError` exceptions, - then you need to make your own backoff decorator on your own method: + then you need to make your own retry decorator (from tenacity) on your own method: .. code-block:: python - @backoff.on_exception(backoff.expo, - Exception, - max_tries=3) + from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, + ) + + @retry( + retry=retry_if_exception_type(Exception), + stop=stop_after_attempt(3), + wait=wait_exponential(), + ) async def execute_with_retry(session, query): return await session.execute(query) @@ -100,14 +120,25 @@ Subscription retries There is no :code:`retry_subscribe` as it is not feasible with async generators. If you want retries for your subscriptions, then you can do it yourself -with backoff decorators on your methods. +with retry decorators (from tenacity) on your methods. .. code-block:: python - @backoff.on_exception(backoff.expo, - Exception, - max_tries=3, - giveup=lambda e: isinstance(e, TransportQueryError)) + from tenacity import ( + retry, + retry_if_exception_type, + retry_unless_exception_type, + stop_after_attempt, + wait_exponential, + ) + from gql.transport.exceptions import TransportQueryError + + @retry( + retry=retry_if_exception_type(Exception) + & retry_unless_exception_type(TransportQueryError), + stop=stop_after_attempt(3), + wait=wait_exponential(), + ) async def execute_subscription1(session): async for result in session.subscribe(subscription1): print(result) @@ -123,4 +154,4 @@ Console example .. literalinclude:: ../code_examples/console_async.py .. _difficult to manage: https://github.com/graphql-python/gql/issues/179 -.. _backoff: https://github.com/litl/backoff +.. _tenacity: https://github.com/jd/tenacity diff --git a/docs/code_examples/reconnecting_mutation_http.py b/docs/code_examples/reconnecting_mutation_http.py index 5deb5063..1eaf0111 100644 --- a/docs/code_examples/reconnecting_mutation_http.py +++ b/docs/code_examples/reconnecting_mutation_http.py @@ -1,7 +1,7 @@ import asyncio import logging -import backoff +from tenacity import retry, retry_if_exception_type, wait_exponential from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport @@ -17,11 +17,9 @@ async def main(): client = Client(transport=transport) - retry_connect = backoff.on_exception( - backoff.expo, - Exception, - max_value=10, - jitter=None, + retry_connect = retry( + retry=retry_if_exception_type(Exception), + wait=wait_exponential(max=10), ) session = await client.connect_async(reconnecting=True, retry_connect=retry_connect) diff --git a/docs/code_examples/reconnecting_mutation_ws.py b/docs/code_examples/reconnecting_mutation_ws.py index d7e7cfe2..4d083d54 100644 --- a/docs/code_examples/reconnecting_mutation_ws.py +++ b/docs/code_examples/reconnecting_mutation_ws.py @@ -1,7 +1,7 @@ import asyncio import logging -import backoff +from tenacity import retry, retry_if_exception_type, wait_exponential from gql import Client, gql from gql.transport.websockets import WebsocketsTransport @@ -17,11 +17,9 @@ async def main(): client = Client(transport=transport) - retry_connect = backoff.on_exception( - backoff.expo, - Exception, - max_value=10, - jitter=None, + retry_connect = retry( + retry=retry_if_exception_type(Exception), + wait=wait_exponential(max=10), ) session = await client.connect_async(reconnecting=True, retry_connect=retry_connect) diff --git a/gql/client.py b/gql/client.py index e17a0b7c..93c1078c 100644 --- a/gql/client.py +++ b/gql/client.py @@ -21,7 +21,6 @@ overload, ) -import backoff from anyio import fail_after from graphql import ( ExecutionResult, @@ -31,6 +30,13 @@ parse, validate, ) +from tenacity import ( + retry, + retry_if_exception_type, + retry_unless_exception_type, + stop_after_attempt, + wait_exponential, +) from .graphql_request import GraphQLRequest, support_deprecated_request from .transport.async_transport import AsyncTransport @@ -1902,11 +1908,12 @@ def __init__( """ :param client: the :class:`client ` used. :param retry_connect: Either a Boolean to activate/deactivate the retries - for the connection to the transport OR a backoff decorator to - provide specific retries parameters for the connections. + for the connection to the transport OR a retry decorator + (e.g., from tenacity) to provide specific retries parameters + for the connections. :param retry_execute: Either a Boolean to activate/deactivate the retries - for the execute method OR a backoff decorator to - provide specific retries parameters for this method. + for the execute method OR a retry decorator (e.g., from tenacity) + to provide specific retries parameters for this method. """ self.client = client self._connect_task = None @@ -1917,10 +1924,9 @@ def __init__( if retry_connect is True: # By default, retry again and again, with maximum 60 seconds # between retries - self.retry_connect = backoff.on_exception( - backoff.expo, - Exception, - max_value=60, + self.retry_connect = retry( + retry=retry_if_exception_type(Exception), + wait=wait_exponential(max=60), ) elif retry_connect is False: self.retry_connect = lambda e: e @@ -1930,11 +1936,11 @@ def __init__( if retry_execute is True: # By default, retry 5 times, except if we receive a TransportQueryError - self.retry_execute = backoff.on_exception( - backoff.expo, - Exception, - max_tries=5, - giveup=lambda e: isinstance(e, TransportQueryError), + self.retry_execute = retry( + retry=retry_if_exception_type(Exception) + & retry_unless_exception_type(TransportQueryError), + stop=stop_after_attempt(5), + wait=wait_exponential(), ) elif retry_execute is False: self.retry_execute = lambda e: e @@ -1943,7 +1949,7 @@ def __init__( self.retry_execute = retry_execute # Creating the _execute_with_retries and _connect_with_retries methods - # using the provided backoff decorators + # using the provided retry decorators self._execute_with_retries = self.retry_execute(self._execute_once) self._connect_with_retries = self.retry_connect(self.transport.connect) diff --git a/setup.py b/setup.py index 39a6e453..e0764d5d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ install_requires = [ "graphql-core>=3.3.0a3,<3.4", "yarl>=1.6,<2.0", - "backoff>=1.11.1,<3.0", + "tenacity>=9.1.2,<10.0", "anyio>=3.0,<5", "typing_extensions>=4.0.0; python_version<'3.11'", ] From 70be65826a319c34a41f99c70050dd3667104e1c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 12 Nov 2025 22:01:25 +0100 Subject: [PATCH 236/239] show @oneOf in introspection query (#569) --- gql/cli.py | 5 +- gql/utilities/get_introspection_query_ast.py | 9 ++++ setup.py | 1 + tests/starwars/test_dsl.py | 48 ++++++++++++++++++++ tests/test_cli.py | 2 + 5 files changed, 64 insertions(+), 1 deletion(-) diff --git a/gql/cli.py b/gql/cli.py index 37be3656..01dfb20f 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -140,7 +140,9 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: - input_value_deprecation:false to omit deprecated input fields - specified_by_url:true - schema_description:true - - directive_is_repeatable:true""" + - directive_is_repeatable:true + - input_object_one_of:true + """ ), dest="schema_download", ) @@ -430,6 +432,7 @@ def get_introspection_args(args: Namespace) -> Dict: "directive_is_repeatable", "schema_description", "input_value_deprecation", + "input_object_one_of", ] if args.schema_download is not None: diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index 0422a225..8d981600 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -11,6 +11,8 @@ def get_introspection_query_ast( directive_is_repeatable: bool = False, schema_description: bool = False, input_value_deprecation: bool = True, + input_object_one_of: bool = False, + *, type_recursion_level: int = 7, ) -> DocumentNode: """Get a query for introspection as a document using the DSL module. @@ -68,6 +70,13 @@ def get_introspection_query_ast( ) if descriptions: fragment_FullType.select(ds.__Type.description) + if input_object_one_of: + try: + fragment_FullType.select(ds.__Type.isOneOf) + except AttributeError: # pragma: no cover + raise NotImplementedError( + "isOneOf is only supported from graphql-core version 3.3.0a7" + ) if specified_by_url: fragment_FullType.select(ds.__Type.specifiedByURL) diff --git a/setup.py b/setup.py index e0764d5d..e7e9d6de 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ tests_requires = [ "parse==1.20.2", + "packaging>=21.0", "pytest==8.3.4", "pytest-asyncio==0.25.3", "pytest-console-scripts==1.4.1", diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 7f042a07..ca9137a7 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -15,11 +15,15 @@ NonNullTypeNode, NullValueNode, Undefined, +) +from graphql import __version__ as graphql_version +from graphql import ( build_ast_schema, parse, print_ast, ) from graphql.utilities import get_introspection_query +from packaging import version from gql import Client, gql from gql.dsl import ( @@ -1084,6 +1088,50 @@ def test_get_introspection_query_ast(option): ) +@pytest.mark.skipif( + version.parse(graphql_version) < version.parse("3.3.0a7"), + reason="Requires graphql-core >= 3.3.0a7", +) +@pytest.mark.parametrize("option", [True, False]) +def test_get_introspection_query_ast_is_one_of(option): + + introspection_query = print_ast( + gql( + get_introspection_query( + input_value_deprecation=option, + ) + ).document + ) + + # Because the option does not exist yet in graphql-core, + # we add it manually here for now + if option: + introspection_query = introspection_query.replace( + "fields", + "isOneOf\n fields", + ) + + dsl_introspection_query = get_introspection_query_ast( + input_value_deprecation=option, + input_object_one_of=option, + type_recursion_level=9, + ) + + assert introspection_query == print_ast(dsl_introspection_query) + + +@pytest.mark.skipif( + version.parse(graphql_version) >= version.parse("3.3.0a7"), + reason="Test only for older graphql-core versions < 3.3.0a7", +) +def test_get_introspection_query_ast_is_one_of_not_implemented_yet(): + + with pytest.raises(NotImplementedError): + get_introspection_query_ast( + input_object_one_of=True, + ) + + def test_typename_aliased(ds): query = """ hero { diff --git a/tests/test_cli.py b/tests/test_cli.py index 4c6b7d15..df613afc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -407,6 +407,7 @@ def test_cli_parse_schema_download(parser): "specified_by_url:True", "schema_description:true", "directive_is_repeatable:true", + "input_object_one_of:true", "--print-schema", ] ) @@ -419,6 +420,7 @@ def test_cli_parse_schema_download(parser): "specified_by_url": True, "schema_description": True, "directive_is_repeatable": True, + "input_object_one_of": True, } assert introspection_args == expected_args From a253d43cb5900acbf430c7cd6eceadb3ba729b41 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 12 Nov 2025 22:31:14 +0100 Subject: [PATCH 237/239] Support Python 3.14 (#575) --- .github/workflows/tests.yml | 4 +++- setup.py | 1 + tox.ini | 5 +++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8463ac00..37b381c5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "pypy3.10"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "pypy3.10"] os: [ubuntu-24.04, windows-latest] exclude: - os: windows-latest @@ -22,6 +22,8 @@ jobs: python-version: "3.11" - os: windows-latest python-version: "3.13" + - os: windows-latest + python-version: "3.14" - os: windows-latest python-version: "pypy3.10" diff --git a/setup.py b/setup.py index e7e9d6de..85ddd34c 100644 --- a/setup.py +++ b/setup.py @@ -96,6 +96,7 @@ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: PyPy", ], keywords="api graphql protocol rest relay gql client", diff --git a/tox.ini b/tox.ini index f6d4b48e..21129e3c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{39,310,311,312,313,py3} + py{39,310,311,312,313,314,py3} [gh-actions] python = @@ -10,6 +10,7 @@ python = 3.11: py311 3.12: py312 3.13: py313 + 3.14: py314 pypy-3: pypy3 [testenv] @@ -28,7 +29,7 @@ deps = -e.[test] commands = pip install -U setuptools ; run "tox -- tests -s" to show output for debugging - py{39,310,311,312,313,py3}: pytest {posargs:tests} + py{39,310,311,312,313,314,py3}: pytest {posargs:tests} py{312}: pytest {posargs:tests --cov-report=term-missing --cov=gql} [testenv:black] From df1214ee551c24df877a56b12c154ae909d18eda Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 12 Nov 2025 22:39:07 +0100 Subject: [PATCH 238/239] Bump pytest-asyncio to 1.2.0 (#576) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 85ddd34c..f4508b86 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "parse==1.20.2", "packaging>=21.0", "pytest==8.3.4", - "pytest-asyncio==0.25.3", + "pytest-asyncio==1.2.0", "pytest-console-scripts==1.4.1", "pytest-cov==6.0.0", "vcrpy==7.0.0", From 05a8e98b898d75809f80c7087f13cf23c983c213 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 31 Dec 2025 13:39:38 +0100 Subject: [PATCH 239/239] Fix warning in make docs (#579) --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index 024dd9e6..d62be297 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -93,6 +93,7 @@ nitpick_ignore = [ # graphql-core: should be fixed ('py:class', 'graphql.execution.execute.ExecutionResult'), + ('py:class', 'graphql.execution.incremental_publisher.ExecutionResult'), ('py:class', 'Source'), ('py:class', 'GraphQLSchema'),