diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index e4ce2bd6..e7eff55f 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -425,19 +425,21 @@ async def execute_batch( async def subscribe( self, request: GraphQLRequest, + *, + extra_args: Optional[Dict[str, Any]] = None, ) -> AsyncGenerator[ExecutionResult, None]: """Execute a GraphQL subscription and yield results from multipart response. :param request: GraphQL request to execute + :param extra_args: additional arguments to send to the aiohttp post method :yields: ExecutionResult objects as they arrive in the multipart stream """ if self.session is None: raise TransportClosed("Transport is not connected") - post_args = self._prepare_request(request) + post_args = self._prepare_request(request, extra_args) - # Add headers for multipart subscription - headers = post_args.get("headers", {}) + headers = dict(post_args.get("headers", {})) headers.update( { "Content-Type": "application/json", diff --git a/tests/test_aiohttp_multipart.py b/tests/test_aiohttp_multipart.py index 443e9b60..d71814a4 100644 --- a/tests/test_aiohttp_multipart.py +++ b/tests/test_aiohttp_multipart.py @@ -636,3 +636,31 @@ async def test_aiohttp_multipart_actually_invalid_utf8(multipart_server): # Should skip invalid part and not crash assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_aiohttp_multipart_subscribe_extra_args(multipart_server): + """Test that extra_args are passed through to the post method.""" + from gql.transport.aiohttp import AIOHTTPTransport + + custom_header_received = False + + def check_custom_header(request): + nonlocal custom_header_received + if request.headers.get("X-Custom-Header") == "custom-value": + custom_header_received = True + + parts = create_multipart_response([book1]) + server = await multipart_server(parts, request_handler=check_custom_header) + url = server.make_url("/") + transport = AIOHTTPTransport(url=url) + + query = gql(subscription_str) + + async with Client(transport=transport) as session: + async for result in session.subscribe( + query, extra_args={"headers": {"X-Custom-Header": "custom-value"}} + ): + pass + + assert custom_header_received