Skip to content

Commit d1291a0

Browse files
authored
Add id property getter to qdrant retriever (#453)
* Add id property getter to qdrant retriever * Missing import * List => tuple
1 parent 03ddb3c commit d1291a0

File tree

7 files changed

+92
-16
lines changed

7 files changed

+92
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
### Added
66

77
- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.
8-
8+
- Added an optional `id_property_getter` callable parameter in the Qdrant retriever to allow for custom ID retrieval.
99

1010
## 1.10.1
1111

src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def get_search_results(
221221
)
222222

223223
result_tuples = [
224-
[f"{o[self.id_property_external]}", o["score"] or 0.0]
224+
(f"{o[self.id_property_external]}", o["score"] or 0.0)
225225
for o in response["matches"]
226226
]
227227

src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import neo4j
2121
from pydantic import ValidationError
2222
from qdrant_client import QdrantClient
23+
from qdrant_client.conversions.common_types import ScoredPoint
2324

2425
from neo4j_graphrag.embeddings.base import Embedder
2526
from neo4j_graphrag.exceptions import (
@@ -80,6 +81,7 @@ class QdrantNeo4jRetriever(ExternalRetriever):
8081
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8182
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
8283
node_label_neo4j (Optional[str]): The label of the Neo4j node to retrieve. This label must be properly escaped if needed, eg "`Label with spaces`".
84+
id_property_getter (Optional[Callable[[ScoredPoint], str]]): Function to get the id property from a ScoredPoint. Defaults to point.payload.get(id_property_external, point.id).
8385
8486
Raises:
8587
RetrieverInitializationError: If validation of the input arguments fail.
@@ -101,6 +103,7 @@ def __init__(
101103
] = None,
102104
neo4j_database: Optional[str] = None,
103105
node_label_neo4j: Optional[str] = None,
106+
id_property_getter: Optional[Callable[[ScoredPoint], Any]] = None,
104107
):
105108
try:
106109
driver_model = Neo4jDriverModel(driver=driver)
@@ -142,6 +145,14 @@ def __init__(
142145
self.return_properties = validated_data.return_properties
143146
self.retrieval_query = validated_data.retrieval_query
144147
self.result_formatter = validated_data.result_formatter
148+
self.id_property_getter = id_property_getter
149+
150+
def get_match_id_from_point(self, point: ScoredPoint) -> Any:
151+
if self.id_property_getter:
152+
return self.id_property_getter(point)
153+
if point.payload is None:
154+
raise ValueError(f"Payload is None for point {point}")
155+
return point.payload.get(self.id_property_external, point.id)
145156

146157
def get_search_results(
147158
self,
@@ -220,10 +231,7 @@ def get_search_results(
220231

221232
result_tuples = []
222233
for point in points:
223-
assert point.payload is not None
224-
result_tuples.append(
225-
[point.payload.get(self.id_property_external, point.id), point.score]
226-
)
234+
result_tuples.append((self.get_match_id_from_point(point), point.score))
227235

228236
search_query = get_match_query(
229237
return_properties=self.return_properties,

src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def get_search_results(
235235
logger.debug("Response: %s", response)
236236

237237
result_tuples = [
238-
[f"{o.properties[self.id_property_external]}", o.metadata.certainty or 0.0]
238+
(f"{o.properties[self.id_property_external]}", o.metadata.certainty or 0.0)
239239
for o in response.objects
240240
]
241241

tests/unit/retrievers/external/test_pinecone.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_pinecone_retriever_search_happy_path(
9595
driver.execute_query.assert_called_once_with(
9696
search_query,
9797
{
98-
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
98+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
9999
"id_property": "sync_id",
100100
},
101101
database_=None,
@@ -168,7 +168,7 @@ def test_pinecone_retriever_search_return_properties(
168168
driver.execute_query.assert_called_once_with(
169169
search_query,
170170
{
171-
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
171+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
172172
"id_property": "sync_id",
173173
},
174174
database_=None,
@@ -228,7 +228,7 @@ def test_pinecone_retriever_search_retrieval_query(
228228
driver.execute_query.assert_called_once_with(
229229
search_query,
230230
{
231-
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
231+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
232232
"id_property": "sync_id",
233233
},
234234
database_=None,

tests/unit/retrievers/external/test_qdrant.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from typing import Any
1516
from unittest import mock
1617
from unittest.mock import MagicMock
1718

@@ -70,7 +71,7 @@ def test_qdrant_retriever_search_happy_path(
7071
driver.execute_query.assert_called_once_with(
7172
search_query,
7273
{
73-
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
74+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
7475
"id_property": "sync_id",
7576
},
7677
database_=None,
@@ -149,7 +150,7 @@ def test_qdrant_retriever_search_return_properties(
149150
driver.execute_query.assert_called_once_with(
150151
search_query,
151152
{
152-
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
153+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
153154
"id_property": "sync_id",
154155
},
155156
database_=None,
@@ -215,7 +216,7 @@ def test_qdrant_retriever_search_retrieval_query(
215216
driver.execute_query.assert_called_once_with(
216217
search_query,
217218
{
218-
"match_params": [[f"node_{i}", i / top_k] for i in range(top_k)],
219+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
219220
"id_property": "sync_id",
220221
},
221222
database_=None,
@@ -267,3 +268,70 @@ def test_qdrant_retriever_invalid_retrieval_query(
267268

268269
assert "retrieval_query" in str(exc_info.value)
269270
assert "Input should be a valid string" in str(exc_info.value)
271+
272+
273+
def test_qdrant_retriever_search_custom_match_id_getter(
274+
driver: MagicMock, client: MagicMock
275+
) -> None:
276+
def my_id_getter(point: ScoredPoint) -> Any:
277+
if point.payload is None:
278+
raise Exception("Payload is None")
279+
return point.payload["data"]["id"]
280+
281+
retriever = QdrantNeo4jRetriever(
282+
driver=driver,
283+
client=client,
284+
collection_name="dummy-text",
285+
id_property_neo4j="sync_id",
286+
id_property_getter=my_id_getter,
287+
)
288+
with mock.patch.object(retriever, "client") as mock_client:
289+
top_k = 5
290+
mock_client.query_points.return_value = QueryResponse(
291+
points=[
292+
ScoredPoint(
293+
id=i,
294+
version=0,
295+
score=i / top_k,
296+
payload={
297+
"data": {"id": f"node_{i}"},
298+
},
299+
)
300+
for i in range(top_k)
301+
]
302+
)
303+
driver.execute_query.return_value = (
304+
[
305+
neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k})
306+
for i in range(top_k)
307+
],
308+
None,
309+
None,
310+
)
311+
query_vector = [1.0 for _ in range(1536)]
312+
search_query = get_match_query()
313+
records = retriever.search(query_vector=query_vector)
314+
315+
driver.execute_query.assert_called_once_with(
316+
search_query,
317+
{
318+
"match_params": [(f"node_{i}", i / top_k) for i in range(top_k)],
319+
"id_property": "sync_id",
320+
},
321+
database_=None,
322+
routing_=neo4j.RoutingControl.READ,
323+
)
324+
325+
assert records == RetrieverResult(
326+
items=[
327+
RetrieverResultItem(
328+
content="<Record node={'sync_id': "
329+
+ f"'node_{i}'"
330+
+ "} "
331+
+ f"score={i / top_k}>",
332+
metadata=None,
333+
)
334+
for i in range(top_k)
335+
],
336+
metadata={"__retriever": "QdrantNeo4jRetriever"},
337+
)

tests/unit/retrievers/external/test_weaviate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_text_search_remote_vector_store_happy_path(driver: MagicMock) -> None:
7575
search_query,
7676
{
7777
"match_params": [
78-
[node_id_value, node_match_score],
78+
(node_id_value, node_match_score),
7979
],
8080
"id_property": "sync_id",
8181
},
@@ -142,7 +142,7 @@ def test_text_search_remote_vector_store_return_properties(driver: MagicMock) ->
142142
search_query,
143143
{
144144
"match_params": [
145-
[node_id_value, node_match_score],
145+
(node_id_value, node_match_score),
146146
],
147147
"id_property": "sync_id",
148148
},
@@ -190,7 +190,7 @@ def test_text_search_remote_vector_store_retrieval_query(driver: MagicMock) -> N
190190
search_query,
191191
{
192192
"match_params": [
193-
[node_id_value, node_match_score],
193+
(node_id_value, node_match_score),
194194
],
195195
"id_property": "sync_id",
196196
},

0 commit comments

Comments
 (0)