|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
| 15 | +from typing import Any |
15 | 16 | from unittest import mock |
16 | 17 | from unittest.mock import MagicMock |
17 | 18 |
|
@@ -70,7 +71,7 @@ def test_qdrant_retriever_search_happy_path( |
70 | 71 | driver.execute_query.assert_called_once_with( |
71 | 72 | search_query, |
72 | 73 | { |
73 | | - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], |
| 74 | + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], |
74 | 75 | "id_property": "sync_id", |
75 | 76 | }, |
76 | 77 | database_=None, |
@@ -149,7 +150,7 @@ def test_qdrant_retriever_search_return_properties( |
149 | 150 | driver.execute_query.assert_called_once_with( |
150 | 151 | search_query, |
151 | 152 | { |
152 | | - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], |
| 153 | + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], |
153 | 154 | "id_property": "sync_id", |
154 | 155 | }, |
155 | 156 | database_=None, |
@@ -215,7 +216,7 @@ def test_qdrant_retriever_search_retrieval_query( |
215 | 216 | driver.execute_query.assert_called_once_with( |
216 | 217 | search_query, |
217 | 218 | { |
218 | | - "match_params": [[f"node_{i}", i / top_k] for i in range(top_k)], |
| 219 | + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], |
219 | 220 | "id_property": "sync_id", |
220 | 221 | }, |
221 | 222 | database_=None, |
@@ -267,3 +268,70 @@ def test_qdrant_retriever_invalid_retrieval_query( |
267 | 268 |
|
268 | 269 | assert "retrieval_query" in str(exc_info.value) |
269 | 270 | assert "Input should be a valid string" in str(exc_info.value) |
| 271 | + |
| 272 | + |
| 273 | +def test_qdrant_retriever_search_custom_match_id_getter( |
| 274 | + driver: MagicMock, client: MagicMock |
| 275 | +) -> None: |
| 276 | + def my_id_getter(point: ScoredPoint) -> Any: |
| 277 | + if point.payload is None: |
| 278 | + raise Exception("Payload is None") |
| 279 | + return point.payload["data"]["id"] |
| 280 | + |
| 281 | + retriever = QdrantNeo4jRetriever( |
| 282 | + driver=driver, |
| 283 | + client=client, |
| 284 | + collection_name="dummy-text", |
| 285 | + id_property_neo4j="sync_id", |
| 286 | + id_property_getter=my_id_getter, |
| 287 | + ) |
| 288 | + with mock.patch.object(retriever, "client") as mock_client: |
| 289 | + top_k = 5 |
| 290 | + mock_client.query_points.return_value = QueryResponse( |
| 291 | + points=[ |
| 292 | + ScoredPoint( |
| 293 | + id=i, |
| 294 | + version=0, |
| 295 | + score=i / top_k, |
| 296 | + payload={ |
| 297 | + "data": {"id": f"node_{i}"}, |
| 298 | + }, |
| 299 | + ) |
| 300 | + for i in range(top_k) |
| 301 | + ] |
| 302 | + ) |
| 303 | + driver.execute_query.return_value = ( |
| 304 | + [ |
| 305 | + neo4j.Record({"node": {"sync_id": f"node_{i}"}, "score": i / top_k}) |
| 306 | + for i in range(top_k) |
| 307 | + ], |
| 308 | + None, |
| 309 | + None, |
| 310 | + ) |
| 311 | + query_vector = [1.0 for _ in range(1536)] |
| 312 | + search_query = get_match_query() |
| 313 | + records = retriever.search(query_vector=query_vector) |
| 314 | + |
| 315 | + driver.execute_query.assert_called_once_with( |
| 316 | + search_query, |
| 317 | + { |
| 318 | + "match_params": [(f"node_{i}", i / top_k) for i in range(top_k)], |
| 319 | + "id_property": "sync_id", |
| 320 | + }, |
| 321 | + database_=None, |
| 322 | + routing_=neo4j.RoutingControl.READ, |
| 323 | + ) |
| 324 | + |
| 325 | + assert records == RetrieverResult( |
| 326 | + items=[ |
| 327 | + RetrieverResultItem( |
| 328 | + content="<Record node={'sync_id': " |
| 329 | + + f"'node_{i}'" |
| 330 | + + "} " |
| 331 | + + f"score={i / top_k}>", |
| 332 | + metadata=None, |
| 333 | + ) |
| 334 | + for i in range(top_k) |
| 335 | + ], |
| 336 | + metadata={"__retriever": "QdrantNeo4jRetriever"}, |
| 337 | + ) |
0 commit comments