From 83d157689c5d1dc4a2a10d8176e81dff97e2b9c6 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 14:38:19 +0800 Subject: [PATCH 01/21] fix bugs: try to fix bugs in _submit_web_logs --- src/memos/mem_scheduler/base_scheduler.py | 37 ++++++++++------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 81defaa0f..9ab356f1d 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -846,28 +846,23 @@ def _submit_web_logs( f"[DIAGNOSTIC] base_scheduler._submit_web_logs called. Message to publish: {message.model_dump_json(indent=2)}" ) - if self.rabbitmq_config is None: - logger.info( - "[DIAGNOSTIC] base_scheduler._submit_web_logs: RabbitMQ config not loaded; skipping publish." - ) - return - - for message in messages: - message_info = message.debug_info() - logger.info(f"[DIAGNOSTIC] base_scheduler._submit_web_logs: submitted {message_info}") + try: + for message in messages: + # Always call publish; the publisher now caches when offline and flushes after reconnect + logger.info( + f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}" + ) + self.rabbitmq_publish_message(message=message.to_dict()) + logger.info( + "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched " + "item_id=%s task_id=%s label=%s", + message.item_id, + message.task_id, + message.label, + ) + except Exception as e: + logger.error(f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True) - # Always call publish; the publisher now caches when offline and flushes after reconnect - logger.info( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message_info}" - ) - self.rabbitmq_publish_message(message=message.to_dict()) - logger.info( - "[DIAGNOSTIC] base_scheduler._submit_web_logs: publish dispatched " - "item_id=%s task_id=%s label=%s", - message.item_id, - message.task_id, - message.label, - ) logger.debug( f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}" ) From e50c56cf817cb6d63b8e8e882aeaa4de12c444b8 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 15:00:28 +0800 Subject: [PATCH 02/21] fix bugs: try to address bugs --- src/memos/mem_scheduler/base_scheduler.py | 13 +++++-------- .../webservice_modules/rabbitmq_service.py | 6 ++++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 9ab356f1d..1e0ecaadb 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -842,12 +842,7 @@ def _submit_web_logs( messages = [messages] # transform single message to list for message in messages: - logger.info( - f"[DIAGNOSTIC] base_scheduler._submit_web_logs called. Message to publish: {message.model_dump_json(indent=2)}" - ) - - try: - for message in messages: + try: # Always call publish; the publisher now caches when offline and flushes after reconnect logger.info( f"[DIAGNOSTIC] base_scheduler._submit_web_logs: enqueue publish {message.model_dump_json(indent=2)}" @@ -860,8 +855,10 @@ def _submit_web_logs( message.task_id, message.label, ) - except Exception as e: - logger.error(f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True) + except Exception as e: + logger.error( + f"[DIAGNOSTIC] base_scheduler._submit_web_logs failed: {e}", exc_info=True + ) logger.debug( f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue. additional_log_info: {additional_log_info}" diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index a8a09760c..db8320879 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -7,6 +7,8 @@ from pathlib import Path from queue import Empty +from pyglet.libs.win32.constants import FALSE + from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig from memos.context.context import ContextThread from memos.dependency import require_python_package @@ -325,14 +327,14 @@ def rabbitmq_publish_message(self, message: dict): f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. " f"Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") elif label == "knowledgeBaseUpdate": # Original diagnostic logging for knowledgeBaseUpdate if NOT in cloud env logger.info( f"[DIAGNOSTIC] Publishing knowledgeBaseUpdate message (Local Env). " f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") with self._rabbitmq_lock: logger.info( From 58eb6b81af34437677e929e629f25dd3ddf0c1ff Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 15:13:21 +0800 Subject: [PATCH 03/21] fix bugs --- .../mem_scheduler/webservice_modules/rabbitmq_service.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index db8320879..43d24c5b9 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -7,8 +7,6 @@ from pathlib import Path from queue import Empty -from pyglet.libs.win32.constants import FALSE - from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig from memos.context.context import ContextThread from memos.dependency import require_python_package @@ -327,14 +325,14 @@ def rabbitmq_publish_message(self, message: dict): f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. " f"Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") elif label == "knowledgeBaseUpdate": # Original diagnostic logging for knowledgeBaseUpdate if NOT in cloud env logger.info( f"[DIAGNOSTIC] Publishing knowledgeBaseUpdate message (Local Env). " f"Current configured Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) - logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=FALSE)}") + logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") with self._rabbitmq_lock: logger.info( From 0d72ce7669f3a9b30aa6849893a0e6ec6f991063 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 15:59:20 +0800 Subject: [PATCH 04/21] refactor: modify examples --- examples/mem_scheduler/memos_w_scheduler.py | 40 --------------------- 1 file changed, 40 deletions(-) diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index 09aec4cba..ef7d853df 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -4,7 +4,6 @@ from datetime import datetime from pathlib import Path -from queue import Queue from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig @@ -12,7 +11,6 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS -from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -160,42 +158,6 @@ def _first_content() -> str: return title, _truncate_with_rules(_first_content()) -def show_web_logs(mem_scheduler: GeneralScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - collected: list[ScheduleLogForWebItem] = [] - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - collected.append(log_item) - temp_queue.put(log_item) - - for idx, log_item in enumerate(sorted(collected, key=lambda x: x.timestamp, reverse=True), 1): - title, content = _format_entry(log_item) - print(f"\nLog Entry #{idx}:") - print(title) - print(content) - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {len(collected)} web log entries displayed.") - print("=" * 110 + "\n") - - def run_with_scheduler_init(): print("==== run_with_automatic_scheduler_init ====") conversations, questions = init_task() @@ -253,8 +215,6 @@ def run_with_scheduler_init(): response = mos.chat(query=query, user_id=user_id) print(f"Answer:\n {response}\n") - show_web_logs(mem_scheduler=mos.mem_scheduler) - mos.mem_scheduler.stop() From 2fe965be240ea0e68c511b5573d88e9599b7cbd2 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 18 Dec 2025 20:06:40 +0800 Subject: [PATCH 05/21] revise add operation and fix an unbelievable bug --- .../mem_scheduler/try_schedule_modules.py | 47 ------------------- src/memos/mem_reader/simple_struct.py | 2 +- .../webservice_modules/rabbitmq_service.py | 3 +- src/memos/templates/mem_reader_prompts.py | 39 ++++++++------- 4 files changed, 21 insertions(+), 70 deletions(-) diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index a5c5bc737..d942aad4e 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -1,8 +1,6 @@ import sys from pathlib import Path -from queue import Queue -from typing import TYPE_CHECKING from tqdm import tqdm @@ -11,18 +9,11 @@ ) from memos.log import get_logger from memos.mem_scheduler.analyzer.api_analyzer import DirectSearchMemoriesAnalyzer -from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import MEM_UPDATE_TASK_LABEL -if TYPE_CHECKING: - from memos.mem_scheduler.schemas import ( - ScheduleLogForWebItem, - ) - - FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory @@ -105,41 +96,6 @@ def init_task(): return conversations, questions -def show_web_logs(mem_scheduler: BaseScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - log_count = 0 - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - temp_queue.put(log_item) - log_count += 1 - - # Print log entry details - print(f"\nLog Entry #{log_count}:") - print(f'- "{log_item.label}" log: {log_item}') - - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {log_count} web log entries displayed.") - print("=" * 110 + "\n") - - class ScheduleModulesRunner(DirectSearchMemoriesAnalyzer): def __init__(self): super().__init__() @@ -215,6 +171,3 @@ def add_msgs( mem_scheduler._memory_update_consumer( messages=[message], ) - - # Show accumulated web logs - show_web_logs(mem_scheduler) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index ac79c246b..b870bf70a 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -522,7 +522,7 @@ def filter_hallucination_in_memories( raw = self.llm.generate([{"role": "user", "content": prompt}]) success, parsed = self._parse_hallucination_filter_response(raw) logger.info( - f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success}" + f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" ) if success: logger.info(f"Hallucination filter result: {parsed}") diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 43d24c5b9..46b2ad3d1 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -108,8 +108,7 @@ def initialize_rabbitmq( elif Path(config_path).exists(): auth_config = AuthConfig.from_local_config(config_path=config_path) else: - logger.error("Fail to initialize auth_config") - return + auth_config = AuthConfig.from_local_env() self.rabbitmq_config = auth_config.rabbitmq elif isinstance(config, RabbitMQConfig): self.rabbitmq_config = config diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 12c445df7..fef3ee6c0 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -625,21 +625,20 @@ SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. -Your task is to compare each memory against the provided user messages (the ground truth) and produce a corrected version only when necessary. Always preserve the original language of the memory—do not translate. +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. Rules: -1. **Language Consistency**: The rewritten memory must be in the exact same language as the original input memory. Never switch languages. -2. **Strict Grounding**: Only use information explicitly stated in the user messages. Do not introduce external facts, assumptions, or common sense. -3. **Ambiguity Resolution**: - - Replace vague pronouns (e.g., "he", "it", "they") or unclear references with specific, unambiguous entities based solely on the messages. - - Convert relative time expressions (e.g., "yesterday", "last week", "in two days") into absolute dates or times **only if the messages provide enough context** (e.g., current date is known or implied). -4. **Handling Assistant Inferences**: - - If a memory contains any content **not directly stated by the user**—such as interpretations, summaries, emotional attributions, predictions, causal claims, or generalizations—this is considered an assistant inference. - - In such cases, you **must** set `need_rewrite = true`. - - The `rewritten` text **must explicitly indicate that the statement is an inference**, using a clear and natural prefix in the memory’s language. For English memories, use: - > "The assistant inferred that [rest of the memory]." - - Do **not** present inferred content as factual user statements. -5. **No Rewrite Needed**: If the memory is factually accurate, fully grounded in the messages, unambiguous, and contains no unsupported content, set `need_rewrite = false` and copy the original memory exactly. +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations. +3. **Ambiguity Elimination**: + - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them. + - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context. +4. **Hallucination Removal**: + - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten. + - Do **not** rephrase inferences as facts. Instead, either: + - Remove the unsupported part and retain only the grounded core, or + - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. Inputs: messages: @@ -649,15 +648,15 @@ {memories_inline} Output Format: -- Return a JSON object with string keys ("0", "1", "2", ...) corresponding to the input memory indices. +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. - Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} -- The "reason" should be concise and specific, e.g.: - - "contains assistant inference not stated by user" - - "pronoun 'it' has no clear referent in messages" - - "relative time 'yesterday' converted to 2025-12-16" - - "accurate and directly supported by user message" +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference" + - "vague pronoun with no referent in messages" + - "relative time resolved to 2025-12-16" + - "fully grounded and concise" -Important: Output **only** the JSON. No additional text, explanations, markdown, or fields. +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ From eecfa5136d9065fefd82867068a8deb12efae8a2 Mon Sep 17 00:00:00 2001 From: chentang Date: Mon, 22 Dec 2025 10:37:34 +0800 Subject: [PATCH 06/21] address the bug issues --- .../task_schedule_modules/redis_queue.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index ed8171ade..1c57f18f0 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -699,27 +699,23 @@ def _batch_claim_pending_messages( results = [] try: results = pipe.execute() - except Exception as e: - err_msg = str(e).lower() - if "nogroup" in err_msg or "no such key" in err_msg: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue - else: - logger.error(f"Pipeline xautoclaim failed: {e}") + except Exception: + # Fallback: attempt sequential xautoclaim for robustness + for stream_key, need_count, label in claims_spec: + try: + self._ensure_consumer_group(stream_key=stream_key) + res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + results.append(res) + except Exception: + continue claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( From f2da3a7bd718cf663b29fb285e602e847f4dc91a Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 19:35:40 +0800 Subject: [PATCH 07/21] the doc file has a format problem which has been fixed in this commit --- docs/README.md | 2 +- .../{task_stop_rerun.py => scheduler_for_async_tasks.py} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename examples/mem_scheduler/{task_stop_rerun.py => scheduler_for_async_tasks.py} (98%) diff --git a/docs/README.md b/docs/README.md index bf5fea70d..8be17ffb7 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,3 +1,3 @@ All documentation has been moved to a separate repository: https://github.com/MemTensor/MemOS-Docs. Please edit documentation there. -所有文档已迁移至独立仓库:https://github.com/MemTensor/MemOS-Docs。请在该仓库中编辑文档。 +所有文档已迁移至独立仓库 https://github.com/MemTensor/MemOS-Docs 。请在该仓库中编辑文档。 diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/scheduler_for_async_tasks.py similarity index 98% rename from examples/mem_scheduler/task_stop_rerun.py rename to examples/mem_scheduler/scheduler_for_async_tasks.py index b5e62ff8f..a767b57c4 100644 --- a/examples/mem_scheduler/task_stop_rerun.py +++ b/examples/mem_scheduler/scheduler_for_async_tasks.py @@ -25,7 +25,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): task_id = str(msg.item_id) file_path = tmp_dir / f"{task_id}.txt" try: - sleep(1) + sleep(5) file_path.write_text(f"Task {task_id} processed.\n") print(f"writing {file_path} done") except Exception as e: @@ -58,7 +58,7 @@ def submit_tasks(): mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) # 10s to restart -mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 10_000 +mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 5_000 tmp_dir = Path("./tmp") tmp_dir.mkdir(exist_ok=True) @@ -88,6 +88,6 @@ def submit_tasks(): print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})") # 7. Stop the scheduler +sleep(20) print("Stopping the scheduler...") -sleep(5) mem_scheduler.stop() From a6881b4b064145f032c9b9e58ed0f9772ef33612 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 19:45:48 +0800 Subject: [PATCH 08/21] add a range of new feats for the add operation --- src/memos/api/config.py | 21 +- src/memos/llms/openai.py | 6 +- src/memos/mem_reader/simple_struct.py | 354 ++++++++++-------- src/memos/mem_reader/utils.py | 210 +++++++++++ .../mem_scheduler/schemas/general_schemas.py | 4 +- .../task_schedule_modules/redis_queue.py | 69 ++-- .../textual/prefer_text_memory/extractor.py | 4 + src/memos/templates/mem_reader_prompts.py | 156 +++++++- 8 files changed, 624 insertions(+), 200 deletions(-) create mode 100644 src/memos/mem_reader/utils.py diff --git a/src/memos/api/config.py b/src/memos/api/config.py index b795c2be6..0cdcb9a92 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -7,16 +7,19 @@ import re import time -from typing import Any +from typing import TYPE_CHECKING, Any import requests from dotenv import load_dotenv -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig from memos.context.context import ContextThread -from memos.mem_cube.general import GeneralMemCube + + +if TYPE_CHECKING: + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube # Load environment variables @@ -805,8 +808,12 @@ def get_start_default_config() -> dict[str, Any]: return config @staticmethod - def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]: + def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "GeneralMemCube"]: """Create configuration for a specific user.""" + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube + openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() @@ -933,12 +940,14 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General return default_config, default_mem_cube @staticmethod - def get_default_cube_config() -> GeneralMemCubeConfig | None: + def get_default_cube_config() -> "GeneralMemCubeConfig | None": """Get default cube configuration for product initialization. Returns: GeneralMemCubeConfig | None: Default cube configuration if enabled, None otherwise. """ + from memos.configs.mem_cube import GeneralMemCubeConfig + if not APIConfig.is_default_cube_config_enabled(): return None diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1d180eebd..752386c91 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -57,8 +57,8 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) if reasoning_content: - return reasoning_content + response_content - return response_content + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM", @@ -146,7 +146,7 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) else: - return response_content + return response_content or "" def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from Azure OpenAI LLM with optional reasoning support.""" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b870bf70a..866b6d988 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -2,7 +2,6 @@ import copy import json import os -import re import traceback from abc import ABC @@ -18,6 +17,13 @@ from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang +from memos.mem_reader.utils import ( + count_tokens_text, + derive_key, + parse_json_result, + parse_keep_filter_response, + parse_rewritten_response, +) from memos.memories.textual.item import ( SourceMessage, TextualMemoryItem, @@ -89,27 +95,6 @@ def from_config(_config): } -try: - import tiktoken - - try: - _ENC = tiktoken.encoding_for_model("gpt-4o-mini") - except Exception: - _ENC = tiktoken.get_encoding("cl100k_base") - - def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "", disallowed_special=())) -except Exception: - # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars - def _count_tokens_text(s: str) -> int: - if not s: - return 0 - zh_chars = re.findall(r"[\u4e00-\u9fff]", s) - zh = len(zh_chars) - rest = len(s) - zh - return zh + max(1, rest // 4) - - def _build_node(idx, message, info, source_info, llm, parse_json_result, embedder): # generate try: @@ -172,14 +157,6 @@ def _build_node(idx, message, info, source_info, llm, parse_json_result, embedde return None -def _derive_key(text: str, max_len: int = 80) -> str: - """default key when without LLM: first max_len words""" - if not text: - return "" - sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] - return (sent[:max_len]).strip() - - class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" @@ -197,7 +174,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.memory_max_length = 8000 # Use token-based windowing; default to ~5000 tokens if not configured self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) - self._count_tokens = _count_tokens_text + self._count_tokens = count_tokens_text + self.searcher = None def _make_memory_item( self, @@ -224,7 +202,7 @@ def _make_memory_item( memory_type=memory_type, status="activated", tags=tags or [], - key=key if key is not None else _derive_key(value), + key=key if key is not None else derive_key(value), embedding=self.embedder.embed([value])[0], usage=[], sources=sources or [], @@ -254,7 +232,7 @@ def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { @@ -456,47 +434,73 @@ def get_memory( standard_scene_data = coerce_scene_data(scene_data, type) return self._read_memory(standard_scene_data, type, info, mode) - @staticmethod - def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: - """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } - Returns (success, parsed_dict) with int keys. - """ + def rewrite_memories( + self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + if user_only: + template = PROMPT_MAPPING["rewrite_user_only"] + filtered_messages = [m for m in messages if m.get("role") != "assistant"] + if len(filtered_messages) < 1: + return memory_list + else: + template = PROMPT_MAPPING["rewrite"] + filtered_messages = messages + if len(filtered_messages) < 2: + return memory_list + + prompt_args = { + "messages_inline": "\n".join( + [f"- [{message['role']}]: {message['content']}" for message in filtered_messages] + ), + "memories_inline": json.dumps( + {idx: mem.memory for idx, mem in enumerate(memory_list)}, + ensure_ascii=False, + indent=2, + ), + } + prompt = template.format(**prompt_args) + + # Optionally run filter and parse the output try: - data = json.loads(text) - except Exception: - return False, {} + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed = parse_rewritten_response(raw) + logger.info( + f"[rewrite_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" + ) + if success: + logger.info(f"Rewrite filter result: {parsed}") - if not isinstance(data, dict): - return False, {} + new_memory_list = [] + for mem_idx, content in parsed.items(): + if mem_idx < 0 or mem_idx >= len(memory_list): + logger.warning( + f"[rewrite_memories] Invalid memory index {mem_idx} for memory_list {len(memory_list)}, skipping." + ) + continue - result: dict[int, dict] = {} - for k, v in data.items(): - try: - idx = int(k) - except Exception: - # allow integer keys as-is - if isinstance(k, int): - idx = k - else: - continue - if not isinstance(v, dict): - continue - need_rewrite = v.get("need_rewrite") - rewritten = v.get("rewritten", "") - reason = v.get("reason", "") - if ( - isinstance(need_rewrite, bool) - and isinstance(rewritten, str) - and isinstance(reason, str) - ): - result[idx] = { - "need_rewrite": need_rewrite, - "rewritten": rewritten, - "reason": reason, - } + need_rewrite = content.get("need_rewrite", False) + rewritten_text = content.get("rewritten", "") + reason = content.get("reason", "") + original_text = memory_list[mem_idx].memory + + # Replace memory text with rewritten content when rewrite is needed + if need_rewrite and isinstance(rewritten_text, str): + logger.info( + f"[rewrite_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + ) + if len(rewritten_text.strip()) != 0: + memory_list[mem_idx].memory = rewritten_text + new_memory_list.append(memory_list[mem_idx]) + else: + new_memory_list.append(memory_list[mem_idx]) + return new_memory_list + else: + logger.warning("Rewrite filter parsing failed or returned empty result.") + except Exception as e: + logger.error(f"Rewrite filter execution error: {e}", stack_info=True) - return (len(result) > 0), result + return memory_list def filter_hallucination_in_memories( self, messages: list[dict], memory_list: list[TextualMemoryItem] @@ -520,32 +524,32 @@ def filter_hallucination_in_memories( # Optionally run filter and parse the output try: raw = self.llm.generate([{"role": "user", "content": prompt}]) - success, parsed = self._parse_hallucination_filter_response(raw) + success, parsed = parse_keep_filter_response(raw) logger.info( f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" ) if success: logger.info(f"Hallucination filter result: {parsed}") - assert len(parsed) == len(memory_list) - for mem_idx, content in parsed.items(): - need_rewrite = content.get("need_rewrite", False) - rewritten_text = content.get("rewritten", "") - reason = content.get("reason", "") - # Replace memory text with rewritten content when rewrite is needed - if ( - need_rewrite - and isinstance(rewritten_text, str) - and len(rewritten_text.strip()) > 0 - ): - original_text = memory_list[mem_idx].memory + filtered_list = [] + for mem_idx, mem in enumerate(memory_list): + content = parsed.get(mem_idx) + if not content: + logger.warning(f"No verdict for memory {mem_idx}, keeping it.") + filtered_list.append(mem) + continue + keep = content.get("keep", True) + reason = content.get("reason", "") + + if keep: + filtered_list.append(mem) + else: logger.info( - f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + f"[filter_hallucination_in_memories] Dropping memory index={mem_idx}, reason='{reason}', memory='{mem.memory}'" ) - memory_list[mem_idx].memory = rewritten_text - return memory_list + return filtered_list else: logger.warning("Hallucination filter parsing failed or returned empty result.") except Exception as e: @@ -553,6 +557,103 @@ def filter_hallucination_in_memories( return memory_list + def add_before_search( + self, + messages: list[dict], + memory_list: list[TextualMemoryItem], + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + template = PROMPT_MAPPING["add_before_search"] + + if not self.searcher: + try: + from memos.mem_reader.utils import init_searcher + + self.searcher = init_searcher(self.llm, self.embedder) + except Exception as e: + logger.error(f"[add_before_search] Failed to init searcher: {e}") + return memory_list + + # 1. Gather candidates and search for related memories + candidates_data = [] + for idx, mem in enumerate(memory_list): + try: + related_memories = self.searcher.search( + query=mem.memory, top_k=3, mode="fast", info={"user_id": "", "session_id": ""} + ) + related_text = "None" + if related_memories: + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) + + candidates_data.append( + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} + ) + except Exception as e: + logger.error(f"[add_before_search] Search error for memory '{mem.memory}': {e}") + # If search fails, we can either skip this check or treat related as empty + candidates_data.append( + { + "idx": idx, + "new_memory": mem.memory, + "related_memories": "None (Search Failed)", + } + ) + + if not candidates_data: + return memory_list + + # 2. Build Prompt + messages_inline = "\n".join( + [ + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" + for message in messages + ] + ) + + candidates_inline_dict = { + str(item["idx"]): { + "new_memory": item["new_memory"], + "related_memories": item["related_memories"], + } + for item in candidates_data + } + + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) + + prompt = template.format( + messages_inline=messages_inline, candidates_inline=candidates_inline + ) + + # 3. Call LLM + try: + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed_result = parse_keep_filter_response(raw) + + if not success: + logger.warning("[add_before_search] Failed to parse LLM response, keeping all.") + return memory_list + + # 4. Filter + filtered_list = [] + for idx, mem in enumerate(memory_list): + res = parsed_result.get(idx) + if not res: + filtered_list.append(mem) + continue + + if res.get("keep", True): + filtered_list.append(mem) + else: + logger.info( + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" + ) + + return filtered_list + + except Exception as e: + logger.error(f"[add_before_search] LLM execution error: {e}") + return memory_list + def _read_memory( self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: @@ -606,29 +707,27 @@ def _read_memory( for group_id in range(len(memory_list)): try: - revised_memory_list = self.filter_hallucination_in_memories( + original_memory_group = copy.deepcopy(memory_list[group_id]) + serialized_origin_memories = json.dumps( + [one.memory for one in original_memory_group], indent=2 + ) + revised_memory_list = self.rewrite_memories( messages=combined_messages, - memory_list=memory_list[group_id], + memory_list=original_memory_group, + user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() + == "true", + ) + serialized_revised_memories = json.dumps( + [one.memory for one in revised_memory_list], indent=2 ) - if len(revised_memory_list) != len(memory_list[group_id]): - original_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in memory_list[group_id] - ] - filtered_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in revised_memory_list - ] - logger.error( - f"Length mismatch after hallucination filtering for group_id={group_id}: " - f"original={len(memory_list[group_id])}, filtered={len(revised_memory_list)}" - f"\noriginal_memory_list(serialized): {original_serialized}" - f"\nfiltered_memory_list(serialized): {filtered_serialized}" - f"\nmessages: {combined_messages}" - f"\nSkipping update and keeping original memory." + if serialized_origin_memories != serialized_revised_memories: + memory_list[group_id] = revised_memory_list + logger.info( + f"[SIMPLE_STRUCT_ADD_FILTER] Modified the list for group_id={group_id}: " + f"\noriginal={serialized_origin_memories}," + f"\nrevised={serialized_revised_memories}" ) - continue - memory_list[group_id] = revised_memory_list + except Exception as e: group_serialized = [ one.memory if hasattr(one, "memory") else str(one) @@ -847,7 +946,7 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): info, source_info_list, self.llm, - self.parse_json_result, + parse_json_result, self.embedder, ): idx for idx, msg in enumerate(messages) @@ -870,44 +969,3 @@ def _process_transfer_doc_data( self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None ): raise NotImplementedError - - def parse_json_result(self, response_text: str) -> dict: - s = (response_text or "").strip() - - m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) - s = (m.group(1) if m else s.replace("```", "")).strip() - - i = s.find("{") - if i == -1: - return {} - s = s[i:].strip() - - try: - return json.loads(s) - except json.JSONDecodeError: - pass - - j = max(s.rfind("}"), s.rfind("]")) - if j != -1: - try: - return json.loads(s[: j + 1]) - except json.JSONDecodeError: - pass - - def _cheap_close(t: str) -> str: - t += "}" * max(0, t.count("{") - t.count("}")) - t += "]" * max(0, t.count("[") - t.count("]")) - return t - - t = _cheap_close(s) - try: - return json.loads(t) - except json.JSONDecodeError as e: - if "Invalid \\escape" in str(e): - s = s.replace("\\", "\\\\") - return json.loads(s) - logger.error( - f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ - json: {s}" - ) - return {} diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py new file mode 100644 index 000000000..843345ec4 --- /dev/null +++ b/src/memos/mem_reader/utils.py @@ -0,0 +1,210 @@ +import json +import os +import re + +from typing import Any + +from memos import log +from memos.api.config import APIConfig +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.factory import RerankerFactory + + +logger = log.get_logger(__name__) + +try: + import tiktoken + + try: + _ENC = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + _ENC = tiktoken.get_encoding("cl100k_base") + + def count_tokens_text(s: str) -> int: + return len(_ENC.encode(s or "", disallowed_special=())) +except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + def count_tokens_text(s: str) -> int: + if not s: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", s) + zh = len(zh_chars) + rest = len(s) - zh + return zh + max(1, rest // 4) + + +def derive_key(text: str, max_len: int = 80) -> str: + """default key when without LLM: first max_len words""" + if not text: + return "" + sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] + return (sent[:max_len]).strip() + + +def parse_json_result(response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.error( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) + return {} + + +def parse_rewritten_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from hallucination filter response. + Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + # allow integer keys as-is + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + need_rewrite = v.get("need_rewrite") + rewritten = v.get("rewritten", "") + reason = v.get("reason", "") + if ( + isinstance(need_rewrite, bool) + and isinstance(rewritten, str) + and isinstance(reason, str) + ): + result[idx] = { + "need_rewrite": need_rewrite, + "rewritten": rewritten, + "reason": reason, + } + + return (len(result) > 0), result + + +def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from keep filter response. + Expected shape: { "0": {"keep": bool, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + keep = v.get("keep") + reason = v.get("reason", "") + if isinstance(keep, bool): + result[idx] = { + "keep": keep, + "reason": reason, + } + return (len(result) > 0), result + + +def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + "polardb": APIConfig.get_polardb_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def build_reranker_config() -> dict[str, Any]: + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def init_searcher(llm, embedder) -> Searcher: + """Initialize a Searcher instance for SimpleStructMemReader.""" + + # Build configs + graph_db_config = build_graph_db_config() + reranker_config = build_reranker_config() + + # Create instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + reranker = RerankerFactory.from_config(reranker_config) + + # Create Searcher + searcher = Searcher( + dispatcher_llm=llm, + graph_store=graph_db, + embedder=embedder, + reranker=reranker, + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + ) + + return searcher diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f4ad9fe48..06910ba17 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,3 +1,5 @@ +import os + from pathlib import Path @@ -21,7 +23,7 @@ DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 DEFAULT_TOP_K = 5 DEFAULT_CONTEXT_WINDOW_SIZE = 5 -DEFAULT_USE_REDIS_QUEUE = True +DEFAULT_USE_REDIS_QUEUE = os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true" DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20 DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1c57f18f0..7923b3750 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -679,11 +679,6 @@ def _batch_claim_pending_messages( if not self._redis_conn or not claims_spec: return [] - # Ensure consumer groups exist to avoid NOGROUP errors during batch claim - for stream_key, _need_count, _label in claims_spec: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - pipe = self._redis_conn.pipeline(transaction=False) for stream_key, need_count, label in claims_spec: pipe.xautoclaim( @@ -696,26 +691,42 @@ def _batch_claim_pending_messages( justid=False, ) - results = [] try: - results = pipe.execute() - except Exception: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue + # Execute with raise_on_error=False so we get exceptions in the results list + # instead of aborting the whole batch. + results = pipe.execute(raise_on_error=False) + except Exception as e: + logger.error(f"Pipeline execution critical failure: {e}") + results = [e] * len(claims_spec) + + # Handle individual failures (e.g. NOGROUP) by retrying just that stream + final_results = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + err_msg = str(res).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + stream_key, need_count, label = claims_spec[i] + try: + self._ensure_consumer_group(stream_key=stream_key) + retry_res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + final_results.append(retry_res) + except Exception as retry_err: + logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}") + final_results.append(None) + else: + final_results.append(None) + else: + final_results.append(res) + + results = final_results claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( @@ -1189,9 +1200,7 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) - logger.info( - f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', " - f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, " - f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, " - f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}" - ) + logger.info( + f"Refreshed stream keys cache: {cache_count} active keys, " + f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." + ) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 144bfad7f..3404c6d4c 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -69,6 +69,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: @@ -92,6 +94,8 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index fef3ee6c0..40971c77e 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,23 +622,56 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. +Note: The provided messages contain only user messages. The assistant's responses are intentionally omitted, not because the assistant didn't answer, but to focus strictly on validating memories against user input. + Rules: 1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. -2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations. -3. **Ambiguity Elimination**: - - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them. - - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. 4. **Hallucination Removal**: - - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten. - - Do **not** rephrase inferences as facts. Instead, either: - - Remove the unsupported part and retain only the grounded core, or - - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit. +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. 5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. Inputs: messages: @@ -651,16 +684,115 @@ - Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. - Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} - The "reason" must be brief and precise, e.g.: - - "contains unsupported inference" - - "vague pronoun with no referent in messages" - - "relative time resolved to 2025-12-16" + - "contains unsupported inference ...." - "fully grounded and concise" Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +You are a strict memory validator. +Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages. + +Rules: +1. **User-Only Origin**: Verify facts against USER messages ONLY. If the Assistant repeats a User fact, it is VALID. If the Assistant introduces a new detail (e.g., 'philanthropy') that the User did not explicitly confirm, it is INVALID. +2. **No Inference Allowed**: Do NOT keep memories based on implication, emotion, preference, or generalization. Only verbatim or direct restatements of user-provided facts are valid. However, minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +3. **Hallucination = Deletion**: If a memory contains any detail not directly expressed by the user, mark it for deletion. +4. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Examples: +Messages: +- [user]: I love coding in Python. +- [assistant]: That's great! I assume you also contribute to open source projects? +Memory: User enjoys Python and contributes to open source. +Result: {{"keep": false, "reason": "User never stated they contribute to open source; this came from Assistant's assumption."}} + +Messages: +- [user]: I am tired. +- [assistant]: I hear you are tired. Rest is important. +Memory: User stated they are tired. +Result: {{"keep": true, "reason": "Direct restatement of user input, even if Assistant repeated it."}} + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true only if the memory is a direct reflection of the user's explicit words. +- "reason": brief, factual, and cites missing or unsupported content. + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + + +SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT = """ +You are a memory manager. +Your task is to decide if a new memory should be added to the long-term memory, given a list of existing related memories. + +Rules: +1. **Redundancy Check**: If the new memory is completely redundant, already known, or covered by the existing memories, discard it. +2. **New Information**: If the new memory provides new information, details, or updates compared to the existing memories, keep it. +3. **Contradiction**: If the new memory contradicts existing memories but seems valid/newer, keep it (updates). +4. **Context Check**: Use the provided conversation messages to verify if the new memory is grounded in the user's explicit statements. + +Inputs: +Messages: +{messages_inline} + +Candidate Memories (to be evaluated): +{candidates_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input candidate memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true if the memory should be added. +- "reason": brief explanation. + +Important: Output **only** the JSON. No extra text. +""" # Prompt mapping for specialized tasks (e.g., hallucination filtering) PROMPT_MAPPING = { "hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT, + "rewrite": SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT, + "rewrite_user_only": SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT, + "add_before_search": SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT, } From 7f39e7ecc052d2e85e7bbeb2ca73f586db143875 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:04:56 +0800 Subject: [PATCH 09/21] address the incompatible issue of local scheduler --- src/memos/mem_scheduler/base_scheduler.py | 20 +++-- .../task_schedule_modules/local_queue.py | 75 +++++++++++++++++-- .../task_schedule_modules/task_queue.py | 23 +----- .../mem_scheduler/utils/status_tracker.py | 26 ++++++- tests/test_local_queue_full.py | 54 +++++++++++++ 5 files changed, 164 insertions(+), 34 deletions(-) create mode 100644 tests/test_local_queue_full.py diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 1e0ecaadb..728203f5b 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1009,14 +1009,24 @@ def _monitor_loop(self): q_sizes = self.memos_message_queue.qsize() for stream_key, queue_length in q_sizes.items(): - # Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}" + # Skip aggregate keys like 'total_size' + if stream_key == "total_size": + continue + + # Key format: ...:{user_id}:{mem_cube_id}:{task_label} + # We want to extract user_id, which is the 3rd component from the end. parts = stream_key.split(":") if len(parts) >= 3: - user_id = parts[2] - self.metrics.update_queue_length(queue_length, user_id) - elif not self.use_redis_queue: # local queue - user_id = stream_key + user_id = parts[-3] self.metrics.update_queue_length(queue_length, user_id) + else: + # Fallback for unexpected key formats (e.g. legacy or testing) + # Try to use the key itself if it looks like a user_id (no colons) + # or just log a warning? + # For now, let's assume if it's not total_size and short, it might be a direct user_id key + # (though that shouldn't happen with current queue implementations) + if ":" not in stream_key: + self.metrics.update_queue_length(queue_length, stream_key) except Exception as e: logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 69cfc0af9..32d79cef3 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -62,7 +62,7 @@ def put( Exception: Any underlying error during queue.put() operation. """ stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) message.stream_key = stream_key @@ -108,35 +108,95 @@ def get( ) return res - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait( + self, stream_key: str, batch_size: int | None = None + ) -> list[ScheduleMessageItem]: """ - Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). Returns immediately with available messages or an empty list if queue is empty. Args: + stream_key (str): The stream/queue identifier. batch_size (int | None): Number of messages to retrieve in a batch. If None, retrieves one message. Returns: List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. """ - logger.debug(f"get_nowait() called with batch_size: {batch_size}") - return self.get(block=False, batch_size=batch_size) + logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}") + return self.get(stream_key=stream_key, block=False, batch_size=batch_size) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + """ + Get messages from all streams in round-robin or sequential fashion. + Equivalent to SchedulerRedisQueue.get_messages. + """ + messages = [] + # Snapshot keys to avoid runtime modification issues + stream_keys = list(self.queue_streams.keys()) + + # Simple strategy: try to get up to batch_size messages across all streams + # We can just iterate and collect. + + # Calculate how many to get per stream to be fair? + # Or just greedy? Redis implementation uses a complex logic. + # For local, let's keep it simple: just iterate and take what's available (non-blocking) + + for stream_key in stream_keys: + if len(messages) >= batch_size: + break + + needed = batch_size - len(messages) + # Use get_nowait to avoid blocking + fetched = self.get_nowait(stream_key=stream_key, batch_size=needed) + messages.extend(fetched) + + return messages def qsize(self) -> dict: """ Return the current size of all internal queues as a dictionary. Each key is the stream name, and each value is the number of messages in that queue. + Also includes 'total_size'. Returns: Dict[str, int]: Mapping from stream name to current queue size. """ sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + total_size = sum(sizes.values()) + sizes["total_size"] = total_size logger.debug(f"Current queue sizes: {sizes}") return sizes + def size(self) -> int: + """ + Get the current size of the queue (total message count). + Compatible with SchedulerRedisQueue. + """ + return self.unfinished_tasks + + def empty(self) -> bool: + """ + Check if the queue is empty. + Compatible with SchedulerRedisQueue. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if the queue is full. + Compatible with SchedulerRedisQueue. + + Returns True if all internal queues are full. + If there are no queues, returns False. + """ + if not self.queue_streams: + return False + + return all(queue.full() for queue in self.queue_streams.values()) + def clear(self) -> None: for queue in self.queue_streams.values(): queue.clear() @@ -151,6 +211,9 @@ def unfinished_tasks(self) -> int: Returns: int: Sum of all message counts in all internal queues. """ - total = sum(self.qsize().values()) + # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values + # But qsize() implementation above sums values from queue_streams, then adds total_size. + # So sum(self.queue_streams.values().qsize()) is safer. + total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index c20243242..b49db2b36 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -153,28 +153,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_messages(batch_size=batch_size) - stream_keys = self.get_stream_keys() - - if len(stream_keys) == 0: - return [] - - messages: list[ScheduleMessageItem] = [] - - for stream_key in stream_keys: - fetched = self.memos_message_queue.get( - stream_key=stream_key, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - if len(messages) > 0: - logger.debug( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) - return messages + return self.memos_message_queue.get_messages(batch_size=batch_size) def clear(self): self.memos_message_queue.clear() diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index d8c8d2cee..2a995b239 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -13,7 +13,7 @@ class TaskStatusTracker: @require_python_package(import_name="redis", install_command="pip install redis") - def __init__(self, redis_client: "redis.Redis"): + def __init__(self, redis_client: "redis.Redis | None"): self.redis = redis_client def _get_key(self, user_id: str) -> str: @@ -41,6 +41,9 @@ def task_submitted( mem_cube_id: Memory cube identifier business_task_id: Optional business-level task ID (one task_id can have multiple item_ids) """ + if not self.redis: + return + key = self._get_key(user_id) payload = { "status": "waiting", @@ -61,6 +64,9 @@ def task_submitted( self.redis.expire(key, timedelta(days=7)) def task_started(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -77,6 +83,9 @@ def task_started(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_completed(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -91,6 +100,9 @@ def task_completed(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_failed(self, task_id: str, user_id: str, error_message: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -108,11 +120,17 @@ def task_failed(self, task_id: str, user_id: str, error_message: str): self.redis.expire(key, timedelta(days=7)) def get_task_status(self, task_id: str, user_id: str) -> dict | None: + if not self.redis: + return None + key = self._get_key(user_id) data = self.redis.hget(key, task_id) return json.loads(data) if data else None def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]: + if not self.redis: + return {} + key = self._get_key(user_id) all_tasks = self.redis.hgetall(key) return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()} @@ -132,6 +150,9 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) -> - If any item is 'failed' → 'failed' Returns None if task_id not found. """ + if not self.redis: + return None + # Get all item_ids for this task_id task_items_key = self._get_task_items_key(user_id, business_task_id) item_ids = self.redis.smembers(task_items_key) @@ -180,6 +201,9 @@ def get_all_tasks_global(self) -> dict[str, dict[str, dict]]: Returns: dict: {user_id: {task_id: task_data, ...}, ...} """ + if not self.redis: + return {} + all_users_tasks = {} cursor: int | str = 0 while True: diff --git a/tests/test_local_queue_full.py b/tests/test_local_queue_full.py new file mode 100644 index 000000000..6c523046a --- /dev/null +++ b/tests/test_local_queue_full.py @@ -0,0 +1,54 @@ +import unittest + +from datetime import datetime, timezone + +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue + + +class TestLocalQueueFull(unittest.TestCase): + def test_full_behavior(self): + # Create a queue with very small maxsize for testing + lq = SchedulerLocalQueue(maxsize=1) + + # Initially empty + self.assertFalse(lq.full()) + + # Add message to stream 1 + msg1 = ScheduleMessageItem( + user_id="u1", + mem_cube_id="c1", + label="l1", + content="m1", + timestamp=datetime.now(timezone.utc), + ) + lq.put(msg1) + + # Now stream 1 is full (maxsize=1). + # Since it's the only stream, and it's full, lq.full() should be True. + self.assertTrue(lq.full()) + + # Add message to stream 2 + msg2 = ScheduleMessageItem( + user_id="u2", + mem_cube_id="c2", + label="l2", + content="m2", + timestamp=datetime.now(timezone.utc), + ) + lq.put(msg2) + + # Now both stream 1 and stream 2 are full. lq.full() should be True. + self.assertTrue(lq.full()) + + # Remove message from stream 1 + stream1_key = lq.get_stream_key("u1", "c1", "l1") + lq.get(stream1_key) + + # Now stream 1 is empty, stream 2 is full. + # "all streams are full" is False. + self.assertFalse(lq.full()) + + +if __name__ == "__main__": + unittest.main() From 3fe9cb09b4f1864db8225be7e64b0959e50c358f Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:25:07 +0800 Subject: [PATCH 10/21] feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. --- .../task_schedule_modules/redis_queue.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 7923b3750..2f4318003 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,7 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -import contextlib import os import re import threading @@ -201,6 +200,20 @@ def _refresh_stream_keys( recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, now_sec=now_sec, ) + + # Ensure consumer groups for newly discovered active streams + with self._stream_keys_lock: + # Identify keys we haven't seen yet + new_streams = [k for k in active_stream_keys if k not in self.seen_streams] + + # Create groups outside the lock to avoid blocking + for key in new_streams: + self._ensure_consumer_group(key) + + if new_streams: + with self._stream_keys_lock: + self.seen_streams.update(new_streams) + deleted_count = self._delete_streams(keys_to_delete) self._update_stream_cache_with_log( stream_key_prefix=stream_key_prefix, @@ -560,10 +573,7 @@ def _read_new_messages_batch( return {} # Pre-ensure consumer groups to avoid NOGROUP during batch reads - for stream_key in stream_keys: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - + # (Optimization: rely on put() and _refresh_stream_keys() to ensure groups) pipe = self._redis_conn.pipeline(transaction=False) for stream_key in stream_keys: pipe.xreadgroup( @@ -1170,10 +1180,14 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: del_pipe.delete(key) del_pipe.execute() deleted_count = len(keys_to_delete) - # Clean up empty-tracking state for deleted keys + # Clean up empty-tracking state and seen_streams for deleted keys with self._empty_stream_seen_lock: for key in keys_to_delete: self._empty_stream_seen_times.pop(key, None) + + with self._stream_keys_lock: + for key in keys_to_delete: + self.seen_streams.discard(key) except Exception: for key in keys_to_delete: try: @@ -1181,6 +1195,8 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: deleted_count += 1 with self._empty_stream_seen_lock: self._empty_stream_seen_times.pop(key, None) + with self._stream_keys_lock: + self.seen_streams.discard(key) except Exception: pass return deleted_count From b35096fa61d3d3aeea3297354b0d10a78916a0f8 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:34:05 +0800 Subject: [PATCH 11/21] fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. --- tests/mem_reader/test_simple_structure.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index f81356886..fd07fbf41 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -1,4 +1,3 @@ -import json import unittest from unittest.mock import MagicMock, patch @@ -8,6 +7,7 @@ from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem @@ -57,7 +57,6 @@ def test_process_chat_data(self): '"summary": "Tom is currently focused on managing a new project with a tight schedule."}' ) self.reader.llm.generate.return_value = mock_response - self.reader.parse_json_result = lambda x: json.loads(x) result = self.reader._process_chat_data(scene_data_info, info) @@ -105,7 +104,7 @@ def test_get_scene_data_info_with_chat(self): def test_parse_json_result_success(self): """Test successful JSON parsing.""" raw_response = '{"summary": "Test summary", "tags": ["test"]}' - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertIsInstance(result, dict) self.assertIn("summary", result) @@ -113,7 +112,7 @@ def test_parse_json_result_success(self): def test_parse_json_result_failure(self): """Test failure in JSON parsing.""" raw_response = "Invalid JSON string" - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertEqual(result, {}) From 8943ba8b437d9b0f2bfe3ec4e93901c36b976314 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 20:49:32 +0800 Subject: [PATCH 12/21] fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. --- src/memos/mem_reader/simple_struct.py | 3 +- tests/mem_reader/test_simple_structure.py | 92 +++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 866b6d988..18bad7ab7 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -561,6 +561,7 @@ def add_before_search( self, messages: list[dict], memory_list: list[TextualMemoryItem], + info: dict[str, Any], ) -> list[TextualMemoryItem]: # Build input objects with memory text and metadata (timestamps, sources, etc.) template = PROMPT_MAPPING["add_before_search"] @@ -579,7 +580,7 @@ def add_before_search( for idx, mem in enumerate(memory_list): try: related_memories = self.searcher.search( - query=mem.memory, top_k=3, mode="fast", info={"user_id": "", "session_id": ""} + query=mem.memory, top_k=3, mode="fast", info=info ) related_text = "None" if related_memories: diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index fd07fbf41..987ff25ae 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -116,6 +116,98 @@ def test_parse_json_result_failure(self): self.assertEqual(result, {}) + def test_add_before_search(self): + """Test add_before_search method.""" + import json + + from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata + + # Mock searcher + self.reader.searcher = MagicMock() + self.reader.searcher.search.return_value = [ + TextualMemoryItem( + memory="Related memory 1", + metadata=TreeNodeTextualMemoryMetadata( + user_id="user1", + session_id="session1", + memory_type="LongTermMemory", + status="activated", + tags=[], + key="key1", + embedding=[0.1], + usage=[], + sources=[], + background="", + confidence=0.99, + type="fact", + info={}, + ), + ) + ] + + # Mock LLM response for filter + # The method expects a JSON response with keep/drop decisions + mock_response = json.dumps( + { + "0": {"keep": True, "reason": "Relevant"}, + "1": {"keep": False, "reason": "Duplicate"}, + } + ) + self.reader.llm.generate.return_value = mock_response + + messages = [{"role": "user", "content": "test message"}] + memory_list = [ + TextualMemoryItem( + memory="Mem 1", + metadata=TreeNodeTextualMemoryMetadata( + user_id="user1", + session_id="session1", + memory_type="LongTermMemory", + status="activated", + tags=[], + key="key1", + embedding=[0.1], + usage=[], + sources=[], + background="", + confidence=0.99, + type="fact", + info={}, + ), + ), + TextualMemoryItem( + memory="Mem 2", + metadata=TreeNodeTextualMemoryMetadata( + user_id="user1", + session_id="session1", + memory_type="LongTermMemory", + status="activated", + tags=[], + key="key2", + embedding=[0.1], + usage=[], + sources=[], + background="", + confidence=0.99, + type="fact", + info={}, + ), + ), + ] + info = {"user_id": "user1", "session_id": "session1"} + + # Call the method + result = self.reader.add_before_search(messages, memory_list, info) + + # Assertions + # Check if searcher.search was called with correct info + self.reader.searcher.search.assert_called_with( + query="Mem 2", top_k=3, mode="fast", info=info + ) + # Check result + self.assertEqual(len(result), 1) + self.assertEqual(result[0].memory, "Mem 1") + if __name__ == "__main__": unittest.main() From 78a43275f5d9550cda6514ea51b05ee64417d979 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 21:25:35 +0800 Subject: [PATCH 13/21] refactor add_before_search from mem_reader to SingleCubeView --- src/memos/mem_reader/simple_struct.py | 3 +- src/memos/mem_reader/utils.py | 53 ------------- src/memos/multi_mem_cube/single_cube.py | 101 ++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 54 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 18bad7ab7..fdd109079 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -561,6 +561,7 @@ def add_before_search( self, messages: list[dict], memory_list: list[TextualMemoryItem], + user_name: str, info: dict[str, Any], ) -> list[TextualMemoryItem]: # Build input objects with memory text and metadata (timestamps, sources, etc.) @@ -580,7 +581,7 @@ def add_before_search( for idx, mem in enumerate(memory_list): try: related_memories = self.searcher.search( - query=mem.memory, top_k=3, mode="fast", info=info + query=mem.memory, top_k=3, mode="fast", user_nam=user_name, info=info ) related_text = "None" if related_memories: diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py index 843345ec4..4e5a78af2 100644 --- a/src/memos/mem_reader/utils.py +++ b/src/memos/mem_reader/utils.py @@ -1,16 +1,7 @@ import json -import os import re -from typing import Any - from memos import log -from memos.api.config import APIConfig -from memos.configs.graph_db import GraphDBConfigFactory -from memos.configs.reranker import RerankerConfigFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher -from memos.reranker.factory import RerankerFactory logger = log.get_logger(__name__) @@ -164,47 +155,3 @@ def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]: "reason": reason, } return (len(result) > 0), result - - -def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: - graph_db_backend_map = { - "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), - "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), - "polardb": APIConfig.get_polardb_config(user_id=user_id), - } - - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() - return GraphDBConfigFactory.model_validate( - { - "backend": graph_db_backend, - "config": graph_db_backend_map[graph_db_backend], - } - ) - - -def build_reranker_config() -> dict[str, Any]: - return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) - - -def init_searcher(llm, embedder) -> Searcher: - """Initialize a Searcher instance for SimpleStructMemReader.""" - - # Build configs - graph_db_config = build_graph_db_config() - reranker_config = build_reranker_config() - - # Create instances - graph_db = GraphStoreFactory.from_config(graph_db_config) - reranker = RerankerFactory.from_config(reranker_config) - - # Create Searcher - searcher = Searcher( - dispatcher_llm=llm, - graph_store=graph_db, - embedder=embedder, - reranker=reranker, - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - ) - - return searcher diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 57f2cdba1..ab3d0ce03 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -15,6 +15,7 @@ ) from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_reader.utils import parse_keep_filter_response from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -23,6 +24,7 @@ PREF_ADD_TASK_LABEL, ) from memos.multi_mem_cube.views import MemCubeView +from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, FineStrategy, @@ -41,6 +43,7 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler + from memos.memories.textual.item import TextualMemoryItem @dataclass @@ -631,6 +634,104 @@ def _process_pref_mem( for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] + def add_before_search( + self, + messages: list[dict], + memory_list: list[TextualMemoryItem], + user_name: str, + info: dict[str, Any], + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + template = PROMPT_MAPPING["add_before_search"] + + if not self.searcher: + self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.") + return memory_list + + # 1. Gather candidates and search for related memories + candidates_data = [] + for idx, mem in enumerate(memory_list): + try: + related_memories = self.searcher.search( + query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info + ) + related_text = "None" + if related_memories: + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) + + candidates_data.append( + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} + ) + except Exception as e: + self.logger.error( + f"[add_before_search] Search error for memory '{mem.memory}': {e}" + ) + # If search fails, we can either skip this check or treat related as empty + candidates_data.append( + { + "idx": idx, + "new_memory": mem.memory, + "related_memories": "None (Search Failed)", + } + ) + + if not candidates_data: + return memory_list + + # 2. Build Prompt + messages_inline = "\n".join( + [ + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" + for message in messages + ] + ) + + candidates_inline_dict = { + str(item["idx"]): { + "new_memory": item["new_memory"], + "related_memories": item["related_memories"], + } + for item in candidates_data + } + + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) + + prompt = template.format( + messages_inline=messages_inline, candidates_inline=candidates_inline + ) + + # 3. Call LLM + try: + raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) + success, parsed_result = parse_keep_filter_response(raw) + + if not success: + self.logger.warning( + "[add_before_search] Failed to parse LLM response, keeping all." + ) + return memory_list + + # 4. Filter + filtered_list = [] + for idx, mem in enumerate(memory_list): + res = parsed_result.get(idx) + if not res: + filtered_list.append(mem) + continue + + if res.get("keep", True): + filtered_list.append(mem) + else: + self.logger.info( + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" + ) + + return filtered_list + + except Exception as e: + self.logger.error(f"[add_before_search] LLM execution error: {e}") + return memory_list + def _process_text_mem( self, add_req: APIADDRequest, From a5fc4c09c94c4ded8d153cf34f5c09dc19cc979a Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 24 Dec 2025 21:36:07 +0800 Subject: [PATCH 14/21] address bugs --- src/memos/mem_reader/simple_struct.py | 99 ----------------------- tests/mem_reader/test_simple_structure.py | 92 --------------------- 2 files changed, 191 deletions(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index fdd109079..70472958e 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -557,105 +557,6 @@ def filter_hallucination_in_memories( return memory_list - def add_before_search( - self, - messages: list[dict], - memory_list: list[TextualMemoryItem], - user_name: str, - info: dict[str, Any], - ) -> list[TextualMemoryItem]: - # Build input objects with memory text and metadata (timestamps, sources, etc.) - template = PROMPT_MAPPING["add_before_search"] - - if not self.searcher: - try: - from memos.mem_reader.utils import init_searcher - - self.searcher = init_searcher(self.llm, self.embedder) - except Exception as e: - logger.error(f"[add_before_search] Failed to init searcher: {e}") - return memory_list - - # 1. Gather candidates and search for related memories - candidates_data = [] - for idx, mem in enumerate(memory_list): - try: - related_memories = self.searcher.search( - query=mem.memory, top_k=3, mode="fast", user_nam=user_name, info=info - ) - related_text = "None" - if related_memories: - related_text = "\n".join([f"- {r.memory}" for r in related_memories]) - - candidates_data.append( - {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} - ) - except Exception as e: - logger.error(f"[add_before_search] Search error for memory '{mem.memory}': {e}") - # If search fails, we can either skip this check or treat related as empty - candidates_data.append( - { - "idx": idx, - "new_memory": mem.memory, - "related_memories": "None (Search Failed)", - } - ) - - if not candidates_data: - return memory_list - - # 2. Build Prompt - messages_inline = "\n".join( - [ - f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" - for message in messages - ] - ) - - candidates_inline_dict = { - str(item["idx"]): { - "new_memory": item["new_memory"], - "related_memories": item["related_memories"], - } - for item in candidates_data - } - - candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) - - prompt = template.format( - messages_inline=messages_inline, candidates_inline=candidates_inline - ) - - # 3. Call LLM - try: - raw = self.llm.generate([{"role": "user", "content": prompt}]) - success, parsed_result = parse_keep_filter_response(raw) - - if not success: - logger.warning("[add_before_search] Failed to parse LLM response, keeping all.") - return memory_list - - # 4. Filter - filtered_list = [] - for idx, mem in enumerate(memory_list): - res = parsed_result.get(idx) - if not res: - filtered_list.append(mem) - continue - - if res.get("keep", True): - filtered_list.append(mem) - else: - logger.info( - f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" - ) - - return filtered_list - - except Exception as e: - logger.error(f"[add_before_search] LLM execution error: {e}") - return memory_list - def _read_memory( self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index 987ff25ae..fd07fbf41 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -116,98 +116,6 @@ def test_parse_json_result_failure(self): self.assertEqual(result, {}) - def test_add_before_search(self): - """Test add_before_search method.""" - import json - - from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata - - # Mock searcher - self.reader.searcher = MagicMock() - self.reader.searcher.search.return_value = [ - TextualMemoryItem( - memory="Related memory 1", - metadata=TreeNodeTextualMemoryMetadata( - user_id="user1", - session_id="session1", - memory_type="LongTermMemory", - status="activated", - tags=[], - key="key1", - embedding=[0.1], - usage=[], - sources=[], - background="", - confidence=0.99, - type="fact", - info={}, - ), - ) - ] - - # Mock LLM response for filter - # The method expects a JSON response with keep/drop decisions - mock_response = json.dumps( - { - "0": {"keep": True, "reason": "Relevant"}, - "1": {"keep": False, "reason": "Duplicate"}, - } - ) - self.reader.llm.generate.return_value = mock_response - - messages = [{"role": "user", "content": "test message"}] - memory_list = [ - TextualMemoryItem( - memory="Mem 1", - metadata=TreeNodeTextualMemoryMetadata( - user_id="user1", - session_id="session1", - memory_type="LongTermMemory", - status="activated", - tags=[], - key="key1", - embedding=[0.1], - usage=[], - sources=[], - background="", - confidence=0.99, - type="fact", - info={}, - ), - ), - TextualMemoryItem( - memory="Mem 2", - metadata=TreeNodeTextualMemoryMetadata( - user_id="user1", - session_id="session1", - memory_type="LongTermMemory", - status="activated", - tags=[], - key="key2", - embedding=[0.1], - usage=[], - sources=[], - background="", - confidence=0.99, - type="fact", - info={}, - ), - ), - ] - info = {"user_id": "user1", "session_id": "session1"} - - # Call the method - result = self.reader.add_before_search(messages, memory_list, info) - - # Assertions - # Check if searcher.search was called with correct info - self.reader.searcher.search.assert_called_with( - query="Mem 2", top_k=3, mode="fast", info=info - ) - # Check result - self.assertEqual(len(result), 1) - self.assertEqual(result[0].memory, "Mem 1") - if __name__ == "__main__": unittest.main() From 45224ddb6880fee5dbe6e5580b47357c2bac25d7 Mon Sep 17 00:00:00 2001 From: chentang Date: Thu, 25 Dec 2025 11:45:01 +0800 Subject: [PATCH 15/21] fix: fix the qsize bug of task queue, and accept change from hotfix/scheduler --- src/memos/mem_scheduler/base_scheduler.py | 3 + .../task_schedule_modules/dispatcher.py | 3 - .../task_schedule_modules/local_queue.py | 111 +++++++++++------- 3 files changed, 73 insertions(+), 44 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 728203f5b..a2621eefc 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1008,6 +1008,9 @@ def _monitor_loop(self): try: q_sizes = self.memos_message_queue.qsize() + if not isinstance(q_sizes, dict): + continue + for stream_key, queue_length in q_sizes.items(): # Skip aggregate keys like 'total_size' if stream_key == "total_size": diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 35df3db64..e2c1621d4 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -128,9 +128,6 @@ def status_tracker(self) -> TaskStatusTracker | None: if self._status_tracker is None: try: self._status_tracker = TaskStatusTracker(self.redis) - # Propagate to submodules when created lazily - if self.dispatcher: - self.dispatcher.status_tracker = self._status_tracker if self.memos_message_queue: self.memos_message_queue.set_status_tracker(self._status_tracker) except Exception as e: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index eae70f8ef..791cedf41 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -4,9 +4,18 @@ the local memos_message_queue functionality in BaseScheduler. """ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -16,26 +25,38 @@ class SchedulerLocalQueue(RedisSchedulerModule): def __init__( self, - maxsize: int, + maxsize: int = 0, + stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX, + orchestrator: SchedulerOrchestrator | None = None, + status_tracker: TaskStatusTracker | None = None, ): """ Initialize the SchedulerLocalQueue with a maximum queue size limit. + Arguments match SchedulerRedisQueue for compatibility. Args: - maxsize (int): Maximum number of messages allowed - in each individual queue. - If exceeded, subsequent puts will block - or raise an exception based on `block` parameter. + maxsize (int): Maximum number of messages allowed in each individual queue. + stream_key_prefix (str): Prefix for stream keys (simulated). + orchestrator: SchedulerOrchestrator instance (ignored). + status_tracker: TaskStatusTracker instance (ignored). """ super().__init__() - self.stream_key_prefix = "local_queue" + self.stream_key_prefix = stream_key_prefix or "local_queue" self.max_internal_message_queue_size = maxsize + # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem] self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {} + + self.orchestrator = orchestrator + self.status_tracker = status_tracker + + self._is_listening = False + self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + logger.info( - f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" + f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}" ) def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: @@ -86,7 +107,7 @@ def get( stream_key: str, block: bool = True, timeout: float | None = None, - batch_size: int | None = None, + batch_size: int | None = 1, ) -> list[ScheduleMessageItem]: if batch_size is not None and batch_size <= 0: logger.warning( @@ -99,18 +120,19 @@ def get( logger.error(f"Stream {stream_key} does not exist when trying to get messages.") return [] + # Ensure we always request a batch so we get a list back + effective_batch_size = batch_size if batch_size is not None else 1 + # Note: Assumes custom Queue implementation supports batch_size parameter res = self.queue_streams[stream_key].get( - block=block, timeout=timeout, batch_size=batch_size + block=block, timeout=timeout, batch_size=effective_batch_size ) logger.debug( f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" ) return res - def get_nowait( - self, stream_key: str, batch_size: int | None = None - ) -> list[ScheduleMessageItem]: + def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]: """ Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). @@ -170,35 +192,13 @@ def qsize(self) -> dict: logger.debug(f"Current queue sizes: {sizes}") return sizes - def size(self) -> int: - """ - Get the current size of the queue (total message count). - Compatible with SchedulerRedisQueue. - """ - return self.unfinished_tasks - - def empty(self) -> bool: - """ - Check if the queue is empty. - Compatible with SchedulerRedisQueue. - """ - return self.size() == 0 - - def full(self) -> bool: - """ - Check if the queue is full. - Compatible with SchedulerRedisQueue. - """ - # Local queue limits are per-stream (max_internal_message_queue_size). - # It is considered full only if all streams are full. - if not self.queue_streams: - return False - - return all(queue.full() for queue in self.queue_streams.values()) - - def clear(self) -> None: - for queue in self.queue_streams.values(): - queue.clear() + def clear(self, stream_key: str | None = None) -> None: + if stream_key: + if stream_key in self.queue_streams: + self.queue_streams[stream_key].clear() + else: + for queue in self.queue_streams.values(): + queue.clear() @property def unfinished_tasks(self) -> int: @@ -216,3 +216,32 @@ def unfinished_tasks(self) -> int: total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total + + def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: + """ + Return list of active stream keys. + """ + prefix = stream_key_prefix or self.stream_key_prefix + return [k for k in self.queue_streams if k.startswith(prefix)] + + def size(self) -> int: + """ + Total size of all queues. + """ + return sum(q.qsize() for q in self.queue_streams.values()) + + def empty(self) -> bool: + """ + Check if all queues are empty. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if any queue is full (approximate). + """ + if self.max_internal_message_queue_size <= 0: + return False + return any( + q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values() + ) From f3c4f6ce4f7dc05686ddd94f98a1832a9bd2e08a Mon Sep 17 00:00:00 2001 From: chentang Date: Fri, 26 Dec 2025 17:33:13 +0800 Subject: [PATCH 16/21] fix: address some issues to run old scheduler example and kv cache example --- .../config/mem_scheduler/mem_cube_config.yaml | 21 ++ .../memos_config_w_scheduler.yaml | 12 +- .../mem_scheduler/quick_start_examples.py | 240 ++++++++++++++++++ src/memos/llms/hf.py | 13 +- src/memos/mem_os/core.py | 6 +- src/memos/mem_os/main.py | 2 +- src/memos/mem_os/utils/format_utils.py | 82 ++++-- src/memos/mem_reader/simple_struct.py | 2 +- src/memos/mem_scheduler/base_scheduler.py | 39 ++- .../general_modules/scheduler_logger.py | 30 ++- src/memos/mem_scheduler/general_scheduler.py | 65 +++-- .../mem_scheduler/monitors/general_monitor.py | 22 +- .../task_schedule_modules/redis_queue.py | 2 +- src/memos/memories/activation/kv.py | 32 ++- 14 files changed, 485 insertions(+), 83 deletions(-) create mode 100644 examples/data/config/mem_scheduler/mem_cube_config.yaml create mode 100644 examples/mem_scheduler/quick_start_examples.py diff --git a/examples/data/config/mem_scheduler/mem_cube_config.yaml b/examples/data/config/mem_scheduler/mem_cube_config.yaml new file mode 100644 index 000000000..398d8dbb3 --- /dev/null +++ b/examples/data/config/mem_scheduler/mem_cube_config.yaml @@ -0,0 +1,21 @@ +user_id: "user_test" +cube_id: "user_test/mem_cube_naive" +text_mem: + backend: "naive_text" + config: + extractor_llm: + backend: "huggingface_singleton" + config: + model_name_or_path: "Qwen/Qwen3-0.6B" + temperature: 0.1 + max_tokens: 1024 +act_mem: + backend: "kv_cache" + config: + memory_filename: "activation_memory.pickle" + extractor_llm: + backend: "huggingface_singleton" + config: + model_name_or_path: "Qwen/Qwen3-0.6B" + temperature: 0.8 + max_tokens: 1024 diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml index bd9910300..a5e91dc4e 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml @@ -10,16 +10,12 @@ mem_reader: backend: "simple_struct" config: llm: - backend: "openai" + backend: "huggingface_singleton" config: - model_name_or_path: "gpt-4o-mini" - temperature: 0.8 - max_tokens: 4096 - top_p: 0.9 - top_k: 50 + model_name_or_path: "Qwen/Qwen3-1.7B" + temperature: 0.1 remove_think_prefix: true - api_key: "sk-xxxxxx" - api_base: "https://api.openai.com/v1" + max_tokens: 4096 embedder: backend: "ollama" config: diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py new file mode 100644 index 000000000..c4142d2cd --- /dev/null +++ b/examples/mem_scheduler/quick_start_examples.py @@ -0,0 +1,240 @@ +import json +import shutil +import sys +import uuid + +from pathlib import Path + +from transformers import DynamicCache + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.configs.memory import MemoryConfigFactory +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) +from memos.mem_scheduler.utils.misc_utils import parse_yaml +from memos.memories.activation.item import KVCacheItem +from memos.memories.factory import MemoryFactory + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +def get_cache_info(cache): + if not cache: + return None + + num_layers = 0 + total_size_bytes = 0 + + if hasattr(cache, "layers"): + num_layers = len(cache.layers) + for layer in cache.layers: + if hasattr(layer, "key_cache") and layer.key_cache is not None: + total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size() + if hasattr(layer, "value_cache") and layer.value_cache is not None: + total_size_bytes += layer.value_cache.nelement() * layer.value_cache.element_size() + + if hasattr(layer, "keys") and layer.keys is not None: + total_size_bytes += layer.keys.nelement() * layer.keys.element_size() + if hasattr(layer, "values") and layer.values is not None: + total_size_bytes += layer.values.nelement() * layer.values.element_size() + + elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"): + num_layers = len(cache.key_cache) + for k, v in zip(cache.key_cache, cache.value_cache, strict=False): + if k is not None: + total_size_bytes += k.nelement() * k.element_size() + if v is not None: + total_size_bytes += v.nelement() * v.element_size() + + return { + "num_layers": num_layers, + "size_bytes": total_size_bytes, + "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB", + } + + +def serialize_item(obj): + if isinstance(obj, list): + return [serialize_item(x) for x in obj] + + if isinstance(obj, KVCacheItem): + return { + "id": obj.id, + "metadata": obj.metadata, + "records": obj.records.model_dump() + if hasattr(obj.records, "model_dump") + else obj.records, + "memory": get_cache_info(obj.memory), + } + + if isinstance(obj, DynamicCache): + return get_cache_info(obj) + + return str(obj) + + +def kv_cache_only(): + # 为 KVCacheMemory(HuggingFace 后端)创建配置 + config = MemoryConfigFactory( + backend="kv_cache", + config={ + "extractor_llm": { + "backend": "huggingface", + "config": { + "model_name_or_path": "Qwen/Qwen3-0.6B", + "max_tokens": 32, + "add_generation_prompt": True, + "remove_think_prefix": True, + }, + }, + }, + ) + + # 实例化 KVCacheMemory + kv_mem = MemoryFactory.from_config(config) + + # 提取一个 KVCacheItem(DynamicCache) + prompt = [ + {"role": "user", "content": "What is MemOS?"}, + {"role": "assistant", "content": "MemOS is a memory operating system for LLMs."}, + ] + print("===== Extract KVCacheItem =====") + cache_item = kv_mem.extract(prompt) + print(json.dumps(serialize_item(cache_item), indent=2, default=str)) + + # 将缓存添加到内存中 + kv_mem.add([cache_item]) + print("All caches:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + # 通过 ID 获取 + retrieved = kv_mem.get(cache_item.id) + print("Retrieved:") + print(json.dumps(serialize_item(retrieved), indent=2, default=str)) + + # 合并缓存 + item2 = kv_mem.extract([{"role": "user", "content": "Tell me a joke."}]) + kv_mem.add([item2]) + merged = kv_mem.get_cache([cache_item.id, item2.id]) + print("Merged cache:") + print(json.dumps(serialize_item(merged), indent=2, default=str)) + + # 删除其中一个 + kv_mem.delete([cache_item.id]) + print("After delete:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + # 导出和加载缓存 + kv_mem.dump("tmp/kv_mem") + print("Dumped to tmp/kv_mem") + kv_mem.delete_all() + kv_mem.load("tmp/kv_mem") + print("Loaded caches:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + +def run_scheduler_example(): + # 使用 MemScheduler 加载主 MOS 配置 + config = parse_yaml( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + ) + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + # 创建动态用户 ID + user_id = str(uuid.uuid4()) + mos.create_user(user_id=user_id) + + # 创建 MemCube 配置并导出 + config = GeneralMemCubeConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + ) + mem_cube_id = "mem_cube_5" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + + # 若存在旧目录则删除 + if Path(mem_cube_name_or_path).exists(): + shutil.rmtree(mem_cube_name_or_path) + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + + # 导出新的 MemCube + mem_cube = GeneralMemCube(config) + mem_cube.dump(mem_cube_name_or_path) + + # 为该用户注册 MemCube + mos.register_mem_cube( + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id + ) + + # Define custom scheduler handlers + def custom_query_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + print(f"\n[scheduler] 用户输入了query: {msg.content}") + # Trigger mem_update manually + new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL}) + mos.mem_scheduler.submit_messages([new_msg]) + + def custom_answer_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + print(f"\n[scheduler] LLM回复了answer:{msg.content}") + + def custom_mem_update_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + if mem_cube and mem_cube.text_mem: + results = mem_cube.text_mem.search(msg.content, top_k=3) + for mem in results: + print( + f"\n[scheduler] transform {mem.metadata.type} to working memory: {mem.memory} " + ) + + # Register custom handlers + mos.mem_scheduler.dispatcher.register_handlers( + { + QUERY_TASK_LABEL: custom_query_handler, + ANSWER_TASK_LABEL: custom_answer_handler, + MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, + } + ) + + # 添加消息 + messages = [ + {"role": "user", "content": "I like playing football."}, + {"role": "assistant", "content": "I like playing football too."}, + ] + mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id) + + # 聊天循环: 展示 TreeTextMemory 节点 + KVCache + while True: + user_input = input("👤 [You] ").strip() + print() + response = mos.chat(user_input, user_id=user_id) + retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id) + + print(f"🤖 [Assistant] {response}") + + # 展示 TreeTextMemory 中的各类型节点 + text_memories = retrieved_memories["text_mem"][0]["memories"] + # Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes) + if isinstance(text_memories, dict) and "nodes" in text_memories: + for node in text_memories["nodes"]: + mem_type = node["metadata"].get("memory_type", "Unknown") + print(f"[{mem_type}] {node['memory']}") + elif isinstance(text_memories, list): + for mem in text_memories: + # Naive memory items might not have memory_type metadata, or it might be different + print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}") + + +if __name__ == "__main__": + run_scheduler_example() diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index d46db7c9e..7dcf09940 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -1,6 +1,8 @@ from collections.abc import Generator from typing import Any +import torch + from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -37,9 +39,14 @@ def __init__(self, config: HFLLMConfig): self.config.model_name_or_path = "Qwen/Qwen3-1.7B" # Initialize hf model - self.model = AutoModelForCausalLM.from_pretrained( - self.config.model_name_or_path, torch_dtype="auto", device_map="auto" - ) + if torch.backends.mps.is_available(): + self.model = AutoModelForCausalLM.from_pretrained( + self.config.model_name_or_path, torch_dtype="auto" + ).to("mps") + else: + self.model = AutoModelForCausalLM.from_pretrained( + self.config.model_name_or_path, torch_dtype="auto", device_map="auto" + ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_name_or_path, use_fast=True ) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 1a88fa831..e7f01ec3e 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -311,7 +311,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = past_key_values = None if self.config.enable_activation_memory: - if self.config.chat_model.backend != "huggingface": + if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]: logger.error( "Activation memory only used for huggingface backend. Skipping activation memory." ) @@ -498,7 +498,9 @@ def register_mem_cube( existing_cube = self.user_manager.get_cube(mem_cube_id) # check the embedder is it consistent with MOSConfig - if self.config.mem_reader.config.embedder != ( + if hasattr( + self.mem_cubes[mem_cube_id].text_mem.config, "embedder" + ) and self.config.mem_reader.config.embedder != ( cube_embedder := self.mem_cubes[mem_cube_id].text_mem.config.embedder ): logger.warning( diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 0114fc0da..0dc6ab209 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -310,7 +310,7 @@ def _generate_enhanced_response_with_context( # Handle activation memory if enabled (same as core method) past_key_values = None if self.config.enable_activation_memory: - if self.config.chat_model.backend != "huggingface": + if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]: logger.error( "Activation memory only used for huggingface backend. Skipping activation memory." ) diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 5fdb59058..f6e33bb31 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -1087,38 +1087,64 @@ def convert_activation_memory_to_serializable( serializable_items = [] for item in act_mem_items: + key_layers = 0 + val_layers = 0 + device = "unknown" + dtype = "unknown" + key_shapes = [] + value_shapes = [] + + if item.memory: + if hasattr(item.memory, "layers"): + key_layers = len(item.memory.layers) + val_layers = len(item.memory.layers) + if key_layers > 0: + l0 = item.memory.layers[0] + k0 = getattr(l0, "key_cache", getattr(l0, "keys", None)) + if k0 is not None: + device = str(k0.device) + dtype = str(k0.dtype) + + for i, layer in enumerate(item.memory.layers): + k = getattr(layer, "key_cache", getattr(layer, "keys", None)) + v = getattr(layer, "value_cache", getattr(layer, "values", None)) + if k is not None: + key_shapes.append({"layer": i, "shape": list(k.shape)}) + if v is not None: + value_shapes.append({"layer": i, "shape": list(v.shape)}) + + elif hasattr(item.memory, "key_cache"): + key_layers = len(item.memory.key_cache) + val_layers = len(item.memory.value_cache) + if key_layers > 0 and item.memory.key_cache[0] is not None: + device = str(item.memory.key_cache[0].device) + dtype = str(item.memory.key_cache[0].dtype) + + for i, key_tensor in enumerate(item.memory.key_cache): + if key_tensor is not None: + key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) + + for i, val_tensor in enumerate(item.memory.value_cache): + if val_tensor is not None: + value_shapes.append({"layer": i, "shape": list(val_tensor.shape)}) + # Extract basic information that can be serialized serializable_item = { "id": item.id, "metadata": item.metadata, "memory_info": { "type": "DynamicCache", - "key_cache_layers": len(item.memory.key_cache) if item.memory else 0, - "value_cache_layers": len(item.memory.value_cache) if item.memory else 0, - "device": str(item.memory.key_cache[0].device) - if item.memory and item.memory.key_cache - else "unknown", - "dtype": str(item.memory.key_cache[0].dtype) - if item.memory and item.memory.key_cache - else "unknown", + "key_cache_layers": key_layers, + "value_cache_layers": val_layers, + "device": device, + "dtype": dtype, }, } # Add tensor shape information if available - if item.memory and item.memory.key_cache: - key_shapes = [] - value_shapes = [] - - for i, key_tensor in enumerate(item.memory.key_cache): - if key_tensor is not None: - key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) - - if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None: - value_shapes.append( - {"layer": i, "shape": list(item.memory.value_cache[i].shape)} - ) - + if key_shapes: serializable_item["memory_info"]["key_shapes"] = key_shapes + if value_shapes: serializable_item["memory_info"]["value_shapes"] = value_shapes serializable_items.append(serializable_item) @@ -1144,7 +1170,19 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[ total_parameters = 0 for item in act_mem_items: - if item.memory and item.memory.key_cache: + if not item.memory: + continue + + if hasattr(item.memory, "layers"): + total_layers += len(item.memory.layers) + for layer in item.memory.layers: + k = getattr(layer, "key_cache", getattr(layer, "keys", None)) + v = getattr(layer, "value_cache", getattr(layer, "values", None)) + if k is not None: + total_parameters += k.numel() + if v is not None: + total_parameters += v.numel() + elif hasattr(item.memory, "key_cache"): total_layers += len(item.memory.key_cache) # Calculate approximate parameter count diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 70472958e..61a7d2b6d 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -618,7 +618,7 @@ def _read_memory( messages=combined_messages, memory_list=original_memory_group, user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() - == "true", + == "false", ) serialized_revised_memories = json.dumps( [one.memory for one in revised_memory_list], indent=2 diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index a2621eefc..3f5c90b67 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -74,6 +74,7 @@ from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory +from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE @@ -198,13 +199,16 @@ def init_mem_cube( logger.error("mem_cube is None, cannot initialize", stack_info=True) self.mem_cube = mem_cube self.text_mem: TreeTextMemory = self.mem_cube.text_mem - self.reranker: HTTPBGEReranker = self.text_mem.reranker + self.reranker: HTTPBGEReranker = getattr(self.text_mem, "reranker", None) if searcher is None: - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, - process_llm=self.process_llm, - ) + if hasattr(self.text_mem, "get_searcher"): + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + process_llm=self.process_llm, + ) + else: + self.searcher = None else: self.searcher = searcher self.feedback_server = feedback_server @@ -540,6 +544,29 @@ def replace_working_memory( mem_cube=mem_cube, log_func_callback=self._submit_web_logs, ) + elif isinstance(text_mem_base, NaiveTextMemory): + # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up + logger.info( + f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates." + ) + + # Use query keywords if available, otherwise just basic monitoring + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + query_keywords = query_db_manager.obj.get_keywords_collections() + + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=new_memory, + ) + + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + memories_with_new_order = new_memory else: logger.error("memory_base is not supported") memories_with_new_order = new_memory diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 57d78676f..fd83ec86f 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -55,7 +55,11 @@ def create_autofilled_log_item( "mem_cube is None — this should not happen in production!", stack_info=True ) text_mem_base: TreeTextMemory = mem_cube.text_mem - current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) + + current_memory_sizes = {} + if hasattr(text_mem_base, "get_current_memory_size"): + current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) + current_memory_sizes = { "long_term_memory_size": current_memory_sizes.get("LongTermMemory", 0), "user_memory_size": current_memory_sizes.get("UserMemory", 0), @@ -63,14 +67,32 @@ def create_autofilled_log_item( "transformed_act_memory_size": NOT_INITIALIZED, "parameter_memory_size": NOT_INITIALIZED, } + memory_capacities = { - "long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"], - "user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"], - "working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"], + "long_term_memory_capacity": 0, + "user_memory_capacity": 0, + "working_memory_capacity": 0, "transformed_act_memory_capacity": NOT_INITIALIZED, "parameter_memory_capacity": NOT_INITIALIZED, } + if hasattr(text_mem_base, "memory_manager") and hasattr( + text_mem_base.memory_manager, "memory_size" + ): + memory_capacities.update( + { + "long_term_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "LongTermMemory", 0 + ), + "user_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "UserMemory", 0 + ), + "working_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "WorkingMemory", 0 + ), + } + ) + if hasattr(self, "monitor"): if ( user_id in self.monitor.activation_memory_monitors diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 86066f346..9b19e9ecb 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -34,6 +34,7 @@ is_cloud_env, ) from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory from memos.types import ( @@ -846,7 +847,9 @@ def _process_memories_with_reader( memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") + logger.warning( + f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}" + ) continue if not memory_items: @@ -1364,22 +1367,31 @@ def process_session_turn( text_mem_base = mem_cube.text_mem if not isinstance(text_mem_base, TreeTextMemory): - logger.error( - f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " - f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " - f"text_mem_base value: {text_mem_base}", - exc_info=True, + if isinstance(text_mem_base, NaiveTextMemory): + logger.debug( + f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search." + ) + # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic + # We will perform retrieval to get "working memory" candidates for activation memory + # But we won't have a distinct "current working memory" + cur_working_memory = [] + else: + logger.warning( + f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " + f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " + f"text_mem_base value: {text_mem_base}" + ) + return [], [] + else: + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( + user_name=mem_cube_id ) - return + cur_working_memory = cur_working_memory[:top_k] logger.info( f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" ) - cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( - user_name=mem_cube_id - ) - cur_working_memory = cur_working_memory[:top_k] text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] intent_result = self.monitor.detect_intent( q_list=queries, text_working_memory=text_working_memory @@ -1419,15 +1431,28 @@ def process_session_turn( ) search_args = {} - results: list[TextualMemoryItem] = self.retriever.search( - query=item, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=k_per_evidence, - method=self.search_method, - search_args=search_args, - ) + if isinstance(text_mem_base, NaiveTextMemory): + # NaiveTextMemory doesn't support complex search args usually, but let's see + # self.retriever.search calls mem_cube.text_mem.search + # NaiveTextMemory.search takes query and top_k + # SchedulerRetriever.search handles method dispatch + # For NaiveTextMemory, we might need to bypass retriever or extend it + # But let's try calling naive memory directly if retriever fails or doesn't support it + try: + results = text_mem_base.search(query=item, top_k=k_per_evidence) + except Exception as e: + logger.warning(f"NaiveTextMemory search failed: {e}") + results = [] + else: + results: list[TextualMemoryItem] = self.retriever.search( + query=item, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + search_args=search_args, + ) logger.info( f"[process_session_turn] Search results for missing evidence '{item}': " diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index b097b1e2d..d75d6ee75 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -200,15 +200,19 @@ def update_working_memory_monitors( mem_cube_id: str, mem_cube: GeneralMemCube, ): - text_mem_base: TreeTextMemory = mem_cube.text_mem - assert isinstance(text_mem_base, TreeTextMemory) - self.working_mem_monitor_capacity = min( - DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, - ( - int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) - + self.partial_retention_number - ), - ) + text_mem_base = mem_cube.text_mem + + if isinstance(text_mem_base, TreeTextMemory): + self.working_mem_monitor_capacity = min( + DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, + ( + int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) + + self.partial_retention_number + ), + ) + else: + # Fallback for NaiveTextMemory and others + self.working_mem_monitor_capacity = DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT # register monitors self.register_memory_manager_if_not_exists( diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 2f4318003..941c52164 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -787,7 +787,7 @@ def qsize(self) -> dict: Total number of messages across all matching streams. """ if not self._redis_conn: - return 0 + return {} total_size = 0 try: diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 98d611dbf..1981b958f 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -2,9 +2,7 @@ import pickle from datetime import datetime -from importlib.metadata import version -from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig @@ -211,10 +209,24 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: return caches[0] merged = DynamicCache() - num_layers = len(caches[0].key_cache) - if Version(version("transformers")) >= Version("4.54.0"): - merged.append_new_layers(num_layers - 1) + # Check for new structure (layers) + if hasattr(caches[0], "layers"): + num_layers = len(caches[0].layers) + + # Ensure merged has layers attribute and populate it + if not hasattr(merged, "layers"): + merged.layers = [] + + if num_layers > 0: + # Get the class of the layer from the first cache + # We assume all caches use the same layer class + layer_cls = type(caches[0].layers[0]) + + # Populate merged.layers + while len(merged.layers) < num_layers: + merged.layers.append(layer_cls()) + for layer in range(num_layers): # gather all K and V for this layer keys = [c.layers[layer].keys for c in caches] @@ -223,7 +235,10 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: merged.layers[layer].keys = torch.cat(keys, dim=-2) merged.layers[layer].values = torch.cat(vals, dim=-2) - else: + # Check for old structure (key_cache) + elif hasattr(caches[0], "key_cache"): + num_layers = len(caches[0].key_cache) + for layer in range(num_layers): # gather all K and V for this layer keys = [c.key_cache[layer] for c in caches] @@ -232,6 +247,11 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: merged.key_cache.append(torch.cat(keys, dim=-2)) merged.value_cache.append(torch.cat(vals, dim=-2)) + else: + raise AttributeError( + "DynamicCache object has neither 'layers' nor 'key_cache' attributes" + ) + return merged From e9b60db165c98e6a8f70602e2c838bdf004adc78 Mon Sep 17 00:00:00 2001 From: chentang Date: Fri, 26 Dec 2025 17:57:04 +0800 Subject: [PATCH 17/21] fix: address the issue of Top-level import of unavailable module 'torch' --- src/memos/llms/hf.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index 7dcf09940..b5fc4ba13 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -1,16 +1,8 @@ from collections.abc import Generator from typing import Any -import torch - from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, DynamicCache, - LogitsProcessorList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, ) from memos.configs.llm import HFLLMConfig @@ -32,6 +24,17 @@ def __init__(self, config: HFLLMConfig): """ Initialize the HFLLM model and tokenizer, and set up logits processors for sampling. """ + import torch + + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) + self.config = config # Default model if not specified @@ -362,6 +365,7 @@ def build_kv_cache(self, messages) -> DynamicCache: DynamicCache: The constructed KV cache object. """ import torch + import transformers # Accept multiple input types and convert to standard chat messages if isinstance(messages, str): @@ -398,7 +402,7 @@ def build_kv_cache(self, messages) -> DynamicCache: # Convert from legacy tuple format to DynamicCache if needed if isinstance(kv, tuple): - kv = DynamicCache.from_legacy_cache(kv) + kv = transformers.DynamicCache.from_legacy_cache(kv) # Handle compatibility between old and new transformers versions # In newer versions, DynamicCache uses 'layers' attribute From c6bdb22e14a13481b85cc96166d06d630dfc3c39 Mon Sep 17 00:00:00 2001 From: chentang Date: Fri, 26 Dec 2025 18:13:47 +0800 Subject: [PATCH 18/21] fix: resolve linting errors and make optional dependencies lazy loaded - Fix ambiguous characters and commented-out code in examples/mem_scheduler/quick_start_examples.py - Fix nested if statements in src/memos/mem_os/core.py - Move torch and transformers imports to method scope in src/memos/llms/hf.py to support optional dependencies - Update tests/llms/test_hf.py to patch transformers module directly --- tests/llms/test_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/llms/test_hf.py b/tests/llms/test_hf.py index 595995ad1..375bf2247 100644 --- a/tests/llms/test_hf.py +++ b/tests/llms/test_hf.py @@ -11,8 +11,8 @@ from memos.llms.hf import HFLLM -@patch("memos.llms.hf.AutoModelForCausalLM", MagicMock()) -@patch("memos.llms.hf.AutoTokenizer", MagicMock()) +@patch("transformers.AutoModelForCausalLM", MagicMock()) +@patch("transformers.AutoTokenizer", MagicMock()) class TestHFLLM(unittest.TestCase): def setUp(self): self.mock_inputs = MagicMock() From ad3620aaceb240403c5831a0d729222b1708669f Mon Sep 17 00:00:00 2001 From: chentang Date: Mon, 29 Dec 2025 20:51:19 +0800 Subject: [PATCH 19/21] refactor: revise the rewrite prompt to make it better --- src/memos/templates/mem_reader_prompts.py | 35 ++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 40971c77e..26795a2b1 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,7 +622,7 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. @@ -655,6 +655,39 @@ Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what is explicitly stated by the user in messages marked as [user]. Remove or flag anything not directly present in the user’s utterances—no assumptions, interpretations, predictions, generalizations, or content originating solely from [assistant]. +3. **Source Attribution Requirement**: + - Every memory must be clearly traceable to its source: + - If a fact appears **only in [assistant] messages** and **is not affirmed by [user]**, label it as “[assistant] memory”. + - If [assistant] states something and [user] explicitly contradicts or denies it, label it as “[assistant] memory, but [user] [brief quote or summary of denial]”. + - If a fact is stated by [user] —whether or not [assistant] also mentions it— it is attributed to “[user]” and may be retained without qualification. +4. **Timestamp Exception**: Memories may include timestamps (e.g., "On December 19, 2026") derived from conversation metadata. If such a date likely reflects the conversation time (even if not in the `messages` list), do NOT treat it as hallucinated—but still attribute it to “[user]” only if the user mentioned or confirmed the date. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference from [assistant]" + - "[assistant] memory, but [user] said 'I don't have a dog'" + - "fully grounded in [user]" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. From 24752864e4870961fdc43fc814713725acec7432 Mon Sep 17 00:00:00 2001 From: chentang Date: Tue, 30 Dec 2025 10:57:35 +0800 Subject: [PATCH 20/21] refactor: update examples --- .../mem_scheduler/quick_start_examples.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py index c4142d2cd..c71869e76 100644 --- a/examples/mem_scheduler/quick_start_examples.py +++ b/examples/mem_scheduler/quick_start_examples.py @@ -18,6 +18,7 @@ MEM_UPDATE_TASK_LABEL, QUERY_TASK_LABEL, ) +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.misc_utils import parse_yaml from memos.memories.activation.item import KVCacheItem from memos.memories.factory import MemoryFactory @@ -186,17 +187,27 @@ def custom_query_handler(messages: list[ScheduleMessageItem]): def custom_answer_handler(messages: list[ScheduleMessageItem]): for msg in messages: + mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + kv_mem = mem_cube.act_mem + for cache_item in kv_mem.get_all(): + print( + f"[scheduler] act memory: {get_cache_info(cache_item.memory)} ({cache_item.records})" + ) print(f"\n[scheduler] LLM回复了answer:{msg.content}") def custom_mem_update_handler(messages: list[ScheduleMessageItem]): for msg in messages: mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + kv_mem = mem_cube.act_mem if mem_cube and mem_cube.text_mem: results = mem_cube.text_mem.search(msg.content, top_k=3) for mem in results: - print( - f"\n[scheduler] transform {mem.metadata.type} to working memory: {mem.memory} " - ) + print(f"\n[scheduler] searched memories: {mem.memory}") + + cache_item = kv_mem.extract(mem.memory) + cache_item.records.text_memories = [mem.memory] + cache_item.records.timestamp = get_utc_now() + kv_mem.add([cache_item]) # Register custom handlers mos.mem_scheduler.dispatcher.register_handlers( @@ -237,4 +248,6 @@ def custom_mem_update_handler(messages: list[ScheduleMessageItem]): if __name__ == "__main__": + kv_cache_only() + run_scheduler_example() From a196dcbc0eb1b9da7ed21231cc0f2a048ab0d86b Mon Sep 17 00:00:00 2001 From: chentang Date: Tue, 30 Dec 2025 15:12:40 +0800 Subject: [PATCH 21/21] refactor: update examples for scheduler --- .../mem_scheduler/quick_start_examples.py | 139 ++++++++++++------ 1 file changed, 98 insertions(+), 41 deletions(-) diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py index c71869e76..fbfef4d76 100644 --- a/examples/mem_scheduler/quick_start_examples.py +++ b/examples/mem_scheduler/quick_start_examples.py @@ -145,106 +145,163 @@ def kv_cache_only(): def run_scheduler_example(): - # 使用 MemScheduler 加载主 MOS 配置 - config = parse_yaml( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" - ) + # 使用 MemScheduler 加载主 MOS(Memory-Oriented System)配置文件 + config = parse_yaml("./examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml") + # 将解析出的配置字典传入 MOSConfig 构造器, 构建配置对象 mos_config = MOSConfig(**config) + # 使用配置对象初始化 MOS 系统实例 mos = MOS(mos_config) - # 创建动态用户 ID + # 生成一个唯一的动态用户 ID(使用 UUID4) user_id = str(uuid.uuid4()) + # 在 MOS 系统中为该用户创建账户 mos.create_user(user_id=user_id) - # 创建 MemCube 配置并导出 + # 从 YAML 文件加载 MemCube(记忆立方体)的通用配置 config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + "./examples/data/config/mem_scheduler/mem_cube_config.yaml" ) + # 定义 MemCube 的唯一标识符 mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + # 定义 MemCube 的本地存储路径(路径中包含用户 ID 和 MemCube ID) + mem_cube_name_or_path = f"./outputs/mem_scheduler/{user_id}/{mem_cube_id}" - # 若存在旧目录则删除 + # 如果该路径已存在, 则先删除旧目录 if Path(mem_cube_name_or_path).exists(): shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + print(f"{mem_cube_name_or_path} 目录非空,已被删除。") - # 导出新的 MemCube + # 根据加载的配置创建一个新的 MemCube 实例 mem_cube = GeneralMemCube(config) + # 将该 MemCube 实例序列化并保存到指定路径 mem_cube.dump(mem_cube_name_or_path) - # 为该用户注册 MemCube + # 在 MOS 系统中为当前用户注册这个 MemCube mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) - # Define custom scheduler handlers + # 定义一个辅助函数, 用于获取缓存(如 KV Cache)的内存信息 + def get_cache_info(cache): + # 如果缓存为空, 则直接返回 None + if not cache: + return None + + num_layers = 0 # 记录缓存的层数 + total_size_bytes = 0 # 记录总字节数 + + # 情况一: 缓存结构包含 layers 属性(如 HuggingFace 的缓存格式) + if hasattr(cache, "layers"): + num_layers = len(cache.layers) + for layer in cache.layers: + # 统计 key_cache 的内存占用(如果存在) + if hasattr(layer, "key_cache") and layer.key_cache is not None: + total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size() + # 统计 value_cache 的内存占用(如果存在) + if hasattr(layer, "value_cache") and layer.value_cache is not None: + total_size_bytes += ( + layer.value_cache.nelement() * layer.value_cache.element_size() + ) + + # 兼容其他可能的缓存命名方式(如 keys/values) + if hasattr(layer, "keys") and layer.keys is not None: + total_size_bytes += layer.keys.nelement() * layer.keys.element_size() + if hasattr(layer, "values") and layer.values is not None: + total_size_bytes += layer.values.nelement() * layer.values.element_size() + + # 情况二: 缓存结构直接包含 key_cache 和 value_cache 列表(如某些自定义格式) + elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"): + num_layers = len(cache.key_cache) + for k, v in zip(cache.key_cache, cache.value_cache, strict=False): + if k is not None: + total_size_bytes += k.nelement() * k.element_size() + if v is not None: + total_size_bytes += v.nelement() * v.element_size() + + # 返回结构化的缓存信息, 包括层数, 字节数和以 MB 为单位的可读格式 + return { + "num_layers": num_layers, + "size_bytes": total_size_bytes, + "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB", + } + + # 定义自定义的查询(query)处理函数 def custom_query_handler(messages: list[ScheduleMessageItem]): for msg in messages: - print(f"\n[scheduler] 用户输入了query: {msg.content}") - # Trigger mem_update manually + # 打印用户输入内容 + print(f"\n[scheduler] 用户输入了查询:{msg.content}") + # 手动构造一个带有 MEM_UPDATE 标签的新消息, 用于触发记忆更新 new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL}) + # 将该消息提交给调度器处理 mos.mem_scheduler.submit_messages([new_msg]) + # 定义自定义的回答(answer)处理函数 def custom_answer_handler(messages: list[ScheduleMessageItem]): for msg in messages: - mem_cube = mos.mem_cubes.get(msg.mem_cube_id) - kv_mem = mem_cube.act_mem - for cache_item in kv_mem.get_all(): - print( - f"[scheduler] act memory: {get_cache_info(cache_item.memory)} ({cache_item.records})" - ) - print(f"\n[scheduler] LLM回复了answer:{msg.content}") + # 打印 LLM 的回复内容 + print(f"\n[scheduler] LLM 回复了答案:{msg.content}") + # 定义自定义的记忆更新(mem_update)处理函数 def custom_mem_update_handler(messages: list[ScheduleMessageItem]): for msg in messages: mem_cube = mos.mem_cubes.get(msg.mem_cube_id) kv_mem = mem_cube.act_mem + # 如果该 MemCube 配置了文本记忆(TreeTextMemory / NaiveTextMemory) if mem_cube and mem_cube.text_mem: + # 在文本记忆中搜索与当前内容相关的记忆(返回 top_k=3 条) results = mem_cube.text_mem.search(msg.content, top_k=3) for mem in results: - print(f"\n[scheduler] searched memories: {mem.memory}") - + print(f"\n[scheduler] 检索到的记忆:{mem.memory}") + print("\n[scheduler] 转换为激活记忆......") + # 从文本记忆中提取对应的 KV 缓存项 cache_item = kv_mem.extract(mem.memory) + # 附加元信息 cache_item.records.text_memories = [mem.memory] cache_item.records.timestamp = get_utc_now() + # 将该缓存项添加到激活记忆中 kv_mem.add([cache_item]) + print("\n[scheduler] 完成!") - # Register custom handlers + # 将上述三个自定义处理器注册到调度器的分发器中, 分别对应不同任务标签 mos.mem_scheduler.dispatcher.register_handlers( { - QUERY_TASK_LABEL: custom_query_handler, - ANSWER_TASK_LABEL: custom_answer_handler, - MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, + QUERY_TASK_LABEL: custom_query_handler, # 查询任务 + ANSWER_TASK_LABEL: custom_answer_handler, # 回答任务 + MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, # 记忆更新任务 } ) - # 添加消息 + # 初始添加两条测试消息(用户和助手的对话)到系统中 messages = [ {"role": "user", "content": "I like playing football."}, {"role": "assistant", "content": "I like playing football too."}, ] mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id) - # 聊天循环: 展示 TreeTextMemory 节点 + KVCache + # 进入聊天循环: 展示 TreeTextMemory 的记忆节点结构 + KV Cache 的状态 while True: + # 获取用户输入并去除首尾空格 user_input = input("👤 [You] ").strip() print() + # 调用 MOS 系统进行聊天响应 response = mos.chat(user_input, user_id=user_id) + # 获取该用户当前 MemCube 中的所有记忆内容 retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id) + # 打印助手的回复 print(f"🤖 [Assistant] {response}") - # 展示 TreeTextMemory 中的各类型节点 - text_memories = retrieved_memories["text_mem"][0]["memories"] - # Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes) - if isinstance(text_memories, dict) and "nodes" in text_memories: - for node in text_memories["nodes"]: - mem_type = node["metadata"].get("memory_type", "Unknown") - print(f"[{mem_type}] {node['memory']}") - elif isinstance(text_memories, list): - for mem in text_memories: - # Naive memory items might not have memory_type metadata, or it might be different - print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}") + # 获取文本记忆部分 - TreeTextMemory + memories = retrieved_memories["text_mem"][0]["memories"] + for mem in memories: + print(f"[文本记忆] {mem.memory}") + + # 获取对应的 MemCube 和其激活记忆(KV Cache) + mem_cube = mos.mem_scheduler.mem_cube + kv_mem = mem_cube.act_mem + # 遍历所有激活记忆项, 打印其缓存信息和记录 + for cache_item in kv_mem.get_all(): + print(f"[激活记忆] {get_cache_info(cache_item.memory)} (记录:{cache_item.records})") if __name__ == "__main__":