@@ -184,6 +184,24 @@ def _build_schema_from_introspection(
184184 self .introspection = cast (IntrospectionQuery , execution_result .data )
185185 self .schema = build_client_schema (self .introspection )
186186
187+ @staticmethod
188+ def _get_event_loop () -> asyncio .AbstractEventLoop :
189+ """Get the current asyncio event loop.
190+
191+ Or create a new event loop if there isn't one (in a new Thread).
192+ """
193+ try :
194+ with warnings .catch_warnings ():
195+ warnings .filterwarnings (
196+ "ignore" , message = "There is no current event loop"
197+ )
198+ loop = asyncio .get_event_loop ()
199+ except RuntimeError :
200+ loop = asyncio .new_event_loop ()
201+ asyncio .set_event_loop (loop )
202+
203+ return loop
204+
187205 @overload
188206 def execute_sync (
189207 self ,
@@ -358,6 +376,58 @@ async def execute_async(
358376 ** kwargs ,
359377 )
360378
379+ @overload
380+ async def execute_batch_async (
381+ self ,
382+ requests : List [GraphQLRequest ],
383+ * ,
384+ serialize_variables : Optional [bool ] = None ,
385+ parse_result : Optional [bool ] = None ,
386+ get_execution_result : Literal [False ] = ...,
387+ ** kwargs : Any ,
388+ ) -> List [Dict [str , Any ]]: ... # pragma: no cover
389+
390+ @overload
391+ async def execute_batch_async (
392+ self ,
393+ requests : List [GraphQLRequest ],
394+ * ,
395+ serialize_variables : Optional [bool ] = None ,
396+ parse_result : Optional [bool ] = None ,
397+ get_execution_result : Literal [True ],
398+ ** kwargs : Any ,
399+ ) -> List [ExecutionResult ]: ... # pragma: no cover
400+
401+ @overload
402+ async def execute_batch_async (
403+ self ,
404+ requests : List [GraphQLRequest ],
405+ * ,
406+ serialize_variables : Optional [bool ] = None ,
407+ parse_result : Optional [bool ] = None ,
408+ get_execution_result : bool ,
409+ ** kwargs : Any ,
410+ ) -> Union [List [Dict [str , Any ]], List [ExecutionResult ]]: ... # pragma: no cover
411+
412+ async def execute_batch_async (
413+ self ,
414+ requests : List [GraphQLRequest ],
415+ * ,
416+ serialize_variables : Optional [bool ] = None ,
417+ parse_result : Optional [bool ] = None ,
418+ get_execution_result : bool = False ,
419+ ** kwargs : Any ,
420+ ) -> Union [List [Dict [str , Any ]], List [ExecutionResult ]]:
421+ """:meta private:"""
422+ async with self as session :
423+ return await session .execute_batch (
424+ requests ,
425+ serialize_variables = serialize_variables ,
426+ parse_result = parse_result ,
427+ get_execution_result = get_execution_result ,
428+ ** kwargs ,
429+ )
430+
361431 @overload
362432 def execute (
363433 self ,
@@ -430,17 +500,7 @@ def execute(
430500 """
431501
432502 if isinstance (self .transport , AsyncTransport ):
433- # Get the current asyncio event loop
434- # Or create a new event loop if there isn't one (in a new Thread)
435- try :
436- with warnings .catch_warnings ():
437- warnings .filterwarnings (
438- "ignore" , message = "There is no current event loop"
439- )
440- loop = asyncio .get_event_loop ()
441- except RuntimeError :
442- loop = asyncio .new_event_loop ()
443- asyncio .set_event_loop (loop )
503+ loop = self ._get_event_loop ()
444504
445505 assert not loop .is_running (), (
446506 "Cannot run client.execute(query) if an asyncio loop is running."
@@ -537,7 +597,24 @@ def execute_batch(
537597 """
538598
539599 if isinstance (self .transport , AsyncTransport ):
540- raise NotImplementedError ("Batching is not implemented for async yet." )
600+ loop = self ._get_event_loop ()
601+
602+ assert not loop .is_running (), (
603+ "Cannot run client.execute_batch(query) if an asyncio loop is running."
604+ " Use 'await client.execute_batch(query)' instead."
605+ )
606+
607+ data = loop .run_until_complete (
608+ self .execute_batch_async (
609+ requests ,
610+ serialize_variables = serialize_variables ,
611+ parse_result = parse_result ,
612+ get_execution_result = get_execution_result ,
613+ ** kwargs ,
614+ )
615+ )
616+
617+ return data
541618
542619 else : # Sync transports
543620 return self .execute_batch_sync (
@@ -675,17 +752,12 @@ def subscribe(
675752 We need an async transport for this functionality.
676753 """
677754
678- # Get the current asyncio event loop
679- # Or create a new event loop if there isn't one (in a new Thread)
680- try :
681- with warnings .catch_warnings ():
682- warnings .filterwarnings (
683- "ignore" , message = "There is no current event loop"
684- )
685- loop = asyncio .get_event_loop ()
686- except RuntimeError :
687- loop = asyncio .new_event_loop ()
688- asyncio .set_event_loop (loop )
755+ loop = self ._get_event_loop ()
756+
757+ assert not loop .is_running (), (
758+ "Cannot run client.subscribe(query) if an asyncio loop is running."
759+ " Use 'await client.subscribe_async(query)' instead."
760+ )
689761
690762 async_generator : Union [
691763 AsyncGenerator [Dict [str , Any ], None ], AsyncGenerator [ExecutionResult , None ]
@@ -699,11 +771,6 @@ def subscribe(
699771 ** kwargs ,
700772 )
701773
702- assert not loop .is_running (), (
703- "Cannot run client.subscribe(query) if an asyncio loop is running."
704- " Use 'await client.subscribe_async(query)' instead."
705- )
706-
707774 try :
708775 while True :
709776 # Note: we need to create a task here in order to be able to close
@@ -1626,6 +1693,149 @@ async def execute(
16261693
16271694 return result .data
16281695
1696+ async def _execute_batch (
1697+ self ,
1698+ requests : List [GraphQLRequest ],
1699+ * ,
1700+ serialize_variables : Optional [bool ] = None ,
1701+ parse_result : Optional [bool ] = None ,
1702+ validate_document : Optional [bool ] = True ,
1703+ ** kwargs : Any ,
1704+ ) -> List [ExecutionResult ]:
1705+ """Execute multiple GraphQL requests in a batch, using
1706+ the async transport, returning a list of ExecutionResult objects.
1707+
1708+ :param requests: List of requests that will be executed.
1709+ :param serialize_variables: whether the variable values should be
1710+ serialized. Used for custom scalars and/or enums.
1711+ By default use the serialize_variables argument of the client.
1712+ :param parse_result: Whether gql will deserialize the result.
1713+ By default use the parse_results argument of the client.
1714+ :param validate_document: Whether we still need to validate the document.
1715+
1716+ The extra arguments are passed to the transport execute_batch method."""
1717+
1718+ # Validate document
1719+ if self .client .schema :
1720+
1721+ if validate_document :
1722+ for req in requests :
1723+ self .client .validate (req .document )
1724+
1725+ # Parse variable values for custom scalars if requested
1726+ if serialize_variables or (
1727+ serialize_variables is None and self .client .serialize_variables
1728+ ):
1729+ requests = [
1730+ (
1731+ req .serialize_variable_values (self .client .schema )
1732+ if req .variable_values is not None
1733+ else req
1734+ )
1735+ for req in requests
1736+ ]
1737+
1738+ results = await self .transport .execute_batch (requests , ** kwargs )
1739+
1740+ # Unserialize the result if requested
1741+ if self .client .schema :
1742+ if parse_result or (parse_result is None and self .client .parse_results ):
1743+ for result in results :
1744+ result .data = parse_result_fn (
1745+ self .client .schema ,
1746+ req .document ,
1747+ result .data ,
1748+ operation_name = req .operation_name ,
1749+ )
1750+
1751+ return results
1752+
1753+ @overload
1754+ async def execute_batch (
1755+ self ,
1756+ requests : List [GraphQLRequest ],
1757+ * ,
1758+ serialize_variables : Optional [bool ] = None ,
1759+ parse_result : Optional [bool ] = None ,
1760+ get_execution_result : Literal [False ] = ...,
1761+ ** kwargs : Any ,
1762+ ) -> List [Dict [str , Any ]]: ... # pragma: no cover
1763+
1764+ @overload
1765+ async def execute_batch (
1766+ self ,
1767+ requests : List [GraphQLRequest ],
1768+ * ,
1769+ serialize_variables : Optional [bool ] = None ,
1770+ parse_result : Optional [bool ] = None ,
1771+ get_execution_result : Literal [True ],
1772+ ** kwargs : Any ,
1773+ ) -> List [ExecutionResult ]: ... # pragma: no cover
1774+
1775+ @overload
1776+ async def execute_batch (
1777+ self ,
1778+ requests : List [GraphQLRequest ],
1779+ * ,
1780+ serialize_variables : Optional [bool ] = None ,
1781+ parse_result : Optional [bool ] = None ,
1782+ get_execution_result : bool ,
1783+ ** kwargs : Any ,
1784+ ) -> Union [List [Dict [str , Any ]], List [ExecutionResult ]]: ... # pragma: no cover
1785+
1786+ async def execute_batch (
1787+ self ,
1788+ requests : List [GraphQLRequest ],
1789+ * ,
1790+ serialize_variables : Optional [bool ] = None ,
1791+ parse_result : Optional [bool ] = None ,
1792+ get_execution_result : bool = False ,
1793+ ** kwargs : Any ,
1794+ ) -> Union [List [Dict [str , Any ]], List [ExecutionResult ]]:
1795+ """Execute multiple GraphQL requests in a batch, using
1796+ the async transport. This method sends the requests to the server all at once.
1797+
1798+ Raises a TransportQueryError if an error has been returned in any
1799+ ExecutionResult.
1800+
1801+ :param requests: List of requests that will be executed.
1802+ :param serialize_variables: whether the variable values should be
1803+ serialized. Used for custom scalars and/or enums.
1804+ By default use the serialize_variables argument of the client.
1805+ :param parse_result: Whether gql will deserialize the result.
1806+ By default use the parse_results argument of the client.
1807+ :param get_execution_result: return the full ExecutionResult instance instead of
1808+ only the "data" field. Necessary if you want to get the "extensions" field.
1809+
1810+ The extra arguments are passed to the transport execute method."""
1811+
1812+ # Validate and execute on the transport
1813+ results = await self ._execute_batch (
1814+ requests ,
1815+ serialize_variables = serialize_variables ,
1816+ parse_result = parse_result ,
1817+ ** kwargs ,
1818+ )
1819+
1820+ for result in results :
1821+ # Raise an error if an error is returned in the ExecutionResult object
1822+ if result .errors :
1823+ raise TransportQueryError (
1824+ str_first_element (result .errors ),
1825+ errors = result .errors ,
1826+ data = result .data ,
1827+ extensions = result .extensions ,
1828+ )
1829+
1830+ assert (
1831+ result .data is not None
1832+ ), "Transport returned an ExecutionResult without data or errors"
1833+
1834+ if get_execution_result :
1835+ return results
1836+
1837+ return cast (List [Dict [str , Any ]], [result .data for result in results ])
1838+
16291839 async def fetch_schema (self ) -> None :
16301840 """Fetch the GraphQL schema explicitly using introspection.
16311841
0 commit comments