Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,21 @@ async def execute_batch(
async def subscribe(
self,
request: GraphQLRequest,
*,
extra_args: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[ExecutionResult, None]:
"""Execute a GraphQL subscription and yield results from multipart response.

:param request: GraphQL request to execute
:param extra_args: additional arguments to send to the aiohttp post method
:yields: ExecutionResult objects as they arrive in the multipart stream
"""
if self.session is None:
raise TransportClosed("Transport is not connected")

post_args = self._prepare_request(request)
post_args = self._prepare_request(request, extra_args)

# Add headers for multipart subscription
headers = post_args.get("headers", {})
headers = dict(post_args.get("headers", {}))
headers.update(
{
"Content-Type": "application/json",
Expand Down
28 changes: 28 additions & 0 deletions tests/test_aiohttp_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,31 @@ async def test_aiohttp_multipart_actually_invalid_utf8(multipart_server):

# Should skip invalid part and not crash
assert len(results) == 0


@pytest.mark.asyncio
async def test_aiohttp_multipart_subscribe_extra_args(multipart_server):
"""Test that extra_args are passed through to the post method."""
from gql.transport.aiohttp import AIOHTTPTransport

custom_header_received = False

def check_custom_header(request):
nonlocal custom_header_received
if request.headers.get("X-Custom-Header") == "custom-value":
custom_header_received = True

parts = create_multipart_response([book1])
server = await multipart_server(parts, request_handler=check_custom_header)
url = server.make_url("/")
transport = AIOHTTPTransport(url=url)

query = gql(subscription_str)

async with Client(transport=transport) as session:
async for result in session.subscribe(
query, extra_args={"headers": {"X-Custom-Header": "custom-value"}}
):
pass

assert custom_header_received
Loading