diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index 1c2c1c82..c6ee035a 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -64,11 +64,11 @@ from the :code:`ds` instance ds.Query.hero.select(ds.Character.name) -The select method return the same instance, so it is possible to chain the calls:: +The select method returns the same instance, so it is possible to chain the calls:: ds.Query.hero.select(ds.Character.name).select(ds.Character.id) -Or do it sequencially:: +Or do it sequentially:: hero_query = ds.Query.hero @@ -279,7 +279,7 @@ will generate the request:: Multiple operations in a document ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -It is possible to create an Document with multiple operations:: +It is possible to create a Document with multiple operations:: query = dsl_gql( operation_name_1=DSLQuery( ... ), @@ -384,6 +384,305 @@ you can use the :class:`DSLMetaField ` class:: DSLMetaField("__typename") ) +Directives +^^^^^^^^^^ + +`Directives`_ provide a way to describe alternate runtime execution and type validation +behavior in a GraphQL document. The DSL module supports both built-in GraphQL directives +(:code:`@skip`, :code:`@include`) and custom schema-defined directives. + +To add directives to DSL elements, use the :meth:`DSLSchema.__call__ ` +factory method and the :meth:`directives ` method:: + + # Using built-in @skip directive with DSLSchema.__call__ factory + ds.Query.hero.select( + ds.Character.name.directives(ds("@skip").args(**{"if": True})) + ) + +Directive Arguments +""""""""""""""""""" + +Directive arguments can be passed using the :meth:`args ` method. +For arguments that don't conflict with Python reserved words, you can pass them directly:: + + # Using the args method for non-reserved names + ds("@custom").args(value="foo", reason="testing") + +It can also be done by calling the directive directly:: + + ds("@custom")(value="foo", reason="testing") + +However, when the GraphQL directive argument name conflicts with a Python reserved word +(like :code:`if`), you need to unpack a dictionary to escape it:: + + # Dictionary unpacking for Python reserved words + ds("@skip").args(**{"if": True}) + ds("@include")(**{"if": False}) + +This ensures that the exact GraphQL argument name is passed to the directive and that +no post-processing of arguments is required. + +The :meth:`DSLSchema.__call__ ` factory method automatically handles +schema lookup and validation for both built-in directives (:code:`@skip`, :code:`@include`) +and custom schema-defined directives using the same syntax. + +Directive Locations +""""""""""""""""""" + +The DSL module supports all executable directive locations from the GraphQL specification: + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - GraphQL Spec Location + - DSL Class/Method + - Description + * - QUERY + - :code:`DSLQuery.directives()` + - Directives on query operations + * - MUTATION + - :code:`DSLMutation.directives()` + - Directives on mutation operations + * - SUBSCRIPTION + - :code:`DSLSubscription.directives()` + - Directives on subscription operations + * - FIELD + - :code:`DSLField.directives()` + - Directives on fields (including meta-fields) + * - FRAGMENT_DEFINITION + - :code:`DSLFragment.directives()` + - Directives on fragment definitions + * - FRAGMENT_SPREAD + - :code:`DSLFragmentSpread.directives()` + - Directives on fragment spreads (via .spread()) + * - INLINE_FRAGMENT + - :code:`DSLInlineFragment.directives()` + - Directives on inline fragments + * - VARIABLE_DEFINITION + - :code:`DSLVariable.directives()` + - Directives on variable definitions + +Examples by Location +"""""""""""""""""""" + +**Operation directives**:: + + # Query operation + query = DSLQuery(ds.Query.hero.select(ds.Character.name)).directives( + ds("@customQueryDirective") + ) + + # Mutation operation + mutation = DSLMutation( + ds.Mutation.createReview.args(episode=6, review={"stars": 5}).select( + ds.Review.stars + ) + ).directives(ds("@customMutationDirective")) + +**Field directives**:: + + # Single directive on field + ds.Query.hero.select( + ds.Character.name.directives(ds("@customFieldDirective")) + ) + + # Multiple directives on a field + ds.Query.hero.select( + ds.Character.appearsIn.directives( + ds("@repeat").args(value="first"), + ds("@repeat").args(value="second"), + ds("@repeat").args(value="third"), + ) + ) + +**Fragment directives**: + +You can add directives to fragment definitions and to fragment spread instances. +To do this, first define your fragment in the usual way:: + + name_and_appearances = ( + DSLFragment("NameAndAppearances") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + ) + +Then, use :meth:`spread() ` when you need to add +directives to the fragment spread:: + + query_with_fragment = DSLQuery( + ds.Query.hero.select( + name_and_appearances.spread().directives( + ds("@customFragmentSpreadDirective") + ) + ) + ) + +The :meth:`spread() ` method creates a +:class:`DSLFragmentSpread ` instance that allows you to add +directives specific to the fragment spread location, separate from directives on the +fragment definition itself. + +Example with fragment definition and spread-specific directives:: + + # Fragment definition with directive + name_and_appearances = ( + DSLFragment("CharacterInfo") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + .directives(ds("@customFragmentDefinitionDirective")) + ) + + # Using fragment with spread-specific directives + query_without_spread_directive = DSLQuery( + # Direct usage (no spread directives) + ds.Query.hero.select(name_and_appearances) + ) + query_with_spread_directive = DSLQuery( + # Enhanced usage with spread directives + name_and_appearances.spread().directives( + ds("@customFragmentSpreadDirective") + ) + ) + + # Don't forget to include the fragment definition in dsl_gql + query = dsl_gql( + name_and_appearances, + BaseQuery=query_without_spread_directive, + QueryWithDirective=query_with_spread_directive, + ) + +This generates GraphQL equivalent to:: + + fragment CharacterInfo on Character @customFragmentDefinitionDirective { + name + appearsIn + } + + { + BaseQuery hero { + ...CharacterInfo + } + QueryWithDirective hero { + ...CharacterInfo @customFragmentSpreadDirective + } + } + +**Inline fragment directives**: + +Inline fragments also support directives using the +:meth:`directives ` method:: + + query_with_directive = ds.Query.hero.args(episode=6).select( + ds.Character.name, + DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet).directives( + ds("@customInlineFragmentDirective") + ) + ) + +This generates:: + + { + hero(episode: JEDI) { + name + ... on Human @customInlineFragmentDirective { + homePlanet + } + } + } + +**Variable definition directives**: + +You can also add directives to variable definitions using the +:meth:`directives ` method:: + + var = DSLVariableDefinitions() + var.episode.directives(ds("@customVariableDirective")) + # Note: the directive is attached to the `.episode` variable definition (singular), + # and not the `var` variable definitions (plural) holder. + + op = DSLQuery(ds.Query.hero.args(episode=var.episode).select(ds.Character.name)) + op.variable_definitions = var + +This will generate:: + + query ($episode: Episode @customVariableDirective) { + hero(episode: $episode) { + name + } + } + +Complete Example for Directives +""""""""""""""""""""""""""""""" + +Here's a comprehensive example showing directives on multiple locations: + +.. code-block:: python + + from gql.dsl import DSLFragment, DSLInlineFragment, DSLQuery, dsl_gql + + # Create variables for directive conditions + var = DSLVariableDefinitions() + + # Fragment with directive on definition + character_fragment = DSLFragment("CharacterInfo").on(ds.Character).select( + ds.Character.name, ds.Character.appearsIn + ).directives(ds("@fragmentDefinition")) + + # Query with directives on multiple locations + query = DSLQuery( + ds.Query.hero.args(episode=var.episode).select( + # Field with directive + ds.Character.name.directives(ds("@skip").args(**{"if": var.skipName})), + + # Fragment spread with directive + character_fragment.spread().directives( + ds("@include").args(**{"if": var.includeFragment}) + ), + + # Inline fragment with directive + DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet).directives( + ds("@skip").args(**{"if": var.skipHuman}) + ), + + # Meta field with directive + DSLMetaField("__typename").directives( + ds("@include").args(**{"if": var.includeType}) + ) + ) + ).directives(ds("@query")) # Operation directive + + # Variable definition with directive + var.episode.directives(ds("@variableDefinition")) + query.variable_definitions = var + + # Generate the document + document = dsl_gql(character_fragment, query) + +This generates GraphQL equivalent to:: + + fragment CharacterInfo on Character @fragmentDefinition { + name + appearsIn + } + + query ( + $episode: Episode @variableDefinition + $skipName: Boolean! + $includeFragment: Boolean! + $skipHuman: Boolean! + $includeType: Boolean! + ) @query { + hero(episode: $episode) { + name @skip(if: $skipName) + ...CharacterInfo @include(if: $includeFragment) + ... on Human @skip(if: $skipHuman) { + homePlanet + } + __typename @include(if: $includeType) + } + } + Executable examples ------------------- @@ -399,4 +698,5 @@ Sync example .. _Fragment: https://graphql.org/learn/queries/#fragments .. _Inline Fragment: https://graphql.org/learn/queries/#inline-fragments +.. _Directives: https://graphql.org/learn/queries/#directives .. _issue #308: https://github.com/graphql-python/gql/issues/308 diff --git a/gql/dsl.py b/gql/dsl.py index 1a8716c2..da4cf64c 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,17 +1,31 @@ """ -.. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa - :alt: UML diagram +.. image:: https://www.plantuml.com/plantuml/png/hLZXJkGs4FwVft1_NLXOfBR_Lcrrz3Wg93WA2rTL24Kc6LYtMITdErpf5QdFqaVharp6tincS8ZsLlTd8PxnpESltun7UMsTDAvPbichRzm2bY3gKYgT9Bfo8AGLfrNHb73KwDofIjjaCWahWfOca-J_V_yJXIsp-mzbEgbgCD9RziIazvHzL6wHQRc4dPdunSXwSNvo0HyQiCu7aDPbTwPQPW-oR23rltl2FTQGjHlEQWmYo-ltkFwkAk26xx9Wb2pLtr2405cZSM-HhWqlX05T23nkakIbj5OSpa_cUSk559yI8QRJzcStot9PbbcM8lwPiCxipD3nK1d8dNg0u7GFJZfdOh_B5ahoH1d20iKVtNgae2pONahg0-mMtMDMm1rHov0XI-Gs4sH30j1EAUC3JoP_VfJctWwS5vTViZF0xwLHyhQ4GxXJMdar1EWFAuD5JBcxjixizJVSR40GEQDRwvJvmwupfQtNPLENS1t3mFFlYVtz_Hl4As_Rc39tOgq3A25tbGbeBJxXjio2cubvzpW7Xu48wwSkq9DG5jMeYkmEtsBgVriyjrLLhYEc4x_kwoNy5sgbtIYHrmFzoE5n8U2HdYd18WdTiTdR3gSTXKfHKlglWynof1FwVnJbHLKvBsB6PiW_nizWi2CZxvUWtLU9zRL0OGnw3vnLQLq8CnDNMbNwsYSDR-9Obqf3TwAmHkUh3KZlrtjPracdyYU1AlVYW1L6ctOAYlH3wcSunqJ_zY_86-_5YxHVLBCNofgQ2NLQhEcRZQg7yGO40gNiAM0jvQoxLm96kcOoRFepGMRii-Z0u_KSU3E84vqtO1w7aeWVUPRzywkt5xzp4OsN4yjpsZWVQgDKfrUN1vV7P--spZPlRcrkLBrnnldLp_Ct5yU_RfsL14EweZRUtL0aD4JGKn02w2g1EuOGNTXEHgrEPLEwC0VuneIhpuAkhibZNJSE4wpBp5Ke4GyYxSQF3a8GCZVoEuZIfmm6Tzk2FEfyWRnUNubR1cStLZzj6H8_dj17IWDc7dx3MujlzVhIWQ-yqeNFo5qsPsIq__xM8ZX0035B-8UTqWDD_IzD4uEns6lWJJjAmysKRtFQU8fnyhZZwEqSUsyZGSGxokokNwCXr9jmkPO6T2YRxY9SkPpT_W6vhy0zGJNfmDp97Bgwt2ri-Rmfj738lF7uIdXmQS2skRnfnpZhvBJ5XG1EzWYdot_Phg_8Y2ZSkZFp8j-YnM3QSI9uZ2y0-KeSwmKOvQJEGHWe_Qra5wgsINz6_-6VwJGQws8FDk74PXfOnuF4asYIy8ayJZRWm2w5sCmRKfAmS16IP01LxCH2nkPaY01oew5W20gp9_qdRwTfQj140z2WbGqioV0PU8CRPuEx3WSSlWi6F6Dn9yERkKJHYRFCpMIdTMe9M1HlgcLTMNyRyA8GKt4Y7y68RyMgdWH-8H6cgjnEilwwCPt-H5yYPY8t81rORkTV6yXfi_JVYTJd3PiAKVasPJq4J8e9wBGCmU070-zDfYz6yxr86ollGIWjQDQrErp7F0dBZ_agxQJIbXVg44-D1TlNd_U9somTGJmeARgfAtaDkcYMvMS0 # noqa + :alt: UML diagram - rename png to uml to edit """ import logging import re from abc import ABC, abstractmethod from math import isfinite -from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union, cast +from typing import ( + Any, + Dict, + Iterable, + Literal, + Mapping, + Optional, + Set, + Tuple, + Union, + cast, + overload, +) from graphql import ( ArgumentNode, BooleanValueNode, + DirectiveLocation, + DirectiveNode, DocumentNode, EnumValueNode, FieldNode, @@ -19,6 +33,7 @@ FragmentDefinitionNode, FragmentSpreadNode, GraphQLArgument, + GraphQLDirective, GraphQLEnumType, GraphQLError, GraphQLField, @@ -61,6 +76,7 @@ is_non_null_type, is_wrapping_type, print_ast, + specified_directives, ) from graphql.pyutils import inspect @@ -132,8 +148,9 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: Produce a GraphQL Value AST given a Python object. - Raises a GraphQLError instead of returning None if we receive an Undefined - of if we receive a Null value for a Non-Null type. + :raises graphql.error.GraphQLError: + instead of returning None if we receive an Undefined + of if we receive a Null value for a Non-Null type. """ if isinstance(value, DSLVariable): return value.set_type(type_).ast_variable_name @@ -274,6 +291,9 @@ class DSLSchema: Attributes of the DSLSchema class are generated automatically with the `__getattr__` dunder method in order to generate instances of :class:`DSLType` + + .. automethod:: __call__ + .. automethod:: __getattr__ """ def __init__(self, schema: GraphQLSchema): @@ -293,7 +313,57 @@ def __init__(self, schema: GraphQLSchema): self._schema: GraphQLSchema = schema + @overload + def __call__( + self, shortcut: Literal["__typename", "__schema", "__type"] + ) -> "DSLMetaField": ... # pragma: no cover + + @overload + def __call__( + self, shortcut: Literal["..."] + ) -> "DSLInlineFragment": ... # pragma: no cover + + @overload + def __call__( + self, shortcut: Literal["fragment"], name: str + ) -> "DSLFragment": ... # pragma: no cover + + @overload + def __call__(self, shortcut: Any) -> "DSLDirective": ... # pragma: no cover + + def __call__( + self, shortcut: str, name: Optional[str] = None + ) -> Union["DSLMetaField", "DSLInlineFragment", "DSLFragment", "DSLDirective"]: + """Factory method for creating DSL objects. + + Currently, supports creating DSLDirective instances when name starts with '@'. + Future support planned for meta-fields (__typename), inline fragments (...), + and fragment definitions (fragment). + + :param shortcut: the name of the object to create + :type shortcut: str + + :return: :class:`DSLDirective` instance + + :raises ValueError: if shortcut format is not supported + """ + if shortcut.startswith("@"): + return DSLDirective(name=shortcut[1:], dsl_schema=self) + # Future support: + # if name.startswith("__"): return DSLMetaField(name) + # if name == "...": return DSLInlineFragment() + # if name.startswith("fragment "): return DSLFragment(name[9:]) + + raise ValueError(f"Unsupported shortcut: {shortcut}") + def __getattr__(self, name: str) -> "DSLType": + """Attributes of the DSLSchema class are generated automatically + with this dunder method in order to generate + instances of :class:`DSLType` + + :return: :class:`DSLType` instance + :raises AttributeError: if the name is not valid + """ type_def: Optional[GraphQLNamedType] = self._schema.get_type(name) @@ -381,7 +451,218 @@ def select( log.debug(f"Added fields: {added_fields} in {self!r}") -class DSLExecutable(DSLSelector): +class DSLDirective: + """The DSLDirective represents a GraphQL directive for the DSL code. + + Directives provide a way to describe alternate runtime execution and type validation + behavior in a GraphQL document. + """ + + def __init__(self, name: str, dsl_schema: "DSLSchema"): + r"""Initialize the DSLDirective with the given name and arguments. + + :param name: the name of the directive + :param dsl_schema: DSLSchema for directive validation and definition lookup + + :raises graphql.error.GraphQLError: if directive not found or not executable + """ + self._dsl_schema = dsl_schema + + # Find directive definition in schema or built-ins + directive_def = self._dsl_schema._schema.get_directive(name) + + if directive_def is None: + # Try to find in built-in directives using specified_directives + builtins = {builtin.name: builtin for builtin in specified_directives} + directive_def = builtins.get(name) + + if directive_def is None: + available: Set[str] = set() + available.update(f"@{d.name}" for d in self._dsl_schema._schema.directives) + available.update(f"@{d.name}" for d in specified_directives) + raise GraphQLError( + f"Directive '@{name}' not found in schema or built-ins. " + f"Available directives: {', '.join(sorted(available))}" + ) + + # Check directive has at least one executable location + executable_locations = { + DirectiveLocation.QUERY, + DirectiveLocation.MUTATION, + DirectiveLocation.SUBSCRIPTION, + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_DEFINITION, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + DirectiveLocation.VARIABLE_DEFINITION, + } + + if not any(loc in executable_locations for loc in directive_def.locations): + raise GraphQLError( + f"Directive '@{name}' is not a valid request executable directive. " + f"It can only be used in type system locations, not in requests." + ) + + self.directive_def: GraphQLDirective = directive_def + self.ast_directive = DirectiveNode(name=NameNode(value=name), arguments=()) + + @property + def name(self) -> str: + """Get the directive name.""" + return self.ast_directive.name.value + + def __call__(self, **kwargs: Any) -> "DSLDirective": + """Add arguments by calling the directive like a function. + + :param kwargs: directive arguments + :return: itself + """ + return self.args(**kwargs) + + def args(self, **kwargs: Any) -> "DSLDirective": + r"""Set the arguments of a directive + + The arguments are parsed to be stored in the AST of this field. + + .. note:: + You can also call the field directly with your arguments. + :code:`ds("@someDirective").args(value="foo")` is equivalent to: + :code:`ds("@someDirective")(value="foo")` + + :param \**kwargs: the arguments (keyword=value) + + :return: itself + + :raises AttributeError: if arguments already set for this directive + :raises graphql.error.GraphQLError: + if argument doesn't exist in directive definition + """ + if len(self.ast_directive.arguments) > 0: + raise AttributeError(f"Arguments for directive @{self.name} already set.") + + errs = [] + for key, value in kwargs.items(): + if key not in self.directive_def.args: + errs.append( + f"Argument '{key}' does not exist in directive '@{self.name}'" + ) + if errs: + raise GraphQLError("\n".join(errs)) + + # Update AST directive with arguments + self.ast_directive = DirectiveNode( + name=NameNode(value=self.name), + arguments=tuple( + ArgumentNode( + name=NameNode(value=key), + value=ast_from_value(value, self.directive_def.args[key].type), + ) + for key, value in kwargs.items() + ), + ) + + return self + + def __repr__(self) -> str: + args_str = ", ".join( + f"{arg.name.value}={getattr(arg.value, 'value')}" + for arg in self.ast_directive.arguments + ) + return f"" + + +class DSLDirectable(ABC): + """Mixin class for DSL elements that can have directives. + + Provides the directives() method for adding GraphQL directives to DSL elements. + Classes that need immediate AST updates should override the directives() method. + """ + + _directives: Tuple[DSLDirective, ...] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._directives = () + + @abstractmethod + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if a directive is valid for this DSL element. + + :param directive: The DSLDirective to validate + :return: True if the directive can be used at this location + """ + raise NotImplementedError( + "Any DSLDirectable concrete class must have an is_valid_directive method" + ) # pragma: no cover + + def directives(self, *directives: DSLDirective) -> Any: + r"""Add directives to this DSL element. + + :param \*directives: DSLDirective instances to add + :return: itself + + :raises graphql.error.GraphQLError: if directive location is invalid + :raises TypeError: if argument is not a DSLDirective + + Usage: + + .. code-block:: python + + # Using new factory method + element.directives(ds("@include")(**{"if": var.show})) + element.directives(ds("@skip")(**{"if": var.hide})) + """ + validated_directives = [] + + for directive in directives: + if not isinstance(directive, DSLDirective): + raise TypeError( + f"Expected DSLDirective, got {type(directive)}. " + f"Use ds('@directiveName') to create directive instances." + ) + + # Validate directive location using the abstract method + if not self.is_valid_directive(directive): + # Get valid locations for error message + valid_locations = [ + loc.name + for loc in directive.directive_def.locations + if loc + in { + DirectiveLocation.QUERY, + DirectiveLocation.MUTATION, + DirectiveLocation.SUBSCRIPTION, + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_DEFINITION, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + DirectiveLocation.VARIABLE_DEFINITION, + } + ] + raise GraphQLError( + f"Invalid directive location: '@{directive.name}' " + f"cannot be used on {self.__class__.__name__}. " + f"Valid locations for this directive: {', '.join(valid_locations)}" + ) + + validated_directives.append(directive) + + # Update stored directives + self._directives = self._directives + tuple(validated_directives) + + log.debug( + f"Added directives {[d.name for d in validated_directives]} to {self!r}" + ) + + return self + + @property + def directives_ast(self) -> Tuple[DirectiveNode, ...]: + """Get AST directive nodes for this element.""" + return tuple(directive.ast_directive for directive in self._directives) + + +class DSLExecutable(DSLSelector, DSLDirectable): """Interface for the root elements which can be executed in the :func:`dsl_gql ` function @@ -430,6 +711,7 @@ def __init__( self.variable_definitions = DSLVariableDefinitions() DSLSelector.__init__(self, *fields, **fields_with_alias) + DSLDirectable.__init__(self) class DSLRootFieldSelector(DSLSelector): @@ -508,7 +790,7 @@ def executable_ast(self) -> OperationDefinitionNode: selection_set=self.selection_set, variable_definitions=self.variable_definitions.get_ast_definitions(), **({"name": NameNode(value=self.name)} if self.name else {}), - directives=(), + directives=self.directives_ast, ) def __repr__(self) -> str: @@ -518,16 +800,28 @@ def __repr__(self) -> str: class DSLQuery(DSLOperation): operation_type = OperationType.QUERY + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Query operations.""" + return DirectiveLocation.QUERY in directive.directive_def.locations + class DSLMutation(DSLOperation): operation_type = OperationType.MUTATION + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Mutation operations.""" + return DirectiveLocation.MUTATION in directive.directive_def.locations + class DSLSubscription(DSLOperation): operation_type = OperationType.SUBSCRIPTION + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Subscription operations.""" + return DirectiveLocation.SUBSCRIPTION in directive.directive_def.locations -class DSLVariable: + +class DSLVariable(DSLDirectable): """The DSLVariable represents a single variable defined in a GraphQL operation Instances of this class are generated for you automatically as attributes @@ -545,6 +839,8 @@ def __init__(self, name: str): self.default_value = None self.type: Optional[GraphQLInputType] = None + DSLDirectable.__init__(self) + def to_ast_type(self, type_: GraphQLInputType) -> TypeNode: if is_wrapping_type(type_): if isinstance(type_, GraphQLList): @@ -568,6 +864,18 @@ def default(self, default_value: Any) -> "DSLVariable": self.default_value = default_value return self + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Variable definitions.""" + for arg in directive.ast_directive.arguments: + if isinstance(arg.value, VariableNode): + raise GraphQLError( + f"Directive @{directive.name} argument value has " + f"unexpected variable '${arg.value.name}' in constant location." + ) + return ( + DirectiveLocation.VARIABLE_DEFINITION in directive.directive_def.locations + ) + class DSLVariableDefinitions: """The DSLVariableDefinitions represents variable definitions in a GraphQL operation @@ -579,6 +887,8 @@ class DSLVariableDefinitions: with the `__getattr__` dunder method in order to generate instances of :class:`DSLVariable`, that can then be used as values in the :meth:`args ` method. + + .. automethod:: __getattr__ """ def __init__(self): @@ -586,6 +896,12 @@ def __init__(self): self.variables: Dict[str, DSLVariable] = {} def __getattr__(self, name: str) -> "DSLVariable": + """Attributes of the DSLVariableDefinitions class are generated automatically + with this dunder method in order to generate + instances of :class:`DSLVariable` + + :return: :class:`DSLVariable` instance + """ if name not in self.variables: self.variables[name] = DSLVariable(name) return self.variables[name] @@ -605,7 +921,7 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: if var.default_value is None else ast_from_value(var.default_value, var.type) ), - directives=(), + directives=var.directives_ast, ) for var in self.variables.values() if var.type is not None # only variables used @@ -625,6 +941,8 @@ class DSLType: Attributes of the DSLType class are generated automatically with the `__getattr__` dunder method in order to generate instances of :class:`DSLField` + + .. automethod:: __getattr__ """ def __init__( @@ -646,6 +964,13 @@ def __init__( log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": + """Attributes of the DSLType class are generated automatically + with this dunder method in order to generate + instances of :class:`DSLField` + + :return: :class:`DSLField` instance + :raises AttributeError: if the field name does not exist in the type + """ camel_cased_name = to_camel_case(name) if name in self._type.fields: @@ -665,7 +990,7 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {self._type!r}>" -class DSLSelectable(ABC): +class DSLSelectable(DSLDirectable): """DSLSelectable is an abstract class which indicates that the subclasses can be used as arguments of the :meth:`select ` method. @@ -715,7 +1040,7 @@ def is_valid_field(self, field: DSLSelectable) -> bool: assert isinstance(self, (DSLFragment, DSLInlineFragment)) - if isinstance(field, (DSLFragment, DSLInlineFragment)): + if isinstance(field, (DSLFragment, DSLFragmentSpread, DSLInlineFragment)): return True assert isinstance(field, DSLField) @@ -747,7 +1072,7 @@ def is_valid_field(self, field: DSLSelectable) -> bool: assert isinstance(self, DSLField) - if isinstance(field, (DSLFragment, DSLInlineFragment)): + if isinstance(field, (DSLFragment, DSLFragmentSpread, DSLInlineFragment)): return True assert isinstance(field, DSLField) @@ -837,6 +1162,7 @@ def __init__( log.debug(f"Creating {self!r}") DSLSelector.__init__(self) + DSLDirectable.__init__(self) @property def name(self): @@ -903,6 +1229,17 @@ def select( return self + def directives(self, *directives: DSLDirective) -> "DSLField": + """Add directives to this field.""" + super().directives(*directives) + self.ast_field.directives = self.directives_ast + + return self + + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Field locations.""" + return DirectiveLocation.FIELD in directive.directive_def.locations + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.parent_type.name}" f"::{self.name}>" @@ -941,6 +1278,10 @@ def __init__(self, name: str): super().__init__(name, self.meta_type, field) + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for MetaField locations (same as Field).""" + return DirectiveLocation.FIELD in directive.directive_def.locations + class DSLInlineFragment(DSLSelectable, DSLFragmentSelector): """DSLInlineFragment represents an inline fragment for the DSL code.""" @@ -966,6 +1307,7 @@ def __init__( self.ast_field = InlineFragmentNode(directives=()) DSLSelector.__init__(self, *fields, **fields_with_alias) + DSLDirectable.__init__(self) def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" @@ -987,6 +1329,15 @@ def on(self, type_condition: DSLType) -> "DSLInlineFragment": ) return self + def directives(self, *directives: DSLDirective) -> "DSLInlineFragment": + """Add directives to this inline fragment. + + Inline fragments support all directive types through auto-validation. + """ + super().directives(*directives) + self.ast_field.directives = self.directives_ast + return self + def __repr__(self) -> str: type_info = "" @@ -997,13 +1348,62 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__}{type_info}>" + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Inline Fragment locations.""" + return DirectiveLocation.INLINE_FRAGMENT in directive.directive_def.locations + + +class DSLFragmentSpread(DSLSelectable): + """Represents a fragment spread (usage) with its own directives. + + This class is created by calling .spread() on a DSLFragment and allows + adding directives specific to the FRAGMENT_SPREAD location. + """ + + ast_field: FragmentSpreadNode + _fragment: "DSLFragment" + + def __init__(self, fragment: "DSLFragment"): + """Initialize a fragment spread from a fragment definition. + + :param fragment: The DSLFragment to create a spread from + """ + self._fragment = fragment + self.ast_field = FragmentSpreadNode( + name=NameNode(value=fragment.name), directives=() + ) + + log.debug(f"Creating fragment spread for {fragment.name}") + + DSLDirectable.__init__(self) + + @property + def name(self) -> str: + """:meta private:""" + return self.ast_field.name.value + + def directives(self, *directives: DSLDirective) -> "DSLFragmentSpread": + """Add directives to this fragment spread. + + Fragment spreads support all directive types through auto-validation. + """ + super().directives(*directives) + self.ast_field.directives = self.directives_ast + return self + + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Fragment Spread locations.""" + return DirectiveLocation.FRAGMENT_SPREAD in directive.directive_def.locations + + def __repr__(self) -> str: + return f"" + class DSLFragment(DSLSelectable, DSLFragmentSelector, DSLExecutable): """DSLFragment represents a named GraphQL fragment for the DSL code.""" _type: Optional[Union[GraphQLObjectType, GraphQLInterfaceType]] ast_field: FragmentSpreadNode - name: str def __init__( self, @@ -1017,24 +1417,32 @@ def __init__( DSLExecutable.__init__(self) - self.name = name + self.ast_field = FragmentSpreadNode(name=NameNode(value=name), directives=()) + self._type = None log.debug(f"Creating {self!r}") - @property # type: ignore - def ast_field(self) -> FragmentSpreadNode: # type: ignore - """ast_field property will generate a FragmentSpreadNode with the - provided name. + @property + def name(self) -> str: + """:meta private:""" + return self.ast_field.name.value - Note: We need to ignore the type because of - `issue #4125 of mypy `_. - """ + @name.setter + def name(self, value: str) -> None: + """:meta private:""" + if hasattr(self, "ast_field"): + self.ast_field.name.value = value - spread_node = FragmentSpreadNode(directives=()) - spread_node.name = NameNode(value=self.name) + def spread(self) -> DSLFragmentSpread: + """Create a fragment spread that can have its own directives. - return spread_node + This allows adding directives specific to the FRAGMENT_SPREAD location, + separate from directives on the fragment definition itself. + + :return: DSLFragmentSpread instance for this fragment + """ + return DSLFragmentSpread(self) def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias" @@ -1096,7 +1504,13 @@ def executable_ast(self) -> FragmentDefinitionNode: selection_set=self.selection_set, **variable_definition_kwargs, name=NameNode(value=self.name), - directives=(), + directives=self.directives_ast, + ) + + def is_valid_directive(self, directive: "DSLDirective") -> bool: + """Check if directive is valid for Fragment Definition locations.""" + return ( + DirectiveLocation.FRAGMENT_DEFINITION in directive.directive_def.locations ) def __repr__(self) -> str: diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 8f1efe99..f14a4ea1 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -2,7 +2,9 @@ from typing import cast from graphql import ( + DirectiveLocation, GraphQLArgument, + GraphQLDirective, GraphQLEnumType, GraphQLEnumValue, GraphQLField, @@ -19,6 +21,7 @@ get_introspection_query, graphql_sync, print_schema, + specified_directives, ) from .fixtures import ( @@ -264,12 +267,125 @@ async def resolve_review(review, _info, **_args): }, ) +query_directive = GraphQLDirective( + name="query", + description="Test directive for QUERY location", + locations=[DirectiveLocation.QUERY], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +field_directive = GraphQLDirective( + name="field", + description="Test directive for FIELD location", + locations=[DirectiveLocation.FIELD], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +fragment_spread_directive = GraphQLDirective( + name="fragmentSpread", + description="Test directive for FRAGMENT_SPREAD location", + locations=[DirectiveLocation.FRAGMENT_SPREAD], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +inline_fragment_directive = GraphQLDirective( + name="inlineFragment", + description="Test directive for INLINE_FRAGMENT location", + locations=[DirectiveLocation.INLINE_FRAGMENT], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +fragment_definition_directive = GraphQLDirective( + name="fragmentDefinition", + description="Test directive for FRAGMENT_DEFINITION location", + locations=[DirectiveLocation.FRAGMENT_DEFINITION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +mutation_directive = GraphQLDirective( + name="mutation", + description="Test directive for MUTATION location (tests keyword conflict)", + locations=[DirectiveLocation.MUTATION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +subscription_directive = GraphQLDirective( + name="subscription", + description="Test directive for SUBSCRIPTION location", + locations=[DirectiveLocation.SUBSCRIPTION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +variable_definition_directive = GraphQLDirective( + name="variableDefinition", + description="Test directive for VARIABLE_DEFINITION location", + locations=[DirectiveLocation.VARIABLE_DEFINITION], + args={ + "value": GraphQLArgument( + GraphQLString, description="A string value for the variable" + ) + }, +) + +repeat_directive = GraphQLDirective( + name="repeat", + description="Test repeatable directive for FIELD location", + locations=[DirectiveLocation.FIELD], + args={ + "value": GraphQLArgument( + GraphQLString, + description="A string value for the repeatable directive", + ) + }, + is_repeatable=True, +) + StarWarsSchema = GraphQLSchema( query=query_type, mutation=mutation_type, subscription=subscription_type, types=[human_type, droid_type, review_type, review_input_type], + directives=[ + *specified_directives, + query_directive, + field_directive, + fragment_spread_directive, + inline_fragment_directive, + fragment_definition_directive, + mutation_directive, + subscription_directive, + variable_definition_directive, + repeat_directive, + ], ) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index e47a97d8..a3d1ef8c 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -23,7 +23,9 @@ from gql import Client, gql from gql.dsl import ( + DSLField, DSLFragment, + DSLFragmentSpread, DSLInlineFragment, DSLMetaField, DSLMutation, @@ -47,6 +49,12 @@ def ds(): return DSLSchema(StarWarsSchema) +@pytest.fixture +def var(): + """Common DSLVariableDefinitions fixture for directive tests""" + return DSLVariableDefinitions() + + @pytest.fixture def client(): return Client(schema=StarWarsSchema) @@ -659,7 +667,23 @@ def test_fragments_repr(ds): assert repr(DSLInlineFragment()) == "" assert repr(DSLInlineFragment().on(ds.Droid)) == "" assert repr(DSLFragment("fragment_1")) == "" + assert repr(DSLFragment("fragment_1").spread()) == "" assert repr(DSLFragment("fragment_2").on(ds.Droid)) == "" + assert ( + repr(DSLFragment("fragment_2").on(ds.Droid).spread()) + == "" + ) + + +def test_fragment_spread_instances(ds): + """Test that each .spread() creates new DSLFragmentSpread instance""" + fragment = DSLFragment("Test").on(ds.Character).select(ds.Character.name) + spread1 = fragment.spread() + spread2 = fragment.spread() + + assert isinstance(spread1, DSLFragmentSpread) + assert isinstance(spread2, DSLFragmentSpread) + assert spread1 is not spread2 def test_fragments(ds): @@ -1271,3 +1295,271 @@ def test_legacy_fragment_with_variables(ds): } """.strip() assert print_ast(query.document) == expected + + +def test_dsl_schema_call_validation(ds): + with pytest.raises(ValueError, match="(?i)unsupported shortcut"): + ds("foo") + + +def test_executable_directives(ds, var): + """Test ALL executable directive locations and types in one document""" + + # Fragment with both built-in and custom directives + fragment = ( + DSLFragment("CharacterInfo") + .on(ds.Character) + .select(ds.Character.name, ds.Character.appearsIn) + .directives(ds("@fragmentDefinition")) + ) + + # Query with multiple directive types + query = DSLQuery( + ds.Query.hero.args(episode=var.episode).select( + # Field with both built-in and custom directives + ds.Character.name.directives( + ds("@skip")(**{"if": var.skipName}), + ds("@field"), # custom field directive + ), + # Field with repeated directives (same directive multiple times) + ds.Character.appearsIn.directives( + ds("@repeat")(value="first"), + ds("@repeat")(value="second"), + ds("@repeat")(value="third"), + ), + # Fragment spread with multiple directives + fragment.spread().directives( + ds("@include")(**{"if": var.includeSpread}), + ds("@fragmentSpread"), + ), + # Inline fragment with directives + DSLInlineFragment() + .on(ds.Human) + .select(ds.Human.homePlanet) + .directives( + ds("@skip")(**{"if": var.skipInline}), + ds("@inlineFragment"), + ), + # Meta field with directive + DSLMetaField("__typename").directives( + ds("@include")(**{"if": var.includeType}) + ), + ) + ).directives(ds("@query")) + + # Mutation with directives + mutation = DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "Great!"} + ).select(ds.Review.stars, ds.Review.commentary) + ).directives(ds("@mutation")) + + # Subscription with directives + subscription = DSLSubscription( + ds.Subscription.reviewAdded.args(episode=6).select( + ds.Review.stars, ds.Review.commentary + ) + ).directives(ds("@subscription")) + + # Variable definitions with directives + var.episode.directives( + # Note that `$episode: Episode @someDirective(value=$someValue)` + # is INVALID GraphQL because variable definitions must be literal values + ds("@variableDefinition"), + ) + query.variable_definitions = var + + # Generate ONE document with everything + doc = dsl_gql( + fragment, HeroQuery=query, CreateReview=mutation, ReviewSub=subscription + ) + + expected = """\ +fragment CharacterInfo on Character @fragmentDefinition { + name + appearsIn +} + +query HeroQuery(\ +$episode: Episode @variableDefinition, \ +$skipName: Boolean!, \ +$includeSpread: Boolean!, \ +$skipInline: Boolean!, \ +$includeType: Boolean!\ +) @query { + hero(episode: $episode) { + name @skip(if: $skipName) @field + appearsIn @repeat(value: "first") @repeat(value: "second") @repeat(value: "third") + ...CharacterInfo @include(if: $includeSpread) @fragmentSpread + ... on Human @skip(if: $skipInline) @inlineFragment { + homePlanet + } + __typename @include(if: $includeType) + } +} + +mutation CreateReview @mutation { + createReview(episode: JEDI, review: {stars: 5, commentary: "Great!"}) { + stars + commentary + } +} + +subscription ReviewSub @subscription { + reviewAdded(episode: JEDI) { + stars + commentary + } +}""" + + assert strip_braces_spaces(print_ast(doc.document)) == expected + assert node_tree(doc.document) == node_tree(gql(expected).document) + + +def test_directive_repr(ds): + """Test DSLDirective string representation""" + directive = ds("@include")(**{"if": True}) + expected = "" + assert repr(directive) == expected + + +def test_directive_error_handling(ds): + """Test error handling for directives""" + # Invalid directive argument type + with pytest.raises(TypeError, match="Expected DSLDirective"): + ds.Query.hero.directives(123) + + # Invalid directive name from `__call__ + with pytest.raises(GraphQLError, match="Directive '@nonexistent' not found"): + ds("@nonexistent") + + # Invalid directive argument + with pytest.raises(GraphQLError, match="Argument 'invalid' does not exist"): + ds("@include")(invalid=True) + + # Tried to set arguments twice + with pytest.raises( + AttributeError, match="Arguments for directive @field already set." + ): + ds("@field").args(value="foo").args(value="bar") + + with pytest.raises( + GraphQLError, + match="(?i)Directive '@deprecated' is not a valid request executable directive", + ): + ds("@deprecated") + + with pytest.raises(GraphQLError, match="unexpected variable"): + # variable definitions must be static, literal values defined in the query! + var = DSLVariableDefinitions() + query = DSLQuery( + ds.Query.hero.args(episode=var.episode).select(ds.Character.name) + ) + var.episode.directives( + ds("@variableDefinition").args(value=var.nonStatic), + ) + query.variable_definitions = var + _ = dsl_gql(query).document + + +# Parametrized tests for comprehensive directive location validation +@pytest.fixture( + params=[ + "@query", + "@mutation", + "@subscription", + "@field", + "@fragmentDefinition", + "@fragmentSpread", + "@inlineFragment", + "@variableDefinition", + ] +) +def directive_name(request): + return request.param + + +@pytest.fixture( + params=[ + (DSLQuery, "QUERY"), + (DSLMutation, "MUTATION"), + (DSLSubscription, "SUBSCRIPTION"), + (DSLField, "FIELD"), + (DSLMetaField, "FIELD"), + (DSLFragment, "FRAGMENT_DEFINITION"), + (DSLFragmentSpread, "FRAGMENT_SPREAD"), + (DSLInlineFragment, "INLINE_FRAGMENT"), + (DSLVariable, "VARIABLE_DEFINITION"), + ] +) +def dsl_class_and_location(request): + return request.param + + +@pytest.fixture +def is_valid_combination(directive_name, dsl_class_and_location): + # Map directive names to their expected locations + directive_to_location = { + "@query": "QUERY", + "@mutation": "MUTATION", + "@subscription": "SUBSCRIPTION", + "@field": "FIELD", + "@fragmentDefinition": "FRAGMENT_DEFINITION", + "@fragmentSpread": "FRAGMENT_SPREAD", + "@inlineFragment": "INLINE_FRAGMENT", + "@variableDefinition": "VARIABLE_DEFINITION", + } + expected_location = directive_to_location[directive_name] + _, actual_location = dsl_class_and_location + return expected_location == actual_location + + +def create_dsl_instance(dsl_class, ds): + """Helper function to create DSL instances for testing""" + if dsl_class == DSLQuery: + return DSLQuery(ds.Query.hero.select(ds.Character.name)) + elif dsl_class == DSLMutation: + return DSLMutation( + ds.Mutation.createReview.args(episode=6, review={"stars": 5}).select( + ds.Review.stars + ) + ) + elif dsl_class == DSLSubscription: + return DSLSubscription( + ds.Subscription.reviewAdded.args(episode=6).select(ds.Review.stars) + ) + elif dsl_class == DSLField: + return ds.Query.hero + elif dsl_class == DSLMetaField: + return DSLMetaField("__typename") + elif dsl_class == DSLFragment: + return DSLFragment("test").on(ds.Character).select(ds.Character.name) + elif dsl_class == DSLFragmentSpread: + fragment = DSLFragment("test").on(ds.Character).select(ds.Character.name) + return fragment.spread() + elif dsl_class == DSLInlineFragment: + return DSLInlineFragment().on(ds.Human).select(ds.Human.homePlanet) + elif dsl_class == DSLVariable: + var = DSLVariableDefinitions() + return var.testVar + else: + raise ValueError(f"Unknown DSL class: {dsl_class}") + + +def test_directive_location_validation( + ds, directive_name, dsl_class_and_location, is_valid_combination +): + """Test all 64 combinations of 8 directives × 8 DSL classes""" + dsl_class, _ = dsl_class_and_location + directive = ds(directive_name) + + # Create instance of DSL class and try to apply directive + instance = create_dsl_instance(dsl_class, ds) + + if is_valid_combination: + # Should work without error + instance.directives(directive) + else: + # Should raise GraphQLError for invalid location + with pytest.raises(GraphQLError, match="Invalid directive location"): + instance.directives(directive)