Skip to content

Commit bb2c317

Browse files
authored
GENKGB-412 (#455)
* feat: add ConstraintType to GraphSchema for constraint extraction * feat: add tests for ConstraintType * feat: add tests for ConstraintType * feat: add tests for ConstraintType * constraints type check * constraints type check * constraints type check
1 parent 06fe74f commit bb2c317

File tree

3 files changed

+535
-3
lines changed

3 files changed

+535
-3
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class PropertyType(BaseModel):
7676
]
7777
description: str = ""
7878
required: bool = False
79-
8079
model_config = ConfigDict(
8180
frozen=True,
8281
)
@@ -161,6 +160,22 @@ def property_type_from_name(self, name: str) -> Optional[PropertyType]:
161160
return None
162161

163162

163+
class ConstraintType(BaseModel):
164+
"""
165+
Represents a constraint on a node in the graph.
166+
"""
167+
168+
type: Literal[
169+
"UNIQUENESS"
170+
] # TODO: add other constraint types ["propertyExistence", "propertyType", "key"]
171+
node_type: str
172+
property_name: str
173+
174+
model_config = ConfigDict(
175+
frozen=True,
176+
)
177+
178+
164179
class GraphSchema(DataModel):
165180
"""This model represents the expected
166181
node and relationship types in the graph.
@@ -177,6 +192,7 @@ class GraphSchema(DataModel):
177192
node_types: Tuple[NodeType, ...]
178193
relationship_types: Tuple[RelationshipType, ...] = tuple()
179194
patterns: Tuple[Tuple[str, str, str], ...] = tuple()
195+
constraints: Tuple[ConstraintType, ...] = tuple()
180196

181197
additional_node_types: bool = Field(
182198
default_factory=default_additional_item("node_types")
@@ -239,6 +255,34 @@ def validate_additional_parameters(self) -> Self:
239255
)
240256
return self
241257

258+
@model_validator(mode="after")
259+
def validate_constraints_against_node_types(self) -> Self:
260+
if not self.constraints:
261+
return self
262+
for constraint in self.constraints:
263+
# Only validate UNIQUENESS constraints (other types will be added)
264+
if constraint.type != "UNIQUENESS":
265+
continue
266+
267+
if not constraint.property_name:
268+
raise SchemaValidationError(
269+
f"Constraint has no property name: {constraint}. Property name is required."
270+
)
271+
if constraint.node_type not in self._node_type_index:
272+
raise SchemaValidationError(
273+
f"Constraint references undefined node type: {constraint.node_type}"
274+
)
275+
# Check if property_name exists on the node type
276+
node_type = self._node_type_index[constraint.node_type]
277+
valid_property_names = {p.name for p in node_type.properties}
278+
if constraint.property_name not in valid_property_names:
279+
raise SchemaValidationError(
280+
f"Constraint references undefined property '{constraint.property_name}' "
281+
f"on node type '{constraint.node_type}'. "
282+
f"Valid properties: {valid_property_names}"
283+
)
284+
return self
285+
242286
def node_type_from_label(self, label: str) -> Optional[NodeType]:
243287
return self._node_type_index.get(label)
244288

@@ -382,6 +426,7 @@ def create_schema_model(
382426
node_types: Sequence[NodeType],
383427
relationship_types: Optional[Sequence[RelationshipType]] = None,
384428
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
429+
constraints: Optional[Sequence[ConstraintType]] = None,
385430
**kwargs: Any,
386431
) -> GraphSchema:
387432
"""
@@ -403,6 +448,7 @@ def create_schema_model(
403448
node_types=node_types,
404449
relationship_types=relationship_types or (),
405450
patterns=patterns or (),
451+
constraints=constraints or (),
406452
**kwargs,
407453
)
408454
)
@@ -415,6 +461,7 @@ async def run(
415461
node_types: Sequence[NodeType],
416462
relationship_types: Optional[Sequence[RelationshipType]] = None,
417463
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
464+
constraints: Optional[Sequence[ConstraintType]] = None,
418465
**kwargs: Any,
419466
) -> GraphSchema:
420467
"""
@@ -432,6 +479,7 @@ async def run(
432479
node_types,
433480
relationship_types,
434481
patterns,
482+
constraints,
435483
**kwargs,
436484
)
437485

@@ -555,6 +603,69 @@ def _filter_relationships_without_labels(
555603
relationship_types, "relationship type"
556604
)
557605

606+
def _filter_invalid_constraints(
607+
self, constraints: List[Dict[str, Any]], node_types: List[Dict[str, Any]]
608+
) -> List[Dict[str, Any]]:
609+
"""Filter out constraints that reference undefined node types, have no property name, are not UNIQUENESS type
610+
or reference a property that doesn't exist on the node type."""
611+
if not constraints:
612+
return []
613+
614+
if not node_types:
615+
logging.info(
616+
"Filtering out all constraints because no node types are defined. "
617+
"Constraints reference node types that must be defined."
618+
)
619+
return []
620+
621+
# Build a mapping of node_type label -> set of property names
622+
node_type_properties: Dict[str, set[str]] = {}
623+
for node_type_dict in node_types:
624+
label = node_type_dict.get("label")
625+
if label:
626+
properties = node_type_dict.get("properties", [])
627+
property_names = {p.get("name") for p in properties if p.get("name")}
628+
node_type_properties[label] = property_names
629+
630+
valid_node_labels = set(node_type_properties.keys())
631+
632+
filtered_constraints = []
633+
for constraint in constraints:
634+
# Only process UNIQUENESS constraints (other types will be added)
635+
if constraint.get("type") != "UNIQUENESS":
636+
logging.info(
637+
f"Filtering out constraint: {constraint}. "
638+
f"Only UNIQUENESS constraints are supported."
639+
)
640+
continue
641+
642+
# check if the property_name is provided
643+
if not constraint.get("property_name"):
644+
logging.info(
645+
f"Filtering out constraint: {constraint}. "
646+
f"Property name is not provided."
647+
)
648+
continue
649+
# check if the node_type is valid
650+
node_type = constraint.get("node_type")
651+
if node_type not in valid_node_labels:
652+
logging.info(
653+
f"Filtering out constraint: {constraint}. "
654+
f"Node type '{node_type}' is not valid. Valid node types: {valid_node_labels}"
655+
)
656+
continue
657+
# check if the property_name exists on the node type
658+
property_name = constraint.get("property_name")
659+
if property_name not in node_type_properties.get(node_type, set()):
660+
logging.info(
661+
f"Filtering out constraint: {constraint}. "
662+
f"Property '{property_name}' does not exist on node type '{node_type}'. "
663+
f"Valid properties: {node_type_properties.get(node_type, set())}"
664+
)
665+
continue
666+
filtered_constraints.append(constraint)
667+
return filtered_constraints
668+
558669
def _clean_json_content(self, content: str) -> str:
559670
content = content.strip()
560671

@@ -624,6 +735,9 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
624735
extracted_patterns: Optional[List[Tuple[str, str, str]]] = extracted_schema.get(
625736
"patterns"
626737
)
738+
extracted_constraints: Optional[List[Dict[str, Any]]] = extracted_schema.get(
739+
"constraints"
740+
)
627741

628742
# Filter out nodes and relationships without labels
629743
extracted_node_types = self._filter_nodes_without_labels(extracted_node_types)
@@ -638,11 +752,18 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
638752
extracted_patterns, extracted_node_types, extracted_relationship_types
639753
)
640754

755+
# Filter out invalid constraints
756+
if extracted_constraints:
757+
extracted_constraints = self._filter_invalid_constraints(
758+
extracted_constraints, extracted_node_types
759+
)
760+
641761
return GraphSchema.model_validate(
642762
{
643763
"node_types": extracted_node_types,
644764
"relationship_types": extracted_relationship_types,
645765
"patterns": extracted_patterns,
766+
"constraints": extracted_constraints or [],
646767
}
647768
)
648769

src/neo4j_graphrag/generation/prompts.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ class SchemaExtractionTemplate(PromptTemplate):
218218
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
219219
6. Do not create node types that aren't clearly mentioned in the text.
220220
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.
221+
8. UNIQUENESS CONSTRAINTS:
222+
8.1 UNIQUENESS is optional; each node_type may or may not have exactly one uniqueness constraint.
223+
8.2 Only use properties that seem to not have too many missing values in the sample.
224+
8.3 Constraints reference node_types by label and specify which property is unique.
225+
8.4 If a property appears in a uniqueness constraint it MUST also appear in the corresponding node_type as a property.
226+
221227
222228
Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
223229
LOCAL_DATETIME, LOCAL_TIME, POINT, STRING, ZONED_DATETIME, ZONED_TIME.
@@ -233,18 +239,26 @@ class SchemaExtractionTemplate(PromptTemplate):
233239
"type": "STRING"
234240
}}
235241
]
236-
}},
242+
}}
237243
...
238244
],
239245
"relationship_types": [
240246
{{
241247
"label": "WORKS_FOR"
242-
}},
248+
}}
243249
...
244250
],
245251
"patterns": [
246252
["Person", "WORKS_FOR", "Company"],
247253
...
254+
],
255+
"constraints": [
256+
{{
257+
"type": "UNIQUENESS",
258+
"node_type": "Person",
259+
"property_name": "name"
260+
}}
261+
...
248262
]
249263
}}
250264

0 commit comments

Comments
 (0)