diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 941b59106..a744e16e2 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -192,26 +192,19 @@ def handle_get_memories( del memories["total_edges"] preferences: list[TextualMemoryItem] = [] - total_explicit_nodes, total_implicit_nodes = 0, 0 + total_pref = 0 + if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} if get_mem_req.user_id is not None: filter_params["user_id"] = get_mem_req.user_id if get_mem_req.mem_cube_id is not None: filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - preferences = naive_mem_cube.pref_mem.get_memory_by_filter( + + preferences, total_pref = naive_mem_cube.pref_mem.get_memory_by_filter( filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size ) - - for key, value_list in preferences.items(): - if key in ["explicit_preference", "implicit_preference"]: - formatted_list = [format_memory_item(item) for item in value_list] - preferences[key] = formatted_list - - total_explicit_nodes = preferences["total_explicit_nodes"] - total_implicit_nodes = preferences["total_implicit_nodes"] - del preferences["total_explicit_nodes"] - del preferences["total_implicit_nodes"] + format_preferences = [format_memory_item(item) for item in preferences] return GetMemoryResponse( message="Memories retrieved successfully", @@ -227,9 +220,8 @@ def handle_get_memories( "pref_mem": [ { "cube_id": get_mem_req.mem_cube_id, - "memories": preferences, - "total_explicit_nodes": total_explicit_nodes, - "total_implicit_nodes": total_implicit_nodes, + "memories": format_preferences, + "total_nodes": total_pref, } ], }, diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 75d7d2a4c..cb4f00735 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -1,6 +1,7 @@ import json import os +from datetime import datetime from typing import Any from memos.configs.memory import PreferenceTextMemoryConfig @@ -262,8 +263,11 @@ def get_all(self) -> list[TextualMemoryItem]: return all_memories def get_memory_by_filter( - self, filter: dict[str, Any] | None = None, **kwargs - ) -> list[TextualMemoryItem]: + self, + filter: dict[str, Any] | None = None, + page: int | None = None, + page_size: int | None = None, + ): """Get memories by filter. Args: filter (dict[str, Any]): Filter criteria. @@ -272,14 +276,9 @@ def get_memory_by_filter( """ collection_list = self.vector_db.config.collection_name - memories = {} - total_explicit_nodes = 0 - total_implicit_nodes = 0 + memories = [] for collection_name in collection_list: - memories[collection_name] = [] - db_items, total_count = self.vector_db.get_by_filter( - collection_name=collection_name, filter=filter, count_total=True, **kwargs - ) + db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter) db_items_memory = [ TextualMemoryItem( id=memo.id, @@ -288,16 +287,23 @@ def get_memory_by_filter( ) for memo in db_items ] - memories[collection_name].extend(db_items_memory) + memories.extend(db_items_memory) - if collection_name == "explicit_preference": - total_explicit_nodes = total_count - if collection_name == "implicit_preference": - total_implicit_nodes = total_count - memories["total_explicit_nodes"] = total_explicit_nodes - memories["total_implicit_nodes"] = total_implicit_nodes - - return memories + # sort + sorted_memories = sorted( + memories, + key=lambda item: datetime.fromisoformat(item.metadata.created_at), + reverse=True, + ) + if page and page_size: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + pick_memories = sorted_memories[(page - 1) * page_size : page * page_size] + return pick_memories, len(sorted_memories) + + return sorted_memories, len(sorted_memories) def delete(self, memory_ids: list[str]) -> None: """Delete memories. diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index b0753b31d..ecbca5815 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -493,14 +493,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt return items def get_by_filter( - self, - collection_name: str, - filter: dict[str, Any], - scroll_limit: int = 100, - page: int | None = None, - page_size: int | None = None, - count_total=False, - **kwargs, + self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -513,74 +506,47 @@ def get_by_filter( List of items including vectors and payload that match the filter """ expr = self._dict_to_expr(filter) if filter else "" - if count_total: - total_count = 0 - count_iterator = self.client.query_iterator( - collection_name=collection_name, - filter=expr, - batch_size=scroll_limit, - output_fields=["id"], - ) - try: - while True: - batch = count_iterator.next() - if not batch: - break - total_count += len(batch) - finally: - count_iterator.close() - - result = [] - skipped = 0 - needed = page_size + all_items = [] + # Use query_iterator for efficient pagination iterator = self.client.query_iterator( collection_name=collection_name, filter=expr, batch_size=scroll_limit, - output_fields=["*"], + output_fields=["*"], # Include all fields including payload ) + # Iterate through all batches try: - while needed > 0: - batch = iterator.next() - if not batch: - break - - for entity in batch: - skipped += 1 + while True: + batch_results = iterator.next() - if skipped <= (page - 1) * page_size: - continue + if not batch_results: + break + # Convert batch results to MilvusVecDBItem objects + for entity in batch_results: + # Extract the actual payload from Milvus entity payload = entity.get("payload", {}) - item = MilvusVecDBItem( - id=entity["id"], - memory=entity.get("memory"), - original_text=entity.get("original_text"), - vector=entity.get("vector"), - payload=payload, + all_items.append( + MilvusVecDBItem( + id=entity["id"], + memory=entity.get("memory"), + original_text=entity.get("original_text"), + vector=entity.get("vector"), + payload=payload, + ) ) - result.append(item) - needed -= 1 - - if needed <= 0: - if count_total: - return result, total_count - return result - except Exception as e: - logger.warning(f"Error during iteration: {e}") + logger.warning( + f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." + ) finally: + # Close the iterator iterator.close() - logger.info( - f"Milvus retrieve by filter completed - " - f"page {page}, page_size {page_size}, got {len(result)} items." - ) - if count_total: - return result, total_count - return result + logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") + return all_items def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database."""