Skip to content

Commit 03ddb3c

Browse files
authored
External retriever node label (#451)
* Add an node label to filter on for external retriever base class and WeaviateRetriever * Add node_label_neo4j parameter to Qdrant and Pinecone retrievers * Update doc + do not force escaping the node label to allow node label expression * Update unit test
1 parent 8ed0eab commit 03ddb3c

File tree

14 files changed

+66
-6
lines changed

14 files changed

+66
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
## Next
44

5+
### Added
6+
7+
- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.
8+
9+
510
## 1.10.1
611

712
### Added

docs/source/user_guide_rag.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ Weaviate Retrievers
762762
collection="Movies",
763763
id_property_external="neo4j_id",
764764
id_property_neo4j="id",
765+
node_label_neo4j="Document", # optional
765766
)
766767
767768
Internally, this retriever performs the vector search in Weaviate, finds the corresponding node by matching
@@ -795,6 +796,7 @@ Pinecone Retrievers
795796
index_name="Movies",
796797
id_property_neo4j="id",
797798
embedder=embedder,
799+
node_label_neo4j="Document", # optional
798800
)
799801
800802
Also see :ref:`pineconeneo4jretriever`.
@@ -825,6 +827,7 @@ Qdrant Retrievers
825827
id_property_external="neo4j_id", # The payload field that contains identifier to a corresponding Neo4j node id property
826828
id_property_neo4j="id",
827829
embedder=embedder,
830+
node_label_neo4j="Document", # optional
828831
)
829832
830833
See :ref:`qdrantneo4jretriever`.

src/neo4j_graphrag/retrievers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,12 @@ def __init__(
454454
id_property_external: str,
455455
id_property_neo4j: str,
456456
neo4j_database: Optional[str] = None,
457+
node_label_neo4j: Optional[str] = None,
457458
):
458459
super().__init__(driver)
459460
self.id_property_external = id_property_external
460461
self.id_property_neo4j = id_property_neo4j
462+
self.node_label_neo4j = node_label_neo4j
461463
self.neo4j_database = neo4j_database
462464

463465
@abstractmethod

src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
8383
retrieval_query (str): Cypher query that gets appended.
8484
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8585
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
86+
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".
8687
8788
Raises:
8889
RetrieverInitializationError: If validation of the input arguments fail.
@@ -101,6 +102,7 @@ def __init__(
101102
Callable[[neo4j.Record], RetrieverResultItem]
102103
] = None,
103104
neo4j_database: Optional[str] = None,
105+
node_label_neo4j: Optional[str] = None,
104106
):
105107
try:
106108
driver_model = Neo4jDriverModel(driver=driver)
@@ -116,6 +118,7 @@ def __init__(
116118
retrieval_query=retrieval_query,
117119
result_formatter=result_formatter,
118120
neo4j_database=neo4j_database,
121+
node_label_neo4j=node_label_neo4j,
119122
)
120123
except ValidationError as e:
121124
raise RetrieverInitializationError(e.errors()) from e
@@ -125,6 +128,7 @@ def __init__(
125128
id_property_external="id",
126129
id_property_neo4j=validated_data.id_property_neo4j,
127130
neo4j_database=neo4j_database,
131+
node_label_neo4j=node_label_neo4j,
128132
)
129133
self.driver = validated_data.driver_model.driver
130134
self.client = validated_data.client_model.client
@@ -172,7 +176,8 @@ def get_search_results(
172176
driver=neo4j_driver,
173177
client=pc_client,
174178
index_name="jeopardy",
175-
id_property_neo4j="id"
179+
id_property_neo4j="id",
180+
node_label_neo4j="Document",
176181
)
177182
biology_embedding = ...
178183
retriever.search(query_vector=biology_embedding, top_k=2)
@@ -223,6 +228,7 @@ def get_search_results(
223228
search_query = get_match_query(
224229
return_properties=self.return_properties,
225230
retrieval_query=self.retrieval_query,
231+
node_label=self.node_label_neo4j,
226232
)
227233

228234
parameters = {

src/neo4j_graphrag/retrievers/external/pinecone/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ class PineconeNeo4jRetrieverModel(BaseModel):
5959
retrieval_query: Optional[str] = None
6060
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
6161
neo4j_database: Optional[str] = None
62+
node_label_neo4j: Optional[str] = None

src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class QdrantNeo4jRetriever(ExternalRetriever):
7979
return_properties (Optional[list[str]]): List of node properties to return.
8080
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8181
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
82+
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".
8283
8384
Raises:
8485
RetrieverInitializationError: If validation of the input arguments fail.
@@ -99,6 +100,7 @@ def __init__(
99100
Callable[[neo4j.Record], RetrieverResultItem]
100101
] = None,
101102
neo4j_database: Optional[str] = None,
103+
node_label_neo4j: Optional[str] = None,
102104
):
103105
try:
104106
driver_model = Neo4jDriverModel(driver=driver)
@@ -116,6 +118,7 @@ def __init__(
116118
retrieval_query=retrieval_query,
117119
result_formatter=result_formatter,
118120
neo4j_database=neo4j_database,
121+
node_label_neo4j=node_label_neo4j,
119122
)
120123
except ValidationError as e:
121124
raise RetrieverInitializationError(e.errors()) from e
@@ -125,6 +128,7 @@ def __init__(
125128
id_property_external=validated_data.id_property_external,
126129
id_property_neo4j=validated_data.id_property_neo4j,
127130
neo4j_database=neo4j_database,
131+
node_label_neo4j=node_label_neo4j,
128132
)
129133
self.driver = validated_data.driver_model.driver
130134
self.client = validated_data.client_model.client
@@ -169,7 +173,8 @@ def get_search_results(
169173
driver=neo4j_driver,
170174
client=client,
171175
collection_name="my_collection",
172-
id_property_external="neo4j_id"
176+
id_property_external="neo4j_id",
177+
node_label_neo4j="Document",
173178
)
174179
embedding = ...
175180
retriever.search(query_vector=embedding, top_k=2)
@@ -223,6 +228,7 @@ def get_search_results(
223228
search_query = get_match_query(
224229
return_properties=self.return_properties,
225230
retrieval_query=self.retrieval_query,
231+
node_label=self.node_label_neo4j,
226232
)
227233

228234
parameters = {

src/neo4j_graphrag/retrievers/external/qdrant/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ class QdrantNeo4jRetrieverModel(BaseModel):
5454
retrieval_query: Optional[str] = None
5555
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
5656
neo4j_database: Optional[str] = None
57+
node_label_neo4j: Optional[str] = None

src/neo4j_graphrag/retrievers/external/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,18 @@
2020

2121

2222
def get_match_query(
23-
return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None
23+
return_properties: Optional[list[str]] = None,
24+
retrieval_query: Optional[str] = None,
25+
node_label: Optional[str] = None,
2426
) -> str:
27+
# node_label is not escaped on purpose, allowing users to use any valid
28+
# node label expression, e.g. "Actor|Director". It's up to the user to ensure
29+
# labels are properly escaped, i.e. "`My label with space`".
30+
node_label_expression = f":{node_label}" if node_label else ""
2531
match_query = (
2632
"UNWIND $match_params AS match_param "
2733
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
28-
"MATCH (node) "
34+
f"MATCH (node{node_label_expression}) "
2935
"WHERE node[$id_property] = match_id_value "
3036
)
3137
return match_query + get_query_tail(

src/neo4j_graphrag/retrievers/external/weaviate/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
5757
retrieval_query: Optional[str] = None
5858
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
5959
neo4j_database: Optional[str] = None
60+
node_label_neo4j: Optional[str] = None
6061

6162

6263
class WeaviateNeo4jSearchModel(VectorSearchModel):

src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
8181
return_properties (Optional[list[str]]): List of node properties to return.
8282
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8383
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
84+
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".
8485
8586
Raises:
8687
RetrieverInitializationError: If validation of the input arguments fail.
@@ -100,6 +101,7 @@ def __init__(
100101
Callable[[neo4j.Record], RetrieverResultItem]
101102
] = None,
102103
neo4j_database: Optional[str] = None,
104+
node_label_neo4j: Optional[str] = None,
103105
):
104106
try:
105107
driver_model = Neo4jDriverModel(driver=driver)
@@ -116,12 +118,17 @@ def __init__(
116118
retrieval_query=retrieval_query,
117119
result_formatter=result_formatter,
118120
neo4j_database=neo4j_database,
121+
node_label_neo4j=node_label_neo4j,
119122
)
120123
except ValidationError as e:
121124
raise RetrieverInitializationError(e.errors()) from e
122125

123126
super().__init__(
124-
driver, id_property_external, id_property_neo4j, neo4j_database
127+
driver,
128+
id_property_external,
129+
id_property_neo4j,
130+
neo4j_database,
131+
node_label_neo4j,
125132
)
126133
self.client = validated_data.client_model.client
127134
collection = validated_data.collection
@@ -164,6 +171,7 @@ def get_search_results(
164171
collection="Jeopardy",
165172
id_property_external="neo4j_id",
166173
id_property_neo4j="id",
174+
node_label_neo4j="Document",
167175
)
168176
169177
biology_embedding = ...
@@ -234,6 +242,7 @@ def get_search_results(
234242
search_query = get_match_query(
235243
return_properties=self.return_properties,
236244
retrieval_query=self.retrieval_query,
245+
node_label=self.node_label_neo4j,
237246
)
238247

239248
parameters = {

0 commit comments

Comments
 (0)