From 11b748fab0c699656ae05a3b337a436516d12b47 Mon Sep 17 00:00:00 2001 From: pursues <15180521816@163.com> Date: Wed, 24 Dec 2025 21:50:01 +0800 Subject: [PATCH 01/48] update requirements (#772) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * docker start * docker start * update config * 代码检测 * test_start_api * test_start_api * test_start_api * fix docker start * update start_api * update start_api * update start_api * 代码检测 * update start_api * update .env.example * back start_api * update * update Dockerfile * update requirements * back Dockerfile * upadte Dockerfile * add --------- Co-authored-by: yjy --- docker/.env.example | 20 +++- docker/Dockerfile | 2 +- docker/requirements-full.txt | 186 +++++++++++++++++++++++++++++++ docker/requirements.txt | 205 ++++++++++++++--------------------- 4 files changed, 290 insertions(+), 123 deletions(-) create mode 100644 docker/requirements-full.txt diff --git a/docker/.env.example b/docker/.env.example index 85d9080a5..ca3abde94 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -20,7 +20,7 @@ MOS_TOP_K=50 ## Chat LLM (main dialogue) MOS_CHAT_MODEL=gpt-4o-mini MOS_CHAT_TEMPERATURE=0.8 -MOS_MAX_TOKENS=8000 +MOS_MAX_TOKENS=2048 MOS_TOP_P=0.9 MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm MOS_MODEL_SCHEMA=memos.configs.llm.VLLMLLMConfig # vllm only: config class path; keep default unless you extend it @@ -51,9 +51,18 @@ MOS_RERANKER_HEADERS_EXTRA= # extra headers, JSON string, e.g. {"A MOS_RERANKER_STRATEGY=single_turn MOS_RERANK_SOURCE= # optional rerank scope, e.g., history/stream/custom + +# External Services (for evaluation scripts) +ZEP_API_KEY=your_zep_api_key_here +MEM0_API_KEY=your_mem0_api_key_here +MODEL=gpt-4o-mini +EMBEDDING_MODEL=nomic-embed-text:latest + ## Internet search & preference memory ENABLE_INTERNET=false BOCHA_API_KEY= # required if ENABLE_INTERNET=true +XINYU_API_KEY= +XINYU_SEARCH_ENGINE_ID= SEARCH_MODE=fast # fast | fine | mixture FAST_GRAPH=false BM25_CALL=false @@ -121,6 +130,7 @@ POLAR_DB_USER=root POLAR_DB_PASSWORD=123456 POLAR_DB_DB_NAME=shared_memos_db POLAR_DB_USE_MULTI_DB=false +POLARDB_POOL_MAX_CONN=100 ## Redis (scheduler queue) — fill only if you want scheduler queues in Redis; otherwise in-memory queue is used REDIS_HOST=localhost # global Redis endpoint (preferred over MEMSCHEDULER_*) @@ -170,3 +180,11 @@ OSS_PUBLIC_BASE_URL= ## SDK / external client MEMOS_API_KEY= MEMOS_BASE_URL=https://memos.memtensor.cn/api/openmem/v1 + +CHAT_MODEL_LIST='[{ + "backend": "deepseek", + "api_base": "http://localhost:1234", + "api_key": "your-api-key", + "model_name_or_path": "deepseek-r1", + "support_models": ["deepseek-r1"] +}]' diff --git a/docker/Dockerfile b/docker/Dockerfile index 29636881c..13fb477d9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,4 +32,4 @@ ENV PYTHONPATH=/app/src EXPOSE 8000 # Start the docker -CMD ["uvicorn", "memos.api.product_api:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] +CMD ["uvicorn", "memos.api.server_api:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] \ No newline at end of file diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt new file mode 100644 index 000000000..538f5e578 --- /dev/null +++ b/docker/requirements-full.txt @@ -0,0 +1,186 @@ +# Generated from poetry.lock - Main dependencies +# This file contains all transitive dependencies for the production build. + +annotated-types==0.7.0 +anyio==4.9.0 +async-timeout==5.0.1 +attrs==25.3.0 +authlib==1.6.0 +beautifulsoup4==4.13.4 +cachetools==6.2.1 +certifi==2025.7.14 +cffi==1.17.1 +charset-normalizer==3.4.2 +chonkie==1.1.1 +click==8.2.1 +cobble==0.1.4 +colorama==0.4.6 +coloredlogs==15.0.1 +concurrent-log-handler==0.9.28 +cryptography==45.0.5 +cyclopts==3.22.2 +datasketch==1.6.5 +defusedxml==0.7.1 +distro==1.9.0 +dnspython==2.7.0 +docstring-parser==0.16 +docutils==0.21.2 +email-validator==2.2.0 +et-xmlfile==2.0.0 +exceptiongroup==1.3.0 +fastapi==0.115.14 +fastapi-cli==0.0.8 +fastapi-cloud-cli==0.1.4 +fastmcp==2.10.5 +filelock==3.18.0 +flatbuffers==25.2.10 +fsspec==2025.7.0 +greenlet==3.2.3 +grpcio==1.73.1 +h11==0.16.0 +h2==4.2.0 +hf-xet==1.1.5 +hpack==4.1.0 +httpcore==1.0.9 +httptools==0.6.4 +httpx==0.28.1 +httpx-sse==0.4.1 +huggingface-hub==0.33.4 +humanfriendly==10.0 +hyperframe==6.1.0 +idna==3.10 +itsdangerous==2.2.0 +jieba==0.42 +jinja2==3.1.6 +jiter==0.10.0 +joblib==1.5.1 +jsonpatch==1.33 +jsonpointer==3.0.0 +jsonschema==4.24.1 +jsonschema-specifications==2025.4.1 +langchain-core==1.1.0 +langchain-text-splitters==1.0.0 +langsmith==0.4.7 +lxml==6.0.0 +magika==0.6.2 +mammoth==1.9.1 +markdown-it-py==3.0.0 +markdownify==1.1.0 +markitdown==0.1.2 +markupsafe==3.0.2 +mcp==1.12.0 +mdurl==0.1.2 +mpmath==1.3.0 +neo4j==5.28.1 +networkx==3.5 +nltk==3.9.1 +numpy==2.3.1 +nvidia-cublas-cu12==12.6.4.1 +nvidia-cuda-cupti-cu12==12.6.80 +nvidia-cuda-nvrtc-cu12==12.6.77 +nvidia-cuda-runtime-cu12==12.6.77 +nvidia-cudnn-cu12==9.5.1.17 +nvidia-cufft-cu12==11.3.0.4 +nvidia-cufile-cu12==1.11.1.6 +nvidia-curand-cu12==10.3.7.77 +nvidia-cusolver-cu12==11.7.1.2 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cusparselt-cu12==0.6.3 +nvidia-nccl-cu12==2.26.2 +nvidia-nvjitlink-cu12==12.6.85 +nvidia-nvtx-cu12==12.6.77 +ollama==0.4.9 +onnxruntime==1.22.1 +openai==1.97.0 +openapi-pydantic==0.5.1 +openpyxl==3.1.5 +orjson==3.11.0 +packaging==25.0 +pandas==2.3.1 +pdfminer-six==20250506 +pika==1.3.2 +pillow==11.3.0 +portalocker==2.10.1 +prometheus-client==0.23.1 +protobuf==6.31.1 +pycparser==2.22 +pydantic==2.11.7 +pydantic-core==2.33.2 +pydantic-extra-types==2.10.5 +pydantic-settings==2.10.1 +pygments==2.19.2 +pymilvus==2.6.2 +pymysql==1.1.2 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-multipart==0.0.20 +python-pptx==1.0.2 +pytz==2025.2 +pyyaml==6.0.2 +qdrant-client==1.14.3 +rake-nltk==1.0.6 +rank-bm25==0.2.2 +redis==6.2.0 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.4 +requests-toolbelt==1.0.0 +rich==14.0.0 +rich-rst==1.3.1 +rich-toolkit==0.14.8 +rignore==0.6.2 +rpds-py==0.26.0 +safetensors==0.5.3 +schedule==1.2.2 +scikit-learn==1.7.0 +scipy==1.16.0 +sentence-transformers==4.1.0 +sentry-sdk==2.33.0 +setuptools==80.9.0 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +sqlalchemy==2.0.41 +sse-starlette==2.4.1 +starlette==0.46.2 +sympy==1.14.0 +tenacity==9.1.2 +threadpoolctl==3.6.0 +tokenizers==0.21.2 +torch +tqdm==4.67.1 +transformers==4.53.2 +triton==3.5.0 +typer==0.16.0 +typing-extensions +typing-inspection==0.4.1 +tzdata==2025.2 +ujson==5.10.0 +urllib3==2.5.0 +uvicorn==0.35.0 +uvloop==0.21.0 +volcengine-python-sdk==4.0.6 +watchfiles==1.1.0 +websockets==15.0.1 +xlrd==2.0.2 +xlsxwriter==3.2.5 +zstandard==0.23.0 +prometheus_client==0.23.1 +beartype==0.22.5 +diskcache==5.6.3 +iniconfig==2.3.0 +jaraco.classes==3.4.0 +jaraco.context==6.0.1 +jaraco.functools==4.3.0 +keyring==25.6.0 +more-itertools==10.8.0 +pathable==0.4.4 +pathvalidate==3.3.1 +platformdirs==4.5.0 +pluggy==1.6.0 +psycopg2-binary==2.9.9 +py-key-value-aio==0.2.8 +py-key-value-shared==0.2.8 +PyJWT==2.10.1 +pytest==9.0.2 \ No newline at end of file diff --git a/docker/requirements.txt b/docker/requirements.txt index 8890ce679..738a53920 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -1,6 +1,3 @@ -# Docker optimized requirements - Core dependencies only -# Excludes Windows-specific and heavy GPU packages for faster builds - annotated-types==0.7.0 anyio==4.9.0 async-timeout==5.0.1 @@ -20,146 +17,112 @@ cryptography==45.0.5 cyclopts==3.22.2 defusedxml==0.7.1 distro==1.9.0 -dnspython==2.7.0 -docstring-parser==0.16 -docutils==0.21.2 -email-validator==2.2.0 -et-xmlfile==2.0.0 +dnspython==2.8.0 +docstring_parser==0.17.0 +docutils==0.22.3 +email-validator==2.3.0 exceptiongroup==1.3.0 -fastapi-cli==0.0.8 -fastapi-cloud-cli==0.1.4 fastapi==0.115.14 -fastmcp==2.10.5 -filelock==3.18.0 -flatbuffers==25.2.10 -fsspec==2025.7.0 -greenlet==3.2.3 -grpcio==1.73.1 +fastapi-cli==0.0.16 +fastapi-cloud-cli==0.3.1 +fastmcp==2.13.0.2 +filelock==3.20.0 +fsspec==2025.10.0 +grpcio==1.76.0 +neo4j==5.28.1 h11==0.16.0 -h2==4.2.0 -hf-xet==1.1.5 -hpack==4.1.0 +hf-xet==1.2.0 httpcore==1.0.9 -httptools==0.6.4 -httpx-sse==0.4.1 +httptools==0.7.1 httpx==0.28.1 -huggingface-hub==0.33.4 -humanfriendly==10.0 -hyperframe==6.1.0 -idna==3.10 +httpx-sse==0.4.3 +huggingface-hub==0.36.0 +idna==3.11 +iniconfig==2.3.0 itsdangerous==2.2.0 -jinja2==3.1.6 -jiter==0.10.0 -joblib==1.5.1 -jsonschema-specifications==2025.4.1 -jsonschema==4.24.1 -lxml==6.0.0 -magika==0.6.2 -mammoth==1.9.1 -markdown-it-py==3.0.0 -markdownify==1.1.0 -markitdown==0.1.2 -markupsafe==3.0.2 -mcp==1.12.0 +jaraco.classes==3.4.0 +jaraco.context==6.0.1 +jaraco.functools==4.3.0 +jieba==0.42 +Jinja2==3.1.6 +jiter==0.12.0 +joblib==1.5.2 +jsonschema==4.25.1 +jsonschema-path==0.3.4 +jsonschema-specifications==2025.9.1 +keyring==25.6.0 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +mcp==1.21.1 mdurl==0.1.2 -mpmath==1.3.0 -neo4j==5.28.1 -networkx==3.5 -numpy==2.3.1 -# NVIDIA CUDA packages excluded for lighter Docker images -# If GPU support is needed, uncomment relevant packages below: -# nvidia-cublas-cu12==12.6.4.1 -# nvidia-cuda-cupti-cu12==12.6.80 -# nvidia-cuda-nvrtc-cu12==12.6.77 -# nvidia-cuda-runtime-cu12==12.6.77 -# nvidia-cudnn-cu12==9.5.1.17 -# nvidia-cufft-cu12==11.3.0.4 -# nvidia-cufile-cu12==1.11.1.6 -# nvidia-curand-cu12==10.3.7.77 -# nvidia-cusolver-cu12==11.7.1.2 -# nvidia-cusparse-cu12==12.5.4.2 -# nvidia-cusparselt-cu12==0.6.3 -# nvidia-nccl-cu12==2.26.2 -# nvidia-nvjitlink-cu12==12.6.85 -# nvidia-nvtx-cu12==12.6.77 +more-itertools==10.8.0 +numpy==2.3.4 ollama==0.4.9 -onnxruntime==1.22.1 -openai==1.97.0 +openai==1.109.1 openapi-pydantic==0.5.1 -openpyxl==3.1.5 -orjson==3.11.0 +orjson==3.11.4 packaging==25.0 -pandas==2.3.1 -pdfminer-six==20250506 +pandas==2.3.3 +pathable==0.4.4 +pathvalidate==3.3.1 pika==1.3.2 -pillow==11.3.0 -portalocker==2.10.1 -protobuf==6.31.1 -pycparser==2.22 -pydantic-core==2.33.2 -pydantic-extra-types==2.10.5 -pydantic-settings==2.10.1 -pydantic==2.11.7 -pygments==2.19.2 -pymysql==1.1.1 -pyperclip==1.9.0 -# Windows-specific packages excluded: -# pyreadline3==3.5.4 # Windows only -# pywin32==311 # Windows only +platformdirs==4.5.0 +pluggy==1.6.0 +portalocker==3.2.0 +prometheus_client==0.23.1 +protobuf==6.33.1 +psycopg2-binary==2.9.9 +py-key-value-aio==0.2.8 +py-key-value-shared==0.2.8 +pycparser==2.23 +pydantic==2.12.4 +pydantic-extra-types==2.10.6 +pydantic-settings==2.12.0 +pydantic_core==2.41.5 +Pygments==2.19.2 +PyJWT==2.10.1 +pymilvus==2.6.5 +PyMySQL==1.1.2 +pyperclip==1.11.0 +pytest==9.0.2 python-dateutil==2.9.0.post0 -python-dotenv==1.1.1 +python-dotenv==1.2.1 python-multipart==0.0.20 -python-pptx==1.0.2 pytz==2025.2 -pyyaml==6.0.2 -qdrant-client==1.14.3 -redis==6.2.0 +PyYAML==6.0.3 +qdrant-client +redis==6.4.0 referencing==0.36.2 -regex==2024.11.6 -requests==2.32.4 -rich-rst==1.3.1 -rich-toolkit==0.14.8 -rich==14.0.0 -rignore==0.6.2 -rpds-py==0.26.0 -safetensors==0.5.3 -schedule==1.2.2 -scikit-learn==1.7.0 -scipy==1.16.0 -sentence-transformers==4.1.0 -sentry-sdk==2.33.0 +regex==2025.11.3 +requests==2.32.5 +rich==14.2.0 +rich-rst==1.3.2 +rich-toolkit==0.15.1 +rignore==0.7.6 +rpds-py==0.28.0 +safetensors==0.6.2 +scikit-learn==1.7.2 +scipy==1.16.3 +sentry-sdk==2.44.0 setuptools==80.9.0 shellingham==1.5.4 six==1.17.0 sniffio==1.3.1 -soupsieve==2.7 -sqlalchemy==2.0.41 -sse-starlette==2.4.1 +SQLAlchemy==2.0.44 +sse-starlette==3.0.3 starlette==0.46.2 -sympy==1.14.0 tenacity==9.1.2 threadpoolctl==3.6.0 -tokenizers==0.21.2 -# Torch excluded for lighter Docker images (very large package ~2GB) -# If needed for ML/AI features, uncomment: -# torch==2.7.1 -# triton==3.3.1 +tokenizers==0.22.1 tqdm==4.67.1 -transformers==4.53.2 -typer==0.16.0 -typing-extensions==4.14.1 -typing-inspection==0.4.1 +transformers==4.57.1 +typer==0.20.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 tzdata==2025.2 -ujson==5.10.0 +ujson==5.11.0 urllib3==2.5.0 -uvicorn==0.35.0 -uvloop==0.21.0 -volcengine-python-sdk==4.0.6 -watchfiles==1.1.0 -websockets==15.0.1 -xlrd==2.0.2 -xlsxwriter==3.2.5 -prometheus-client==0.23.1 -pymilvus==2.5.12 -nltk==3.9.1 -rake-nltk==1.0.6 +uvicorn==0.38.0 +uvloop==0.22.1 +watchfiles==1.1.1 +websockets==15.0.1 \ No newline at end of file From b27beadd82a88d2cf794ba6d2024a078d5822d08 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Wed, 24 Dec 2025 21:51:25 +0800 Subject: [PATCH 02/48] update scheduler and add operation for dehallucination (#769) * fix bugs: try to fix bugs in _submit_web_logs * fix bugs: try to address bugs * fix bugs * refactor: modify examples * revise add operation and fix an unbelievable bug * address the bug issues * the doc file has a format problem which has been fixed in this commit * add a range of new feats for the add operation * address the incompatible issue of local scheduler * feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. * fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. * fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. * refactor add_before_search from mem_reader to SingleCubeView * address bugs --- docs/README.md | 2 +- ..._rerun.py => scheduler_for_async_tasks.py} | 6 +- src/memos/api/config.py | 21 +- src/memos/llms/openai.py | 6 +- src/memos/mem_reader/simple_struct.py | 257 ++++++++---------- src/memos/mem_reader/utils.py | 157 +++++++++++ src/memos/mem_scheduler/base_scheduler.py | 20 +- .../mem_scheduler/schemas/general_schemas.py | 4 +- .../task_schedule_modules/local_queue.py | 74 ++++- .../task_schedule_modules/redis_queue.py | 97 ++++--- .../task_schedule_modules/task_queue.py | 23 +- .../mem_scheduler/utils/status_tracker.py | 26 +- .../textual/prefer_text_memory/extractor.py | 4 + src/memos/multi_mem_cube/single_cube.py | 101 +++++++ src/memos/templates/mem_reader_prompts.py | 156 ++++++++++- tests/mem_reader/test_simple_structure.py | 7 +- 16 files changed, 713 insertions(+), 248 deletions(-) rename examples/mem_scheduler/{task_stop_rerun.py => scheduler_for_async_tasks.py} (98%) create mode 100644 src/memos/mem_reader/utils.py diff --git a/docs/README.md b/docs/README.md index bf5fea70d..8be17ffb7 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,3 +1,3 @@ All documentation has been moved to a separate repository: https://github.com/MemTensor/MemOS-Docs. Please edit documentation there. -所有文档已迁移至独立仓库:https://github.com/MemTensor/MemOS-Docs。请在该仓库中编辑文档。 +所有文档已迁移至独立仓库 https://github.com/MemTensor/MemOS-Docs 。请在该仓库中编辑文档。 diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/scheduler_for_async_tasks.py similarity index 98% rename from examples/mem_scheduler/task_stop_rerun.py rename to examples/mem_scheduler/scheduler_for_async_tasks.py index b5e62ff8f..a767b57c4 100644 --- a/examples/mem_scheduler/task_stop_rerun.py +++ b/examples/mem_scheduler/scheduler_for_async_tasks.py @@ -25,7 +25,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): task_id = str(msg.item_id) file_path = tmp_dir / f"{task_id}.txt" try: - sleep(1) + sleep(5) file_path.write_text(f"Task {task_id} processed.\n") print(f"writing {file_path} done") except Exception as e: @@ -58,7 +58,7 @@ def submit_tasks(): mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) # 10s to restart -mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 10_000 +mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 5_000 tmp_dir = Path("./tmp") tmp_dir.mkdir(exist_ok=True) @@ -88,6 +88,6 @@ def submit_tasks(): print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})") # 7. Stop the scheduler +sleep(20) print("Stopping the scheduler...") -sleep(5) mem_scheduler.stop() diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 48a16a6e2..7298658ff 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -7,16 +7,19 @@ import re import time -from typing import Any +from typing import TYPE_CHECKING, Any import requests from dotenv import load_dotenv -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig from memos.context.context import ContextThread -from memos.mem_cube.general import GeneralMemCube + + +if TYPE_CHECKING: + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube # Load environment variables @@ -805,8 +808,12 @@ def get_start_default_config() -> dict[str, Any]: return config @staticmethod - def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]: + def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "GeneralMemCube"]: """Create configuration for a specific user.""" + from memos.configs.mem_cube import GeneralMemCubeConfig + from memos.configs.mem_os import MOSConfig + from memos.mem_cube.general import GeneralMemCube + openai_config = APIConfig.get_openai_config() qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() @@ -933,12 +940,14 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General return default_config, default_mem_cube @staticmethod - def get_default_cube_config() -> GeneralMemCubeConfig | None: + def get_default_cube_config() -> "GeneralMemCubeConfig | None": """Get default cube configuration for product initialization. Returns: GeneralMemCubeConfig | None: Default cube configuration if enabled, None otherwise. """ + from memos.configs.mem_cube import GeneralMemCubeConfig + if not APIConfig.is_default_cube_config_enabled(): return None diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 563b8723e..ea488329d 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -59,8 +59,8 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) if reasoning_content: - return reasoning_content + response_content - return response_content + return reasoning_content + (response_content or "") + return response_content or "" @timed_with_status( log_prefix="OpenAI LLM Stream", @@ -151,7 +151,7 @@ def generate(self, messages: MessageList, **kwargs) -> str: if self.config.remove_think_prefix: return remove_thinking_tags(response_content) else: - return response_content + return response_content or "" def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from Azure OpenAI LLM with optional reasoning support.""" diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index b870bf70a..70472958e 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -2,7 +2,6 @@ import copy import json import os -import re import traceback from abc import ABC @@ -18,6 +17,13 @@ from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang +from memos.mem_reader.utils import ( + count_tokens_text, + derive_key, + parse_json_result, + parse_keep_filter_response, + parse_rewritten_response, +) from memos.memories.textual.item import ( SourceMessage, TextualMemoryItem, @@ -89,27 +95,6 @@ def from_config(_config): } -try: - import tiktoken - - try: - _ENC = tiktoken.encoding_for_model("gpt-4o-mini") - except Exception: - _ENC = tiktoken.get_encoding("cl100k_base") - - def _count_tokens_text(s: str) -> int: - return len(_ENC.encode(s or "", disallowed_special=())) -except Exception: - # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars - def _count_tokens_text(s: str) -> int: - if not s: - return 0 - zh_chars = re.findall(r"[\u4e00-\u9fff]", s) - zh = len(zh_chars) - rest = len(s) - zh - return zh + max(1, rest // 4) - - def _build_node(idx, message, info, source_info, llm, parse_json_result, embedder): # generate try: @@ -172,14 +157,6 @@ def _build_node(idx, message, info, source_info, llm, parse_json_result, embedde return None -def _derive_key(text: str, max_len: int = 80) -> str: - """default key when without LLM: first max_len words""" - if not text: - return "" - sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] - return (sent[:max_len]).strip() - - class SimpleStructMemReader(BaseMemReader, ABC): """Naive implementation of MemReader.""" @@ -197,7 +174,8 @@ def __init__(self, config: SimpleStructMemReaderConfig): self.memory_max_length = 8000 # Use token-based windowing; default to ~5000 tokens if not configured self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024) - self._count_tokens = _count_tokens_text + self._count_tokens = count_tokens_text + self.searcher = None def _make_memory_item( self, @@ -224,7 +202,7 @@ def _make_memory_item( memory_type=memory_type, status="activated", tags=tags or [], - key=key if key is not None else _derive_key(value), + key=key if key is not None else derive_key(value), embedding=self.embedder.embed([value])[0], usage=[], sources=sources or [], @@ -254,7 +232,7 @@ def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { @@ -456,47 +434,73 @@ def get_memory( standard_scene_data = coerce_scene_data(scene_data, type) return self._read_memory(standard_scene_data, type, info, mode) - @staticmethod - def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]: - """Parse index-keyed JSON from hallucination filter response. - Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } - Returns (success, parsed_dict) with int keys. - """ + def rewrite_memories( + self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + if user_only: + template = PROMPT_MAPPING["rewrite_user_only"] + filtered_messages = [m for m in messages if m.get("role") != "assistant"] + if len(filtered_messages) < 1: + return memory_list + else: + template = PROMPT_MAPPING["rewrite"] + filtered_messages = messages + if len(filtered_messages) < 2: + return memory_list + + prompt_args = { + "messages_inline": "\n".join( + [f"- [{message['role']}]: {message['content']}" for message in filtered_messages] + ), + "memories_inline": json.dumps( + {idx: mem.memory for idx, mem in enumerate(memory_list)}, + ensure_ascii=False, + indent=2, + ), + } + prompt = template.format(**prompt_args) + + # Optionally run filter and parse the output try: - data = json.loads(text) - except Exception: - return False, {} + raw = self.llm.generate([{"role": "user", "content": prompt}]) + success, parsed = parse_rewritten_response(raw) + logger.info( + f"[rewrite_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" + ) + if success: + logger.info(f"Rewrite filter result: {parsed}") - if not isinstance(data, dict): - return False, {} + new_memory_list = [] + for mem_idx, content in parsed.items(): + if mem_idx < 0 or mem_idx >= len(memory_list): + logger.warning( + f"[rewrite_memories] Invalid memory index {mem_idx} for memory_list {len(memory_list)}, skipping." + ) + continue - result: dict[int, dict] = {} - for k, v in data.items(): - try: - idx = int(k) - except Exception: - # allow integer keys as-is - if isinstance(k, int): - idx = k - else: - continue - if not isinstance(v, dict): - continue - need_rewrite = v.get("need_rewrite") - rewritten = v.get("rewritten", "") - reason = v.get("reason", "") - if ( - isinstance(need_rewrite, bool) - and isinstance(rewritten, str) - and isinstance(reason, str) - ): - result[idx] = { - "need_rewrite": need_rewrite, - "rewritten": rewritten, - "reason": reason, - } + need_rewrite = content.get("need_rewrite", False) + rewritten_text = content.get("rewritten", "") + reason = content.get("reason", "") + original_text = memory_list[mem_idx].memory - return (len(result) > 0), result + # Replace memory text with rewritten content when rewrite is needed + if need_rewrite and isinstance(rewritten_text, str): + logger.info( + f"[rewrite_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + ) + if len(rewritten_text.strip()) != 0: + memory_list[mem_idx].memory = rewritten_text + new_memory_list.append(memory_list[mem_idx]) + else: + new_memory_list.append(memory_list[mem_idx]) + return new_memory_list + else: + logger.warning("Rewrite filter parsing failed or returned empty result.") + except Exception as e: + logger.error(f"Rewrite filter execution error: {e}", stack_info=True) + + return memory_list def filter_hallucination_in_memories( self, messages: list[dict], memory_list: list[TextualMemoryItem] @@ -520,32 +524,32 @@ def filter_hallucination_in_memories( # Optionally run filter and parse the output try: raw = self.llm.generate([{"role": "user", "content": prompt}]) - success, parsed = self._parse_hallucination_filter_response(raw) + success, parsed = parse_keep_filter_response(raw) logger.info( f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}" ) if success: logger.info(f"Hallucination filter result: {parsed}") - assert len(parsed) == len(memory_list) - for mem_idx, content in parsed.items(): - need_rewrite = content.get("need_rewrite", False) - rewritten_text = content.get("rewritten", "") - reason = content.get("reason", "") - # Replace memory text with rewritten content when rewrite is needed - if ( - need_rewrite - and isinstance(rewritten_text, str) - and len(rewritten_text.strip()) > 0 - ): - original_text = memory_list[mem_idx].memory + filtered_list = [] + for mem_idx, mem in enumerate(memory_list): + content = parsed.get(mem_idx) + if not content: + logger.warning(f"No verdict for memory {mem_idx}, keeping it.") + filtered_list.append(mem) + continue + + keep = content.get("keep", True) + reason = content.get("reason", "") + if keep: + filtered_list.append(mem) + else: logger.info( - f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'" + f"[filter_hallucination_in_memories] Dropping memory index={mem_idx}, reason='{reason}', memory='{mem.memory}'" ) - memory_list[mem_idx].memory = rewritten_text - return memory_list + return filtered_list else: logger.warning("Hallucination filter parsing failed or returned empty result.") except Exception as e: @@ -606,29 +610,27 @@ def _read_memory( for group_id in range(len(memory_list)): try: - revised_memory_list = self.filter_hallucination_in_memories( + original_memory_group = copy.deepcopy(memory_list[group_id]) + serialized_origin_memories = json.dumps( + [one.memory for one in original_memory_group], indent=2 + ) + revised_memory_list = self.rewrite_memories( messages=combined_messages, - memory_list=memory_list[group_id], + memory_list=original_memory_group, + user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() + == "true", ) - if len(revised_memory_list) != len(memory_list[group_id]): - original_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in memory_list[group_id] - ] - filtered_serialized = [ - one.memory if hasattr(one, "memory") else str(one) - for one in revised_memory_list - ] - logger.error( - f"Length mismatch after hallucination filtering for group_id={group_id}: " - f"original={len(memory_list[group_id])}, filtered={len(revised_memory_list)}" - f"\noriginal_memory_list(serialized): {original_serialized}" - f"\nfiltered_memory_list(serialized): {filtered_serialized}" - f"\nmessages: {combined_messages}" - f"\nSkipping update and keeping original memory." + serialized_revised_memories = json.dumps( + [one.memory for one in revised_memory_list], indent=2 + ) + if serialized_origin_memories != serialized_revised_memories: + memory_list[group_id] = revised_memory_list + logger.info( + f"[SIMPLE_STRUCT_ADD_FILTER] Modified the list for group_id={group_id}: " + f"\noriginal={serialized_origin_memories}," + f"\nrevised={serialized_revised_memories}" ) - continue - memory_list[group_id] = revised_memory_list + except Exception as e: group_serialized = [ one.memory if hasattr(one, "memory") else str(one) @@ -847,7 +849,7 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): info, source_info_list, self.llm, - self.parse_json_result, + parse_json_result, self.embedder, ): idx for idx, msg in enumerate(messages) @@ -870,44 +872,3 @@ def _process_transfer_doc_data( self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None ): raise NotImplementedError - - def parse_json_result(self, response_text: str) -> dict: - s = (response_text or "").strip() - - m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) - s = (m.group(1) if m else s.replace("```", "")).strip() - - i = s.find("{") - if i == -1: - return {} - s = s[i:].strip() - - try: - return json.loads(s) - except json.JSONDecodeError: - pass - - j = max(s.rfind("}"), s.rfind("]")) - if j != -1: - try: - return json.loads(s[: j + 1]) - except json.JSONDecodeError: - pass - - def _cheap_close(t: str) -> str: - t += "}" * max(0, t.count("{") - t.count("}")) - t += "]" * max(0, t.count("[") - t.count("]")) - return t - - t = _cheap_close(s) - try: - return json.loads(t) - except json.JSONDecodeError as e: - if "Invalid \\escape" in str(e): - s = s.replace("\\", "\\\\") - return json.loads(s) - logger.error( - f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ - json: {s}" - ) - return {} diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py new file mode 100644 index 000000000..4e5a78af2 --- /dev/null +++ b/src/memos/mem_reader/utils.py @@ -0,0 +1,157 @@ +import json +import re + +from memos import log + + +logger = log.get_logger(__name__) + +try: + import tiktoken + + try: + _ENC = tiktoken.encoding_for_model("gpt-4o-mini") + except Exception: + _ENC = tiktoken.get_encoding("cl100k_base") + + def count_tokens_text(s: str) -> int: + return len(_ENC.encode(s or "", disallowed_special=())) +except Exception: + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars + def count_tokens_text(s: str) -> int: + if not s: + return 0 + zh_chars = re.findall(r"[\u4e00-\u9fff]", s) + zh = len(zh_chars) + rest = len(s) - zh + return zh + max(1, rest // 4) + + +def derive_key(text: str, max_len: int = 80) -> str: + """default key when without LLM: first max_len words""" + if not text: + return "" + sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0] + return (sent[:max_len]).strip() + + +def parse_json_result(response_text: str) -> dict: + s = (response_text or "").strip() + + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I) + s = (m.group(1) if m else s.replace("```", "")).strip() + + i = s.find("{") + if i == -1: + return {} + s = s[i:].strip() + + try: + return json.loads(s) + except json.JSONDecodeError: + pass + + j = max(s.rfind("}"), s.rfind("]")) + if j != -1: + try: + return json.loads(s[: j + 1]) + except json.JSONDecodeError: + pass + + def _cheap_close(t: str) -> str: + t += "}" * max(0, t.count("{") - t.count("}")) + t += "]" * max(0, t.count("[") - t.count("]")) + return t + + t = _cheap_close(s) + try: + return json.loads(t) + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + s = s.replace("\\", "\\\\") + return json.loads(s) + logger.error( + f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \ + json: {s}" + ) + return {} + + +def parse_rewritten_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from hallucination filter response. + Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + # allow integer keys as-is + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + need_rewrite = v.get("need_rewrite") + rewritten = v.get("rewritten", "") + reason = v.get("reason", "") + if ( + isinstance(need_rewrite, bool) + and isinstance(rewritten, str) + and isinstance(reason, str) + ): + result[idx] = { + "need_rewrite": need_rewrite, + "rewritten": rewritten, + "reason": reason, + } + + return (len(result) > 0), result + + +def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]: + """Parse index-keyed JSON from keep filter response. + Expected shape: { "0": {"keep": bool, "reason": str}, ... } + Returns (success, parsed_dict) with int keys. + """ + try: + m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I) + s = (m.group(1) if m else text).strip() + data = json.loads(s) + except Exception: + return False, {} + + if not isinstance(data, dict): + return False, {} + + result: dict[int, dict] = {} + for k, v in data.items(): + try: + idx = int(k) + except Exception: + if isinstance(k, int): + idx = k + else: + continue + if not isinstance(v, dict): + continue + keep = v.get("keep") + reason = v.get("reason", "") + if isinstance(keep, bool): + result[idx] = { + "keep": keep, + "reason": reason, + } + return (len(result) > 0), result diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 1e0ecaadb..728203f5b 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1009,14 +1009,24 @@ def _monitor_loop(self): q_sizes = self.memos_message_queue.qsize() for stream_key, queue_length in q_sizes.items(): - # Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}" + # Skip aggregate keys like 'total_size' + if stream_key == "total_size": + continue + + # Key format: ...:{user_id}:{mem_cube_id}:{task_label} + # We want to extract user_id, which is the 3rd component from the end. parts = stream_key.split(":") if len(parts) >= 3: - user_id = parts[2] - self.metrics.update_queue_length(queue_length, user_id) - elif not self.use_redis_queue: # local queue - user_id = stream_key + user_id = parts[-3] self.metrics.update_queue_length(queue_length, user_id) + else: + # Fallback for unexpected key formats (e.g. legacy or testing) + # Try to use the key itself if it looks like a user_id (no colons) + # or just log a warning? + # For now, let's assume if it's not total_size and short, it might be a direct user_id key + # (though that shouldn't happen with current queue implementations) + if ":" not in stream_key: + self.metrics.update_queue_length(queue_length, stream_key) except Exception as e: logger.error(f"Error in metrics monitor loop: {e}", exc_info=True) diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f4ad9fe48..06910ba17 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,3 +1,5 @@ +import os + from pathlib import Path @@ -21,7 +23,7 @@ DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 DEFAULT_TOP_K = 5 DEFAULT_CONTEXT_WINDOW_SIZE = 5 -DEFAULT_USE_REDIS_QUEUE = True +DEFAULT_USE_REDIS_QUEUE = os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true" DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20 DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 69cfc0af9..eae70f8ef 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -62,7 +62,7 @@ def put( Exception: Any underlying error during queue.put() operation. """ stream_key = self.get_stream_key( - user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.task_label + user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label ) message.stream_key = stream_key @@ -108,35 +108,94 @@ def get( ) return res - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait( + self, stream_key: str, batch_size: int | None = None + ) -> list[ScheduleMessageItem]: """ - Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). Returns immediately with available messages or an empty list if queue is empty. Args: + stream_key (str): The stream/queue identifier. batch_size (int | None): Number of messages to retrieve in a batch. If None, retrieves one message. Returns: List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. """ - logger.debug(f"get_nowait() called with batch_size: {batch_size}") - return self.get(block=False, batch_size=batch_size) + logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}") + return self.get(stream_key=stream_key, block=False, batch_size=batch_size) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + """ + Get messages from all streams in round-robin or sequential fashion. + Equivalent to SchedulerRedisQueue.get_messages. + """ + messages = [] + # Snapshot keys to avoid runtime modification issues + stream_keys = list(self.queue_streams.keys()) + + # Simple strategy: try to get up to batch_size messages across all streams + # We can just iterate and collect. + + # Calculate how many to get per stream to be fair? + # Or just greedy? Redis implementation uses a complex logic. + # For local, let's keep it simple: just iterate and take what's available (non-blocking) + + for stream_key in stream_keys: + if len(messages) >= batch_size: + break + + needed = batch_size - len(messages) + # Use get_nowait to avoid blocking + fetched = self.get_nowait(stream_key=stream_key, batch_size=needed) + messages.extend(fetched) + + return messages def qsize(self) -> dict: """ Return the current size of all internal queues as a dictionary. Each key is the stream name, and each value is the number of messages in that queue. + Also includes 'total_size'. Returns: Dict[str, int]: Mapping from stream name to current queue size. """ sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + total_size = sum(sizes.values()) + sizes["total_size"] = total_size logger.debug(f"Current queue sizes: {sizes}") return sizes + def size(self) -> int: + """ + Get the current size of the queue (total message count). + Compatible with SchedulerRedisQueue. + """ + return self.unfinished_tasks + + def empty(self) -> bool: + """ + Check if the queue is empty. + Compatible with SchedulerRedisQueue. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if the queue is full. + Compatible with SchedulerRedisQueue. + """ + # Local queue limits are per-stream (max_internal_message_queue_size). + # It is considered full only if all streams are full. + if not self.queue_streams: + return False + + return all(queue.full() for queue in self.queue_streams.values()) + def clear(self) -> None: for queue in self.queue_streams.values(): queue.clear() @@ -151,6 +210,9 @@ def unfinished_tasks(self) -> int: Returns: int: Sum of all message counts in all internal queues. """ - total = sum(self.qsize().values()) + # qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values + # But qsize() implementation above sums values from queue_streams, then adds total_size. + # So sum(self.queue_streams.values().qsize()) is safer. + total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1c57f18f0..2f4318003 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,7 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -import contextlib import os import re import threading @@ -201,6 +200,20 @@ def _refresh_stream_keys( recent_seconds=DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, now_sec=now_sec, ) + + # Ensure consumer groups for newly discovered active streams + with self._stream_keys_lock: + # Identify keys we haven't seen yet + new_streams = [k for k in active_stream_keys if k not in self.seen_streams] + + # Create groups outside the lock to avoid blocking + for key in new_streams: + self._ensure_consumer_group(key) + + if new_streams: + with self._stream_keys_lock: + self.seen_streams.update(new_streams) + deleted_count = self._delete_streams(keys_to_delete) self._update_stream_cache_with_log( stream_key_prefix=stream_key_prefix, @@ -560,10 +573,7 @@ def _read_new_messages_batch( return {} # Pre-ensure consumer groups to avoid NOGROUP during batch reads - for stream_key in stream_keys: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - + # (Optimization: rely on put() and _refresh_stream_keys() to ensure groups) pipe = self._redis_conn.pipeline(transaction=False) for stream_key in stream_keys: pipe.xreadgroup( @@ -679,11 +689,6 @@ def _batch_claim_pending_messages( if not self._redis_conn or not claims_spec: return [] - # Ensure consumer groups exist to avoid NOGROUP errors during batch claim - for stream_key, _need_count, _label in claims_spec: - with contextlib.suppress(Exception): - self._ensure_consumer_group(stream_key=stream_key) - pipe = self._redis_conn.pipeline(transaction=False) for stream_key, need_count, label in claims_spec: pipe.xautoclaim( @@ -696,26 +701,42 @@ def _batch_claim_pending_messages( justid=False, ) - results = [] try: - results = pipe.execute() - except Exception: - # Fallback: attempt sequential xautoclaim for robustness - for stream_key, need_count, label in claims_spec: - try: - self._ensure_consumer_group(stream_key=stream_key) - res = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), - start_id="0-0", - count=need_count, - justid=False, - ) - results.append(res) - except Exception: - continue + # Execute with raise_on_error=False so we get exceptions in the results list + # instead of aborting the whole batch. + results = pipe.execute(raise_on_error=False) + except Exception as e: + logger.error(f"Pipeline execution critical failure: {e}") + results = [e] * len(claims_spec) + + # Handle individual failures (e.g. NOGROUP) by retrying just that stream + final_results = [] + for i, res in enumerate(results): + if isinstance(res, Exception): + err_msg = str(res).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + stream_key, need_count, label = claims_spec[i] + try: + self._ensure_consumer_group(stream_key=stream_key) + retry_res = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=self.orchestrator.get_task_idle_min(task_label=label), + start_id="0-0", + count=need_count, + justid=False, + ) + final_results.append(retry_res) + except Exception as retry_err: + logger.warning(f"Retry xautoclaim failed for {stream_key}: {retry_err}") + final_results.append(None) + else: + final_results.append(None) + else: + final_results.append(res) + + results = final_results claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] for (stream_key, _need_count, _label), claimed_result in zip( @@ -1159,10 +1180,14 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: del_pipe.delete(key) del_pipe.execute() deleted_count = len(keys_to_delete) - # Clean up empty-tracking state for deleted keys + # Clean up empty-tracking state and seen_streams for deleted keys with self._empty_stream_seen_lock: for key in keys_to_delete: self._empty_stream_seen_times.pop(key, None) + + with self._stream_keys_lock: + for key in keys_to_delete: + self.seen_streams.discard(key) except Exception: for key in keys_to_delete: try: @@ -1170,6 +1195,8 @@ def _delete_streams(self, keys_to_delete: list[str]) -> int: deleted_count += 1 with self._empty_stream_seen_lock: self._empty_stream_seen_times.pop(key, None) + with self._stream_keys_lock: + self.seen_streams.discard(key) except Exception: pass return deleted_count @@ -1189,9 +1216,7 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) - logger.info( - f"[REDIS_QUEUE] Stream keys refresh: prefix='{stream_key_prefix}', " - f"total={len(candidate_keys)}, active={len(active_stream_keys)}, cached={cache_count}, " - f"active_threshold_sec={int(active_threshold_sec)}, deleted={deleted_count}, " - f"inactive_threshold_sec={int(DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS)}" - ) + logger.info( + f"Refreshed stream keys cache: {cache_count} active keys, " + f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." + ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index c20243242..b49db2b36 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -153,28 +153,7 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_messages(batch_size=batch_size) - stream_keys = self.get_stream_keys() - - if len(stream_keys) == 0: - return [] - - messages: list[ScheduleMessageItem] = [] - - for stream_key in stream_keys: - fetched = self.memos_message_queue.get( - stream_key=stream_key, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - if len(messages) > 0: - logger.debug( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) - return messages + return self.memos_message_queue.get_messages(batch_size=batch_size) def clear(self): self.memos_message_queue.clear() diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index d8c8d2cee..2a995b239 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -13,7 +13,7 @@ class TaskStatusTracker: @require_python_package(import_name="redis", install_command="pip install redis") - def __init__(self, redis_client: "redis.Redis"): + def __init__(self, redis_client: "redis.Redis | None"): self.redis = redis_client def _get_key(self, user_id: str) -> str: @@ -41,6 +41,9 @@ def task_submitted( mem_cube_id: Memory cube identifier business_task_id: Optional business-level task ID (one task_id can have multiple item_ids) """ + if not self.redis: + return + key = self._get_key(user_id) payload = { "status": "waiting", @@ -61,6 +64,9 @@ def task_submitted( self.redis.expire(key, timedelta(days=7)) def task_started(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -77,6 +83,9 @@ def task_started(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_completed(self, task_id: str, user_id: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -91,6 +100,9 @@ def task_completed(self, task_id: str, user_id: str): self.redis.expire(key, timedelta(days=7)) def task_failed(self, task_id: str, user_id: str, error_message: str): + if not self.redis: + return + key = self._get_key(user_id) existing_data_json = self.redis.hget(key, task_id) if not existing_data_json: @@ -108,11 +120,17 @@ def task_failed(self, task_id: str, user_id: str, error_message: str): self.redis.expire(key, timedelta(days=7)) def get_task_status(self, task_id: str, user_id: str) -> dict | None: + if not self.redis: + return None + key = self._get_key(user_id) data = self.redis.hget(key, task_id) return json.loads(data) if data else None def get_all_tasks_for_user(self, user_id: str) -> dict[str, dict]: + if not self.redis: + return {} + key = self._get_key(user_id) all_tasks = self.redis.hgetall(key) return {tid: json.loads(t_data) for tid, t_data in all_tasks.items()} @@ -132,6 +150,9 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) -> - If any item is 'failed' → 'failed' Returns None if task_id not found. """ + if not self.redis: + return None + # Get all item_ids for this task_id task_items_key = self._get_task_items_key(user_id, business_task_id) item_ids = self.redis.smembers(task_items_key) @@ -180,6 +201,9 @@ def get_all_tasks_global(self) -> dict[str, dict[str, dict]]: Returns: dict: {user_id: {task_id: task_data, ...}, ...} """ + if not self.redis: + return {} + all_users_tasks = {} cursor: int | str = 0 while True: diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 144bfad7f..3404c6d4c 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -69,6 +69,8 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: @@ -92,6 +94,8 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + if not response: + return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) for d in result: diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 57f2cdba1..ab3d0ce03 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -15,6 +15,7 @@ ) from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_reader.utils import parse_keep_filter_response from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( ADD_TASK_LABEL, @@ -23,6 +24,7 @@ PREF_ADD_TASK_LABEL, ) from memos.multi_mem_cube.views import MemCubeView +from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, FineStrategy, @@ -41,6 +43,7 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler + from memos.memories.textual.item import TextualMemoryItem @dataclass @@ -631,6 +634,104 @@ def _process_pref_mem( for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] + def add_before_search( + self, + messages: list[dict], + memory_list: list[TextualMemoryItem], + user_name: str, + info: dict[str, Any], + ) -> list[TextualMemoryItem]: + # Build input objects with memory text and metadata (timestamps, sources, etc.) + template = PROMPT_MAPPING["add_before_search"] + + if not self.searcher: + self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.") + return memory_list + + # 1. Gather candidates and search for related memories + candidates_data = [] + for idx, mem in enumerate(memory_list): + try: + related_memories = self.searcher.search( + query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info + ) + related_text = "None" + if related_memories: + related_text = "\n".join([f"- {r.memory}" for r in related_memories]) + + candidates_data.append( + {"idx": idx, "new_memory": mem.memory, "related_memories": related_text} + ) + except Exception as e: + self.logger.error( + f"[add_before_search] Search error for memory '{mem.memory}': {e}" + ) + # If search fails, we can either skip this check or treat related as empty + candidates_data.append( + { + "idx": idx, + "new_memory": mem.memory, + "related_memories": "None (Search Failed)", + } + ) + + if not candidates_data: + return memory_list + + # 2. Build Prompt + messages_inline = "\n".join( + [ + f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}" + for message in messages + ] + ) + + candidates_inline_dict = { + str(item["idx"]): { + "new_memory": item["new_memory"], + "related_memories": item["related_memories"], + } + for item in candidates_data + } + + candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2) + + prompt = template.format( + messages_inline=messages_inline, candidates_inline=candidates_inline + ) + + # 3. Call LLM + try: + raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}]) + success, parsed_result = parse_keep_filter_response(raw) + + if not success: + self.logger.warning( + "[add_before_search] Failed to parse LLM response, keeping all." + ) + return memory_list + + # 4. Filter + filtered_list = [] + for idx, mem in enumerate(memory_list): + res = parsed_result.get(idx) + if not res: + filtered_list.append(mem) + continue + + if res.get("keep", True): + filtered_list.append(mem) + else: + self.logger.info( + f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'" + ) + + return filtered_list + + except Exception as e: + self.logger.error(f"[add_before_search] LLM execution error: {e}") + return memory_list + def _process_text_mem( self, add_req: APIADDRequest, diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index fef3ee6c0..40971c77e 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,23 +622,56 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. +Note: The provided messages contain only user messages. The assistant's responses are intentionally omitted, not because the assistant didn't answer, but to focus strictly on validating memories against user input. + Rules: 1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. -2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations. -3. **Ambiguity Elimination**: - - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them. - - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. 4. **Hallucination Removal**: - - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten. - - Do **not** rephrase inferences as facts. Instead, either: - - Remove the unsupported part and retain only the grounded core, or - - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit. +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. 5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. Inputs: messages: @@ -651,16 +684,115 @@ - Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. - Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} - The "reason" must be brief and precise, e.g.: - - "contains unsupported inference" - - "vague pronoun with no referent in messages" - - "relative time resolved to 2025-12-16" + - "contains unsupported inference ...." - "fully grounded and concise" Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +4. **Hallucination Removal**: +- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten. +- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific. +- Do **not** rephrase inferences as facts. Instead, either: +- Remove the unsupported part and retain only the grounded core. +5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged. +6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference ...." + - "fully grounded and concise" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + +SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ +You are a strict memory validator. +Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages. + +Rules: +1. **User-Only Origin**: Verify facts against USER messages ONLY. If the Assistant repeats a User fact, it is VALID. If the Assistant introduces a new detail (e.g., 'philanthropy') that the User did not explicitly confirm, it is INVALID. +2. **No Inference Allowed**: Do NOT keep memories based on implication, emotion, preference, or generalization. Only verbatim or direct restatements of user-provided facts are valid. However, minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. +3. **Hallucination = Deletion**: If a memory contains any detail not directly expressed by the user, mark it for deletion. +4. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + +Examples: +Messages: +- [user]: I love coding in Python. +- [assistant]: That's great! I assume you also contribute to open source projects? +Memory: User enjoys Python and contributes to open source. +Result: {{"keep": false, "reason": "User never stated they contribute to open source; this came from Assistant's assumption."}} + +Messages: +- [user]: I am tired. +- [assistant]: I hear you are tired. Rest is important. +Memory: User stated they are tired. +Result: {{"keep": true, "reason": "Direct restatement of user input, even if Assistant repeated it."}} + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true only if the memory is a direct reflection of the user's explicit words. +- "reason": brief, factual, and cites missing or unsupported content. + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + + +SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT = """ +You are a memory manager. +Your task is to decide if a new memory should be added to the long-term memory, given a list of existing related memories. + +Rules: +1. **Redundancy Check**: If the new memory is completely redundant, already known, or covered by the existing memories, discard it. +2. **New Information**: If the new memory provides new information, details, or updates compared to the existing memories, keep it. +3. **Contradiction**: If the new memory contradicts existing memories but seems valid/newer, keep it (updates). +4. **Context Check**: Use the provided conversation messages to verify if the new memory is grounded in the user's explicit statements. + +Inputs: +Messages: +{messages_inline} + +Candidate Memories (to be evaluated): +{candidates_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching the input candidate memory indices. +- Each value must be: {{ "keep": boolean, "reason": string }} +- "keep": true if the memory should be added. +- "reason": brief explanation. + +Important: Output **only** the JSON. No extra text. +""" # Prompt mapping for specialized tasks (e.g., hallucination filtering) PROMPT_MAPPING = { "hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT, + "rewrite": SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT, + "rewrite_user_only": SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT, + "add_before_search": SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT, } diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index f81356886..fd07fbf41 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -1,4 +1,3 @@ -import json import unittest from unittest.mock import MagicMock, patch @@ -8,6 +7,7 @@ from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem @@ -57,7 +57,6 @@ def test_process_chat_data(self): '"summary": "Tom is currently focused on managing a new project with a tight schedule."}' ) self.reader.llm.generate.return_value = mock_response - self.reader.parse_json_result = lambda x: json.loads(x) result = self.reader._process_chat_data(scene_data_info, info) @@ -105,7 +104,7 @@ def test_get_scene_data_info_with_chat(self): def test_parse_json_result_success(self): """Test successful JSON parsing.""" raw_response = '{"summary": "Test summary", "tags": ["test"]}' - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertIsInstance(result, dict) self.assertIn("summary", result) @@ -113,7 +112,7 @@ def test_parse_json_result_success(self): def test_parse_json_result_failure(self): """Test failure in JSON parsing.""" raw_response = "Invalid JSON string" - result = self.reader.parse_json_result(raw_response) + result = parse_json_result(raw_response) self.assertEqual(result, {}) From a1746fb5dd5e5284a249368cc663c44d03e2ab7b Mon Sep 17 00:00:00 2001 From: zZhangSir <103892644+zZhangSir@users.noreply.github.com> Date: Wed, 24 Dec 2025 21:52:04 +0800 Subject: [PATCH 03/48] fix: update README.md (#774) Co-authored-by: Zehao Lin --- README.md | 129 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index 634b38dec..2f5422095 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,6 @@ MemOS is an open-source **Agent Memory framework** that empowers AI agents with **long-term memory, personality consistency, and contextual recall**. It enables agents to **remember past interactions**, **learn over time**, and **build evolving identities** across sessions. Designed for **AI companions, role-playing NPCs, and multi-agent systems**, MemOS provides a unified API for **memory representation, retrieval, and update** — making it the foundation for next-generation **memory-augmented AI agents**. - -🆕 **MemOS 2.0** introduces **knowledge base system**, **multi-modal memory** (images & documents), **tool memory** for Agent optimization, **memory feedback mechanism** for precise control, and **enterprise-grade architecture** with Redis Streams scheduler and advanced DB optimizations.
MemOS Banner @@ -117,19 +115,7 @@ showcasing its capabilities in **information extraction**, **temporal and cross- - **Textual Memory**: For storing and retrieving unstructured or structured text knowledge. - **Activation Memory**: Caches key-value pairs (`KVCacheMemory`) to accelerate LLM inference and context reuse. - **Parametric Memory**: Stores model adaptation parameters (e.g., LoRA weights). - - **Tool Memory** 🆕: Records Agent tool call trajectories and experiences to improve planning capabilities. -- **📚 Knowledge Base System** 🆕: Build multi-dimensional knowledge bases with automatic document/URL parsing, splitting, and cross-project sharing capabilities. -- **🔧 Memory Controllability** 🆕: - - **Feedback Mechanism**: Use `add_feedback` API to correct, supplement, or replace existing memories with natural language. - - **Precise Deletion**: Delete specific memories by User ID or Memory ID via API or MCP tools. -- **👁️ Multi-Modal Support** 🆕: Support for image understanding and memory, including chart parsing in documents. -- **⚡ Advanced Architecture**: - - **DB Optimization**: Enhanced connection management and batch insertion for high-concurrency scenarios. - - **Advanced Retrieval**: Custom tag and info field filtering with complex logical operations. - - **Redis Streams Scheduler**: Multi-level queue architecture with intelligent orchestration for fair multi-tenant scheduling. - - **Stream & Non-Stream Chat**: Ready-to-use streaming and non-streaming chat interfaces. - **🔌 Extensible**: Easily extend and customize memory modules, data sources, and LLM integrations. -- **🏂 Lightweight Deployment** 🆕: Support for quick mode and complete mode deployment options. ## 🚀 Getting Started @@ -153,62 +139,103 @@ pip install -r ./docker/requirements.txt uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8001 --workers 8 ``` -### Local SDK -Here's a quick example of how to create a **`MemCube`**, load it from a directory, access its memories, and save it. +### Interface SDK +#### Here is a quick example showing how to create all interface SDK +This interface is used to add messages, supporting multiple types of content and batch additions. MemOS will automatically parse the messages and handle memory for reference in subsequent conversations. ```python -from memos.mem_cube.general import GeneralMemCube +# Please make sure MemoS is installed (pip install MemoryOS -U) +from memos.api.client import MemOSClient + +# Initialize the client using the API Key +client = MemOSClient(api_key="YOUR_API_KEY") + +messages = [ + {"role": "user", "content": "I have planned to travel to Guangzhou during the summer vacation. What chain hotels are available for accommodation?"}, + {"role": "assistant", "content": "You can consider [7 Days, All Seasons, Hilton], and so on."}, + {"role": "user", "content": "I'll choose 7 Days"}, + {"role": "assistant", "content": "Okay, ask me if you have any other questions."} +] +user_id = "memos_user_123" +conversation_id = "0610" +res = client.add_message(messages=messages, user_id=user_id, conversation_id=conversation_id) + +print(f"result: {res}") +``` -# Initialize a MemCube from a local directory -mem_cube = GeneralMemCube.init_from_dir("examples/data/mem_cube_2") +This interface is used to retrieve the memories of a specified user, returning the memory fragments most relevant to the input query for Agent use. The recalled memory fragments include 'factual memory', 'preference memory', and 'tool memory'. +```python +# Please make sure MemoS is installed (pip install MemoryOS -U) +from memos.api.client import MemOSClient -# Access and print all memories -print("--- Textual Memories ---") -for item in mem_cube.text_mem.get_all(): - print(item) +# Initialize the client using the API Key +client = MemOSClient(api_key="YOUR_API_KEY") -print("\n--- Activation Memories ---") -for item in mem_cube.act_mem.get_all(): - print(item) +query = "I want to go out to play during National Day. Can you recommend a city I haven't been to and a hotel brand I haven't stayed at?" +user_id = "memos_user_123" +conversation_id = "0928" +res = client.search_memory(query=query, user_id=user_id, conversation_id=conversation_id) -# Save the MemCube to a new directory -mem_cube.dump("tmp/mem_cube") +print(f"result: {res}") ``` -**`MOS`** (Memory Operating System) is a higher-level orchestration layer that manages multiple MemCubes and provides a unified API for memory operations. Here's a quick example of how to use MOS: - +This interface is used to delete the memory of specified users and supports batch deletion. ```python -from memos.configs.mem_os import MOSConfig -from memos.mem_os.main import MOS +# Please make sure MemoS is installed (pip install MemoryOS -U) +from memos.api.client import MemOSClient +# Initialize the client using the API Key +client = MemOSClient(api_key="YOUR_API_KEY") + +user_ids = ["memos_user_123"] +# Replace with the memory ID +memory_ids = ["6b23b583-f4c4-4a8f-b345-58d0c48fea04"] +res = client.delete_memory(user_ids=user_ids, memory_ids=memory_ids) + +print(f"result: {res}") +``` -# init MOS -mos_config = MOSConfig.from_json_file("examples/data/config/simple_memos_config.json") -memory = MOS(mos_config) +This interface is used to add feedback to messages in the current session, allowing MemOS to correct its memory based on user feedback. +```python +# Please make sure MemoS is installed (pip install MemoryOS -U) +from memos.api.client import MemOSClient -# create user -user_id = "b41a34d5-5cae-4b46-8c49-d03794d206f5" -memory.create_user(user_id=user_id) +# Initialize the client using the API Key +client = MemOSClient(api_key="YOUR_API_KEY") -# register cube for user -memory.register_mem_cube("examples/data/mem_cube_2", user_id=user_id) +user_id = "memos_user_123" +conversation_id = "memos_feedback_conv" +feedback_content = "No, let's change it now to a meal allowance of 150 yuan per day and a lodging subsidy of 700 yuan per day for first-tier cities; for second- and third-tier cities, it remains the same as before." +# Replace with the knowledgebase ID +allow_knowledgebase_ids = ["basee5ec9050-c964-484f-abf1-ce3e8e2aa5b7"] -# add memory for user -memory.add( - messages=[ - {"role": "user", "content": "I like playing football."}, - {"role": "assistant", "content": "I like playing football too."}, - ], +res = client.add_feedback( user_id=user_id, + conversation_id=conversation_id, + feedback_content=feedback_content, + allow_knowledgebase_ids=allow_knowledgebase_ids ) -# Later, when you want to retrieve memory for user -retrieved_memories = memory.search(query="What do you like?", user_id=user_id) -# output text_memories: I like playing football, act_memories, para_memories -print(f"text_memories: {retrieved_memories['text_mem']}") +print(f"result: {res}") ``` -For more detailed examples, please check out the [`examples`](./examples) directory. +This interface is used to create a knowledgebase associated with a project +```python +# Please make sure MemoS is installed (pip install MemoryOS -U) +from memos.api.client import MemOSClient + +# Initialize the client using the API Key +client = MemOSClient(api_key="YOUR_API_KEY") + +knowledgebase_name = "Financial Reimbursement Knowledge Base" +knowledgebase_description = "A compilation of all knowledge related to the company's financial reimbursements." + +res = client.create_knowledgebase( + knowledgebase_name=knowledgebase_name, + knowledgebase_description=knowledgebase_description +) +print(f"result: {res}") +``` ## 📦 Installation From fc70e9f127d514624e33f3419b356ee2b464ab38 Mon Sep 17 00:00:00 2001 From: zZhangSir <103892644+zZhangSir@users.noreply.github.com> Date: Thu, 25 Dec 2025 11:25:47 +0800 Subject: [PATCH 04/48] Dev zhq new (#776) * fix: update README.md * fix: update README.md --------- Co-authored-by: Zehao Lin --- README.md | 100 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 2f5422095..29a50c1da 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,57 @@ showcasing its capabilities in **information extraction**, **temporal and cross- - **Parametric Memory**: Stores model adaptation parameters (e.g., LoRA weights). - **🔌 Extensible**: Easily extend and customize memory modules, data sources, and LLM integrations. + +## 📦 Installation + +### Install via pip + +```bash +pip install MemoryOS +``` + +### Optional Dependencies + +MemOS provides several optional dependency groups for different features. You can install them based on your needs. + +| Feature | Package Name | +| --------------------- | ------------------------- | +| Tree Memory | `MemoryOS[tree-mem]` | +| Memory Reader | `MemoryOS[mem-reader]` | +| Memory Scheduler | `MemoryOS[mem-scheduler]` | + +Example installation commands: + +```bash +pip install MemoryOS[tree-mem] +pip install MemoryOS[tree-mem,mem-reader] +pip install MemoryOS[mem-scheduler] +pip install MemoryOS[tree-mem,mem-reader,mem-scheduler] +``` + +### External Dependencies + +#### Ollama Support + +To use MemOS with [Ollama](https://ollama.com/), first install the Ollama CLI: + +```bash +curl -fsSL https://ollama.com/install.sh | sh +``` + +#### Transformers Support + +To use functionalities based on the `transformers` library, ensure you have [PyTorch](https://pytorch.org/get-started/locally/) installed (CUDA version recommended for GPU acceleration). + +#### Download Examples + +To download example code, data and configurations, run the following command: + +```bash +memos download_examples +``` + + ## 🚀 Getting Started ### ⭐️ MemOS online API @@ -237,55 +288,6 @@ res = client.create_knowledgebase( print(f"result: {res}") ``` -## 📦 Installation - -### Install via pip - -```bash -pip install MemoryOS -``` - -### Optional Dependencies - -MemOS provides several optional dependency groups for different features. You can install them based on your needs. - -| Feature | Package Name | -| --------------------- | ------------------------- | -| Tree Memory | `MemoryOS[tree-mem]` | -| Memory Reader | `MemoryOS[mem-reader]` | -| Memory Scheduler | `MemoryOS[mem-scheduler]` | - -Example installation commands: - -```bash -pip install MemoryOS[tree-mem] -pip install MemoryOS[tree-mem,mem-reader] -pip install MemoryOS[mem-scheduler] -pip install MemoryOS[tree-mem,mem-reader,mem-scheduler] -``` - -### External Dependencies - -#### Ollama Support - -To use MemOS with [Ollama](https://ollama.com/), first install the Ollama CLI: - -```bash -curl -fsSL https://ollama.com/install.sh | sh -``` - -#### Transformers Support - -To use functionalities based on the `transformers` library, ensure you have [PyTorch](https://pytorch.org/get-started/locally/) installed (CUDA version recommended for GPU acceleration). - -#### Download Examples - -To download example code, data and configurations, run the following command: - -```bash -memos download_examples -``` - ## 💬 Community & Support Join our community to ask questions, share your projects, and connect with other developers. From 3873adb3b5869704bb3cb2958aee1e2a58a6e7aa Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Thu, 25 Dec 2025 14:14:19 +0800 Subject: [PATCH 05/48] feat: add export_graph data page (#778) --- src/memos/graph_dbs/polardb.py | 144 ++++++++++++++++++++++++++++----- 1 file changed, 123 insertions(+), 21 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index c81e46804..1d19dc98d 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2502,13 +2502,19 @@ def clear(self, user_name: str | None = None) -> None: @timed def export_graph( - self, include_embedding: bool = False, user_name: str | None = None + self, + include_embedding: bool = False, + user_name: str | None = None, + page: int = 1, + page_size: int = 10, ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. user_name (str, optional): User name for filtering in non-multi-db mode + page (int): Page number (starts from 1). Default is 1. + page_size (int): Number of items per page. Default is 1000. Returns: { @@ -2516,7 +2522,17 @@ def export_graph( "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] } """ + logger.info( + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, page: {page}, page_size: {page_size}" + ) user_name = user_name if user_name else self._get_config_value("user_name") + + # Validate pagination parameters + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + conn = None try: conn = self._get_connection() @@ -2526,14 +2542,18 @@ def export_graph( SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + ORDER BY id + LIMIT {page_size} OFFSET {(page - 1) * page_size} """ else: node_query = f""" SELECT id, properties FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + ORDER BY id + LIMIT {page_size} OFFSET {(page - 1) * page_size} """ - + logger.info(f"[export_graph nodes] Query: {node_query}") with conn.cursor() as cursor: cursor.execute(node_query) node_results = cursor.fetchall() @@ -2580,14 +2600,19 @@ def export_graph( try: conn = self._get_connection() # Export edges using cypher query + # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery edge_query = f""" - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' - RETURN a.id AS source, b.id AS target, type(r) as edge - $$) AS (source agtype, target agtype, edge agtype) + SELECT source, target, edge FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' + RETURN a.id AS source, b.id AS target, type(r) as edge + ORDER BY a.id, b.id + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + LIMIT {page_size} OFFSET {(page - 1) * page_size} """ - + logger.info(f"[export_graph edges] Query: {edge_query}") with conn.cursor() as cursor: cursor.execute(edge_query) edge_results = cursor.fetchall() @@ -4580,28 +4605,105 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype" ) elif op == "contains": - # Handle contains operator (for string fields only) - # Check if agtype contains value (using @> operator) - if not isinstance(op_value, str): - raise ValueError( - f"contains operator only supports string format. " - f"Use {{'{key}': {{'contains': '{op_value}'}}}} instead of {{'{key}': {{'contains': {op_value}}}}}" - ) + # Handle contains operator + # For array fields: check if array contains the value using @> operator + # For string fields: check if string contains the value using @> operator # Check if key starts with "info." prefix if key.startswith("info."): info_field = key[5:] # Remove "info." prefix - # String contains: use @> operator for agtype contains - escaped_value = escape_sql_string(op_value) + escaped_value = escape_sql_string(str(op_value)) + # For array fields, use @> with array format: '["value"]'::agtype + # For string fields, use @> with string format: '"value"'::agtype + # We'll use array format for contains to check if array contains the value condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '\"{escaped_value}\"'::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype" ) else: # Direct property access - # String contains: use @> operator for agtype contains - escaped_value = escape_sql_string(op_value) + escaped_value = escape_sql_string(str(op_value)) + # For array fields, use @> with array format condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '\"{escaped_value}\"'::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype" ) + elif op == "in": + # Handle in operator (for checking if field value is in a list) + # Supports array format: {"field": {"in": ["value1", "value2"]}} + if not isinstance(op_value, list): + raise ValueError( + f"in operator only supports array format. " + f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" + ) + # Check if key starts with "info." prefix + if key.startswith("info."): + info_field = key[5:] # Remove "info." prefix + # Build OR conditions for nested properties + if len(op_value) == 0: + # Empty list means no match + condition_parts.append("false") + elif len(op_value) == 1: + # Single value, use equality + item = op_value[0] + if isinstance(item, str): + escaped_value = escape_sql_string(item) + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" + ) + else: + # Multiple values, use OR conditions + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_sql_string(item) + or_conditions.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" + ) + else: + or_conditions.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" + ) + if or_conditions: + condition_parts.append( + f"({' OR '.join(or_conditions)})" + ) + else: + # Direct property access + # Build OR conditions + if len(op_value) == 0: + # Empty list means no match + condition_parts.append("false") + elif len(op_value) == 1: + # Single value, use equality + item = op_value[0] + if isinstance(item, str): + escaped_value = escape_sql_string(item) + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" + ) + else: + # Multiple values, use OR conditions + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_sql_string(item) + or_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" + ) + else: + or_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" + ) + if or_conditions: + condition_parts.append( + f"({' OR '.join(or_conditions)})" + ) elif op == "like": # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') # Check if key starts with "info." prefix From 1ee536ae11d953bb53801de6e0b241f302b9f4a9 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Thu, 25 Dec 2025 14:35:27 +0800 Subject: [PATCH 06/48] fix: optimize Neo4j Community Edition support and enhance MCP environment loading (#754) - Update default Neo4j DB name to 'neo4j' for Community Edition compatibility - Add neo4j_auto_create and use_multi_db configuration options - Enhance MCP server to load all relevant environment variables from .env - Add defensive error handling in Neo4j driver for administrative commands on Community Edition - Update .env.example with detailed Neo4j setup instructions Co-authored-by: CaralHsi --- docker/.env.example | 4 +- src/memos/api/mcp_serve.py | 106 +++++++++++++++++++---- src/memos/graph_dbs/neo4j.py | 9 ++ src/memos/mem_os/utils/default_config.py | 11 ++- 4 files changed, 109 insertions(+), 21 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index ca3abde94..dc4252133 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -99,7 +99,9 @@ NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j* NEO4J_USER=neo4j # required when backend=neo4j* NEO4J_PASSWORD=12345678 # required when backend=neo4j* NEO4J_DB_NAME=neo4j # required for shared-db mode -MOS_NEO4J_SHARED_DB=false +MOS_NEO4J_SHARED_DB=true # if true, all users share one DB; if false, each user gets their own DB +NEO4J_AUTO_CREATE=false # [IMPORTANT] set to false for Neo4j Community Edition +NEO4J_USE_MULTI_DB=false # alternative to MOS_NEO4J_SHARED_DB (logic is inverse) QDRANT_HOST=localhost QDRANT_PORT=6333 # For Qdrant Cloud / remote endpoint (takes priority if set): diff --git a/src/memos/api/mcp_serve.py b/src/memos/api/mcp_serve.py index 9eb1e59d0..838c2a76a 100644 --- a/src/memos/api/mcp_serve.py +++ b/src/memos/api/mcp_serve.py @@ -16,14 +16,88 @@ def load_default_config(user_id="default_user"): + """ + Load MOS configuration from environment variables. + + IMPORTANT for Neo4j Community Edition: + Community Edition does not support administrative commands like 'CREATE DATABASE'. + To avoid errors, ensure the following environment variables are set correctly: + - NEO4J_DB_NAME=neo4j (Must use the default database) + - NEO4J_AUTO_CREATE=false (Disable automatic database creation) + - NEO4J_USE_MULTI_DB=false (Disable multi-tenant database mode) + """ + # Define mapping between environment variables and configuration parameters + # We support both clean names and MOS_ prefixed names for compatibility + env_mapping = { + "OPENAI_API_KEY": "openai_api_key", + "OPENAI_API_BASE": "openai_api_base", + "MOS_TEXT_MEM_TYPE": "text_mem_type", + "NEO4J_URI": "neo4j_uri", + "NEO4J_USER": "neo4j_user", + "NEO4J_PASSWORD": "neo4j_password", + "NEO4J_DB_NAME": "neo4j_db_name", + "NEO4J_AUTO_CREATE": "neo4j_auto_create", + "NEO4J_USE_MULTI_DB": "use_multi_db", + "MOS_NEO4J_SHARED_DB": "mos_shared_db", # Special handle later + "MODEL_NAME": "model_name", + "MOS_CHAT_MODEL": "model_name", + "EMBEDDER_MODEL": "embedder_model", + "MOS_EMBEDDER_MODEL": "embedder_model", + "CHUNK_SIZE": "chunk_size", + "CHUNK_OVERLAP": "chunk_overlap", + "ENABLE_MEM_SCHEDULER": "enable_mem_scheduler", + "MOS_ENABLE_SCHEDULER": "enable_mem_scheduler", + "ENABLE_ACTIVATION_MEMORY": "enable_activation_memory", + "TEMPERATURE": "temperature", + "MOS_CHAT_TEMPERATURE": "temperature", + "MAX_TOKENS": "max_tokens", + "MOS_MAX_TOKENS": "max_tokens", + "TOP_P": "top_p", + "MOS_TOP_P": "top_p", + "TOP_K": "top_k", + "MOS_TOP_K": "top_k", + "SCHEDULER_TOP_K": "scheduler_top_k", + "MOS_SCHEDULER_TOP_K": "scheduler_top_k", + "SCHEDULER_TOP_N": "scheduler_top_n", + } + + kwargs = {"user_id": user_id} + for env_key, param_key in env_mapping.items(): + val = os.getenv(env_key) + if val is not None: + # Strip quotes if they exist (sometimes happens with .env) + if (val.startswith('"') and val.endswith('"')) or ( + val.startswith("'") and val.endswith("'") + ): + val = val[1:-1] + + # Handle boolean conversions + if val.lower() in ("true", "false"): + kwargs[param_key] = val.lower() == "true" + else: + # Try numeric conversions (int first, then float) + try: + if "." in val: + kwargs[param_key] = float(val) + else: + kwargs[param_key] = int(val) + except ValueError: + kwargs[param_key] = val + + # Logic handle for MOS_NEO4J_SHARED_DB vs use_multi_db + if "mos_shared_db" in kwargs: + kwargs["use_multi_db"] = not kwargs.pop("mos_shared_db") + + # Extract mandatory or special params + openai_api_key = kwargs.pop("openai_api_key", os.getenv("OPENAI_API_KEY")) + openai_api_base = kwargs.pop("openai_api_base", "https://api.openai.com/v1") + text_mem_type = kwargs.pop("text_mem_type", "tree_text") + config, cube = get_default( - openai_api_key=os.getenv("OPENAI_API_KEY"), - openai_api_base=os.getenv("OPENAI_API_BASE"), - text_mem_type=os.getenv("MOS_TEXT_MEM_TYPE"), - user_id=user_id, - neo4j_uri=os.getenv("NEO4J_URI"), - neo4j_user=os.getenv("NEO4J_USER"), - neo4j_password=os.getenv("NEO4J_PASSWORD"), + openai_api_key=openai_api_key, + openai_api_base=openai_api_base, + text_mem_type=text_mem_type, + **kwargs, ) return config, cube @@ -33,6 +107,7 @@ def __init__(self): self.mcp = FastMCP("MOS Memory System") config, cube = load_default_config() self.mos_core = MOS(config=config) + self.mos_core.register_mem_cube(cube) self._setup_tools() def _setup_tools(self): @@ -132,11 +207,14 @@ async def register_cube( """ try: if not os.path.exists(cube_name_or_path): - mos_config, cube_name_or_path = load_default_config(user_id=user_id) + _, cube = load_default_config(user_id=user_id) + cube_to_register = cube + else: + cube_to_register = cube_name_or_path self.mos_core.register_mem_cube( - cube_name_or_path, mem_cube_id=cube_id, user_id=user_id + cube_to_register, mem_cube_id=cube_id, user_id=user_id ) - return f"Cube registered successfully: {cube_id or cube_name_or_path}" + return f"Cube registered successfully: {cube_id or cube_to_register}" except Exception as e: return f"Error registering cube: {e!s}" @@ -489,14 +567,6 @@ def run(self, transport: str = "stdio", **kwargs): args = parser.parse_args() - # Set environment variables - os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE") - os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") - os.environ["MOS_TEXT_MEM_TYPE"] = "tree_text" # "tree_text" need set neo4j - os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI") - os.environ["NEO4J_USER"] = os.getenv("NEO4J_USER") - os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD") - # Create and run MCP server server = MOSMCPStdioServer() server.run(transport=args.transport, host=args.host, port=args.port) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index a0a4c6a50..debbb4e3c 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1347,6 +1347,15 @@ def _ensure_database_exists(self): with self.driver.session(database="system") as session: session.run(f"CREATE DATABASE `{self.db_name}` IF NOT EXISTS") except ClientError as e: + if "Unsupported administration command" in str( + e + ) or "Unsupported administration" in str(e): + logger.warning( + f"Could not create database '{self.db_name}' because this Neo4j instance " + "(likely Community Edition) does not support administrative commands. " + "Please ensure the database exists manually or use the default 'neo4j' database." + ) + return if "ExistingDatabaseFound" in str(e): pass # Ignore, database already exists else: diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index bf9f847d0..edb7875d4 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -181,15 +181,22 @@ def get_default_cube_config( # Configure text memory based on type if text_mem_type == "tree_text": # Tree text memory requires Neo4j configuration + # NOTE: Neo4j Community Edition does NOT support multiple databases. + # It only has one default database named 'neo4j'. + # If you are using Community Edition: + # 1. Set 'use_multi_db' to False (default) + # 2. Set 'db_name' to 'neo4j' (default) + # 3. Set 'auto_create' to False to avoid 'CREATE DATABASE' permission errors. db_name = f"memos{user_id.replace('-', '').replace('_', '')}" if not kwargs.get("use_multi_db", False): - db_name = kwargs.get("neo4j_db_name", "defaultdb") + db_name = kwargs.get("neo4j_db_name", "neo4j") + neo4j_config = { "uri": kwargs.get("neo4j_uri", "bolt://localhost:7687"), "user": kwargs.get("neo4j_user", "neo4j"), "db_name": db_name, "password": kwargs.get("neo4j_password", "12345678"), - "auto_create": True, + "auto_create": kwargs.get("neo4j_auto_create", True), "use_multi_db": kwargs.get("use_multi_db", False), "embedding_dimension": kwargs.get("embedding_dimension", 3072), } From 5cf0282b15dada36b15cee37808ae979e6788c38 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Thu, 25 Dec 2025 14:36:11 +0800 Subject: [PATCH 07/48] fix: add feedback change to preference (#771) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/handlers/component_init.py | 1 + src/memos/mem_feedback/feedback.py | 157 +++++++++++++----- src/memos/mem_feedback/simple_feedback.py | 3 + src/memos/mem_feedback/utils.py | 5 +- .../init_components_for_scheduler.py | 1 + src/memos/memories/textual/preference.py | 3 + .../memories/textual/simple_preference.py | 3 + 7 files changed, 135 insertions(+), 38 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index f968ea7b9..7af3afe74 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -308,6 +308,7 @@ def init_server() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, + pref_mem=pref_mem, ) # Initialize Scheduler diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 0b3fc3846..fad15a7cd 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -2,6 +2,7 @@ import difflib import json import re +import uuid from datetime import datetime from typing import TYPE_CHECKING, Any @@ -33,6 +34,7 @@ if TYPE_CHECKING: + from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_feedback_prompts import ( FEEDBACK_ANSWER_PROMPT, @@ -90,6 +92,7 @@ def __init__(self, config: MemFeedbackConfig): self.stopword_manager = StopwordManager self.searcher: Searcher = None self.reranker = None + self.pref_mem: SimplePreferenceTextMemory = None self.DB_IDX_READY = False @require_python_package( @@ -115,7 +118,7 @@ def _retry_db_operation(self, operation): return operation() except Exception as e: logger.error( - f"[1223 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True + f"[1224 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True ) raise @@ -129,7 +132,7 @@ def _batch_embed(self, texts: list[str], embed_bs: int = 5): results.extend(self._embed_once(batch)) except Exception as e: logger.error( - f"[1223 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" + f"[1224 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" ) results.extend([[0.0] * dim for _ in range(len(batch))]) return results @@ -145,7 +148,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i lambda: self.memory_manager.add(to_add_memories, user_name=user_name, use_batch=False) ) logger.info( - f"[1223 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." + f"[1224 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." ) return { "record": { @@ -182,7 +185,7 @@ def _keyword_replace_judgement(self, feedback_content: str) -> dict | None: return judge_res else: logger.warning( - "[1223 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return {} @@ -207,7 +210,7 @@ def _feedback_judgement( return judge_res else: logger.warning( - "[1223 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return [] @@ -271,6 +274,14 @@ def _single_update_operation( """ Individual update operations """ + if "preference" in old_memory_item.metadata.__dict__: + logger.info( + f"[1224 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" + ) + return self._single_update_pref( + old_memory_item, new_memory_item, user_id, user_name, operation + ) + memory_type = old_memory_item.metadata.memory_type source_doc_id = ( old_memory_item.metadata.file_ids[0] @@ -281,6 +292,7 @@ def _single_update_operation( ) if operation and "text" in operation and operation["text"]: new_memory_item.memory = operation["text"] + new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] if memory_type == "WorkingMemory": fields = { @@ -317,6 +329,68 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } + def _single_update_pref( + self, + old_memory_item: TextualMemoryItem, + new_memory_item: TextualMemoryItem, + user_id: str, + user_name: str, + operation: dict, + ): + """update preference memory""" + + feedback_context = new_memory_item.memory + if operation and "text" in operation and operation["text"]: + new_memory_item.memory = operation["text"] + new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] + + to_add_memory = old_memory_item.model_copy(deep=True) + to_add_memory.metadata.key = new_memory_item.metadata.key + to_add_memory.metadata.tags = new_memory_item.metadata.tags + to_add_memory.memory = new_memory_item.memory + to_add_memory.metadata.preference = new_memory_item.memory + to_add_memory.metadata.embedding = new_memory_item.metadata.embedding + + to_add_memory.metadata.user_id = new_memory_item.metadata.user_id + to_add_memory.metadata.original_text = old_memory_item.memory + to_add_memory.metadata.covered_history = old_memory_item.id + + to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( + datetime.now().isoformat() + ) + to_add_memory.metadata.context_summary = ( + old_memory_item.metadata.context_summary + " \n" + feedback_context + ) + + # add new memory + to_add_memory.id = str(uuid.uuid4()) + added_ids = self._retry_db_operation(lambda: self.pref_mem.add([to_add_memory])) + # delete + deleted_id = old_memory_item.id + collection_name = old_memory_item.metadata.preference_type + self._retry_db_operation( + lambda: self.pref_mem.delete_with_collection_name(collection_name, [deleted_id]) + ) + # add archived + old_memory_item.metadata.status = "archived" + old_memory_item.metadata.original_text = "archived" + old_memory_item.metadata.embedding = [0.0] * 1024 + + archived_ids = self._retry_db_operation(lambda: self.pref_mem.add([old_memory_item])) + + logger.info( + f"[Memory Feedback UPDATE Pref] New Add:{added_ids!s} | Set archived:{archived_ids!s}" + ) + + return { + "id": to_add_memory.id, + "text": new_memory_item.memory, + "source_doc_id": "", + "archived_id": old_memory_item.id, + "origin_memory": old_memory_item.memory, + "type": "preference", + } + def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -334,11 +408,11 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> self.graph_store.delete_node(mid, user_name=user_name) logger.info( - f"[1223 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + f"[1224 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" ) except Exception as e: logger.warning( - f"[1223 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" + f"[1224 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) def semantics_feedback( @@ -355,13 +429,12 @@ def semantics_feedback( lang = detect_lang("".join(memory_item.memory)) template = FEEDBACK_PROMPT_DICT["compare"][lang] if current_memories == []: - # retrieve feedback - feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name) - - # retrieve question + # retrieve last_user_index = max(i for i, d in enumerate(chat_history_list) if d["role"] == "user") last_qa = " ".join([item["content"] for item in chat_history_list[last_user_index:]]) supplementary_retrieved = self._retrieve(last_qa, info=info, user_name=user_name) + feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name) + ids = [] for item in feedback_retrieved + supplementary_retrieved: if item.id not in ids: @@ -385,9 +458,14 @@ def semantics_feedback( with ContextThreadPoolExecutor(max_workers=10) as executor: future_to_chunk_idx = {} for chunk in memory_chunks: - current_memories_str = "\n".join( - [f"{item.id}: {item.memory}" for item in chunk] - ) + chunk_list = [] + for item in chunk: + if "preference" in item.metadata.__dict__: + chunk_list.append(f"{item.id}: {item.metadata.preference}") + else: + chunk_list.append(f"{item.id}: {item.memory}") + current_memories_str = "\n".join(chunk_list) + prompt = template.format( now_time=now_time, current_memories=current_memories_str, @@ -408,7 +486,7 @@ def semantics_feedback( all_operations.extend(chunk_operations["operations"]) except Exception as e: logger.error( - f"[1223 Feedback Core: semantics_feedback] Operation failed: {e}" + f"[1224 Feedback Core: semantics_feedback] Operation failed: {e}" ) standard_operations = self.standard_operations(all_operations, current_memories) @@ -458,7 +536,7 @@ def semantics_feedback( update_results.append(result) except Exception as e: logger.error( - f"[1223 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", + f"[1224 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) if update_results: @@ -486,7 +564,7 @@ def _feedback_memory( ] if filterd_ids: logger.warning( - f"[1223 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[1224 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) current_memories = [ @@ -518,7 +596,7 @@ def _feedback_memory( results[i] = node except Exception as e: logger.error( - f"[1223 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", + f"[1224 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", exc_info=True, ) mem_res = [r for r in results if r] @@ -542,13 +620,18 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: record.append(info_v == mem_v) return all(record) - def _retrieve(self, query: str, info=None, top_k=100, user_name=None): + def _retrieve(self, query: str, info=None, top_k=20, user_name=None): """Retrieve memory items""" retrieved_mems = self.searcher.search( query, info=info, user_name=user_name, top_k=top_k, full_recall=True ) retrieved_mems = [item[0] for item in retrieved_mems if float(item[1]) > 0.01] - return retrieved_mems + + pref_info = {} + if "user_id" in info: + pref_info = {"user_id": info["user_id"]} + retrieved_prefs = self.pref_mem.search(query, top_k, pref_info) + return retrieved_mems + retrieved_prefs def _vec_query(self, new_memories_embedding: list[float], user_name=None): """Vector retrieval query""" @@ -577,7 +660,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): if not retrieved_ids: logger.info( - f"[1223 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." + f"[1224 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." ) filterd_ids = [ @@ -585,7 +668,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): ] if filterd_ids: logger.warning( - f"[1223 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[1224 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) return [ TextualMemoryItem(**item) @@ -639,9 +722,9 @@ def filter_fault_update(self, operations: list[dict]): ): all_judge.extend(judge_res["operations_judgement"]) except Exception as e: - logger.error(f"[1223 Feedback Core: filter_fault_update] Judgement failed: {e}") + logger.error(f"[1224 Feedback Core: filter_fault_update] Judgement failed: {e}") - logger.info(f"[1223 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") + logger.info(f"[1224 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") id2op = {item["id"]: item for item in updated_operations} valid_updates = [] for judge in all_judge: @@ -652,7 +735,7 @@ def filter_fault_update(self, operations: list[dict]): valid_updates.append(valid_update) logger.info( - f"[1223 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" + f"[1224 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" ) return valid_updates + [item for item in operations if item["operation"] != "UPDATE"] @@ -680,11 +763,11 @@ def correct_item(data): and "text" in data and "old_memory" in data and data["operation"].lower() == "update" - ) + ), "Invalid operation item" if not should_keep_update(data["text"], data["old_memory"]): logger.warning( - f"[1223 Feedback Core: semantics_feedback] Due to the excessive proportion of changes, skip update: {data}" + f"[1224 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}" ) return None @@ -704,14 +787,14 @@ def correct_item(data): return data except Exception: logger.error( - f"[1223 Feedback Core: standard_operations] Error processing operation item: {data}", + f"[1224 Feedback Core: standard_operations] Error processing operation item: {data}", exc_info=True, ) return None dehallu_res = [correct_item(item) for item in operations] dehalluded_operations = [item for item in dehallu_res if item] - logger.info(f"[1223 Feedback Core: dehalluded_operations] {dehalluded_operations}") + logger.info(f"[1224 Feedback Core: dehalluded_operations] {dehalluded_operations}") # c add objects add_texts = [] @@ -725,7 +808,7 @@ def correct_item(data): elif item["operation"].lower() == "update": llm_operations.append(item) logger.info( - f"[1223 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" + f"[1224 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" ) # Update takes precedence over add @@ -739,7 +822,7 @@ def correct_item(data): ] if filtered_items: logger.info( - f"[1223 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" + f"[1224 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" ) return update_items else: @@ -787,7 +870,7 @@ def _doc_filter(self, doc_scope: str, memories: list[TextualMemoryItem]): memid for inscope_file in inscope_docs for memid in filename2_memid[inscope_file] ] logger.info( - f"[1223 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" + f"[1224 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" ) filter_memories = [mem for mem in memories if mem.id in inscope_ids] return filter_memories @@ -841,7 +924,7 @@ def process_keyword_replace( retrieved_memories = self._doc_filter(doc_scope, retrieved_memories) logger.info( - f"[1223 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." + f"[1224 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." ) if not retrieved_memories: @@ -926,7 +1009,7 @@ def check_validity(item): info.update({"user_id": user_id, "user_name": user_name, "session_id": session_id}) logger.info( - f"[1223 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" + f"[1224 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" ) # feedback keywords update kwp_judge = self._keyword_replace_judgement(feedback_content) @@ -959,7 +1042,7 @@ def check_validity(item): if not valid_feedback: logger.warning( - f"[1223 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." + f"[1224 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." ) return {"record": {"add": [], "update": []}} @@ -1007,13 +1090,13 @@ def check_validity(item): add_memories = mem_record["record"]["add"] update_memories = mem_record["record"]["update"] logger.info( - f"[1223 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." + f"[1224 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." ) return mem_record except Exception as e: logger.error( - f"[1223 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" + f"[1224 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" ) return {"record": {"add": [], "update": []}} diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index 429c2ea20..e32f939c7 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -4,6 +4,7 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -23,6 +24,7 @@ def __init__( mem_reader: SimpleStructMemReader, searcher: Searcher, reranker: BaseReranker, + pref_mem: SimplePreferenceTextMemory, ): self.llm = llm self.embedder = embedder @@ -31,5 +33,6 @@ def __init__( self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager + self.pref_mem = pref_mem self.reranker = reranker self.DB_IDX_READY = False diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py index c32c12328..8cb7f97a3 100644 --- a/src/memos/mem_feedback/utils.py +++ b/src/memos/mem_feedback/utils.py @@ -48,8 +48,11 @@ def calculate_similarity(text1: str, text2: str) -> float: similarity = calculate_similarity(old_text, new_text) change_ratio = 1 - similarity + if change_ratio == float(0): + return False + if old_len < 200: - return change_ratio < 0.5 + return change_ratio < 0.7 else: return change_ratio < 0.2 diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index ba7b558fd..8fd60153d 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -418,6 +418,7 @@ def init_components() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, + pref_mem=pref_mem, ) # Return all components as a dictionary for easy access and extension return {"naive_mem_cube": naive_mem_cube, "feedback_server": feedback_server} diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index e1bc0e72b..9e521158d 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -87,6 +87,9 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ + if not isinstance(search_filter, dict): + search_filter = {} + search_filter.update({"status": "activated"}) logger.info(f"search_filter for preference memory: {search_filter}") return self.retriever.retrieve(query, top_k, info, search_filter) diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index 1f02132bb..ee37d638c 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -61,6 +61,9 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ + if not isinstance(search_filter, dict): + search_filter = {} + search_filter.update({"status": "activated"}) return self.retriever.retrieve(query, top_k, info, search_filter) def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: From b11c768436389402e83904ac3c239368bb8c496d Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 25 Dec 2025 14:41:37 +0800 Subject: [PATCH 08/48] fix: improve chat playground stability and chat handler initialization (#770) * fix playground bug, internet search judge * fix playground internet bug * modify delete mem * modify tool resp bug in multi cube * fix bug in playground chat handle and search inter * modify prompt * fix bug in playground * fix bug playfround * fix bug * fix code * fix model bug in playground * modify plan b * llm param modify * add logger in playground * modify code * fix bug * modify code * modify code * fix bug * fix search bug in plarground * fixx bug * move schadualr to back * modify pref location * modify fast net search * add tags and new package * modify prompt fix bug * remove nltk due to image promblem * prompt modify * modify bug remove redundant field * modify bug * fix playground bug * fix bug * bust internet topk * bust to 50 * fix bug cite * modify search * remote query add in playground * modify bug * modify pref bug * move add position * modify chat prompt * modify overthinking * add logger in playground chat * midify mem * remove must in prompt * add logger * add logger --------- Co-authored-by: yuan.wang Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> Co-authored-by: CaralHsi --- src/memos/api/handlers/chat_handler.py | 98 ++++++++++++------- src/memos/api/handlers/component_init.py | 6 +- src/memos/api/routers/server_router.py | 30 ++++-- .../textual/prefer_text_memory/extractor.py | 6 ++ src/memos/multi_mem_cube/single_cube.py | 5 + 5 files changed, 104 insertions(+), 41 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index bcc3669b6..3e9d1e5ec 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -99,15 +99,13 @@ def __init__( def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]: """ - Chat with MemOS for complete response (non-streaming). - - This implementation directly uses search/add handlers instead of mos_server. + Chat with MemOS for chat complete response (non-streaming). Args: chat_req: Chat complete request Returns: - Dictionary with response and references + Dictionary with chat complete response and reasoning Raises: HTTPException: If chat fails @@ -161,7 +159,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An {"role": "user", "content": chat_req.query}, ] - self.logger.info("Starting to generate complete response...") + self.logger.info("[Cloud Service] Starting to generate chat complete response...") # Step 3: Generate complete response from LLM if chat_req.model_name_or_path and chat_req.model_name_or_path not in self.chat_llms: @@ -172,11 +170,23 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) - self.logger.info(f"[Cloud Service Chat Complete Model]: {model}") + self.logger.info(f"[Cloud Service] Chat Complete Model: {model}") strat = time.time() response = self.chat_llms[model].generate(current_messages, model_name_or_path=model) end = time.time() - self.logger.info(f"[Cloud Service Chat Complete Time]: {end - strat} seconds") + self.logger.info(f"[Cloud Service] Chat Complete Time: {end - strat} seconds") + + if not response: + self.logger.error( + f"[Cloud Service] Chat Complete Failed, LLM response is {response}" + ) + raise HTTPException( + status_code=500, detail="Chat complete failed, LLM response is None" + ) + + self.logger.info( + f"[Cloud Service] Chat Complete LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Complete LLM Response: {response}" + ) # Step 4: start add after chat asynchronously if chat_req.add_message_on_answer: @@ -192,7 +202,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An async_mode="async", ) end = time.time() - self.logger.info(f"[Cloud Service Chat Add Time]: {end - start} seconds") + self.logger.info(f"[Cloud Service] Chat Add Time: {end - start} seconds") match = re.search(r"([\s\S]*?)", response) reasoning_text = match.group(1) if match else None @@ -208,14 +218,12 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An except ValueError as err: raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err except Exception as err: - self.logger.error(f"Failed to complete chat: {traceback.format_exc()}") + self.logger.error(f"[Cloud Service] Failed to chat complete: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: """ - Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. - - This implementation directly uses search_handler and add_handler. + Chat with MemOS via Server-Sent Events (SSE) stream for chat stream response. Args: chat_req: Chat stream request @@ -229,7 +237,7 @@ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: try: def generate_chat_response() -> Generator[str, None, None]: - """Generate chat response as SSE stream.""" + """Generate chat stream response as SSE stream.""" try: # Resolve readable cube IDs (for search) readable_cube_ids = chat_req.readable_cube_ids or ( @@ -289,7 +297,7 @@ def generate_chat_response() -> Generator[str, None, None]: ] self.logger.info( - f"user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " + f"[Cloud Service] chat stream user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " f"current_system_prompt: {system_prompt}" ) @@ -304,14 +312,12 @@ def generate_chat_response() -> Generator[str, None, None]: ) model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) - self.logger.info(f"[Cloud Service Chat Stream Model]: {model}") + self.logger.info(f"[Cloud Service] Chat Stream Model: {model}") start = time.time() response_stream = self.chat_llms[model].generate_stream( current_messages, model_name_or_path=model ) - end = time.time() - self.logger.info(f"[Cloud Service Chat Stream Time]: {end - start} seconds") # Stream the response buffer = "" @@ -337,6 +343,13 @@ def generate_chat_response() -> Generator[str, None, None]: chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n" yield chunk_data + end = time.time() + self.logger.info(f"[Cloud Service] Chat Stream Time: {end - start} seconds") + + self.logger.info( + f"[Cloud Service] Chat Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Stream LLM Response: {full_response}" + ) + current_messages.append({"role": "assistant", "content": full_response}) if chat_req.add_message_on_answer: # Resolve writable cube IDs (for add) @@ -354,10 +367,10 @@ def generate_chat_response() -> Generator[str, None, None]: ) end = time.time() self.logger.info( - f"[Cloud Service Chat Stream Add Time]: {end - start} seconds" + f"[Cloud Service] Chat Stream Add Time: {end - start} seconds" ) except Exception as e: - self.logger.error(f"Error in chat stream: {e}", exc_info=True) + self.logger.error(f"[Cloud Service] Error in chat stream: {e}", exc_info=True) error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" yield error_data @@ -377,14 +390,14 @@ def generate_chat_response() -> Generator[str, None, None]: except ValueError as err: raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err except Exception as err: - self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + self.logger.error( + f"[Cloud Service] Failed to start chat stream: {traceback.format_exc()}" + ) raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> StreamingResponse: """ - Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. - - This implementation directly uses search_handler and add_handler. + Chat with MemOS via Server-Sent Events (SSE) stream for playground chat stream response. Args: chat_req: Chat stream request @@ -398,7 +411,7 @@ def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> Stre try: def generate_chat_response() -> Generator[str, None, None]: - """Generate chat response as SSE stream.""" + """Generate playground chat stream response as SSE stream.""" try: import time @@ -434,7 +447,9 @@ def generate_chat_response() -> Generator[str, None, None]: start_time = time.time() search_response = self.search_handler.handle_search_memories(search_req) end_time = time.time() - self.logger.info(f"first search time: {end_time - start_time}") + self.logger.info( + f"[PLAYGROUND CHAT] first search time: {end_time - start_time}" + ) yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" @@ -481,7 +496,7 @@ def generate_chat_response() -> Generator[str, None, None]: conversation=chat_req.history, mode="fine", ) - self.logger.info(f"[PLAYGROUND chat parsed_goal]: {parsed_goal}") + self.logger.info(f"[PLAYGROUND CHAT] parsed_goal: {parsed_goal}") if chat_req.beginner_guide_step == "first": chat_req.internet_search = False @@ -512,12 +527,14 @@ def generate_chat_response() -> Generator[str, None, None]: search_tool_memory=False, ) - self.logger.info(f"[PLAYGROUND second search query]: {search_req.query}") + self.logger.info(f"[PLAYGROUND CHAT] second search query: {search_req.query}") start_time = time.time() search_response = self.search_handler.handle_search_memories(search_req) end_time = time.time() - self.logger.info(f"second search time: {end_time - start_time}") + self.logger.info( + f"[PLAYGROUND CHAT] second search time: {end_time - start_time}" + ) # for playground, add the query to memory without response self._start_add_to_memory( @@ -578,13 +595,15 @@ def generate_chat_response() -> Generator[str, None, None]: ] self.logger.info( - f"user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " + f"[PLAYGROUND CHAT] user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " f"current_system_prompt: {system_prompt}" ) # Step 3: Generate streaming response from LLM try: model = next(iter(self.chat_llms.keys())) + self.logger.info(f"[PLAYGROUND CHAT] Chat Playground Stream Model: {model}") + start = time.time() response_stream = self.chat_llms[model].generate_stream( current_messages, model_name_or_path=model ) @@ -629,10 +648,19 @@ def generate_chat_response() -> Generator[str, None, None]: chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" yield chunk_data + end = time.time() + self.logger.info( + f"[PLAYGROUND CHAT] Chat Playground Stream Time: {end - start} seconds" + ) + self.logger.info( + f"[PLAYGROUND CHAT] Chat Playground Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Playground Stream LLM Response: {full_response}" + ) + except Exception as llm_error: # Log the error self.logger.error( - f"Error during LLM generation: {llm_error}", exc_info=True + f"[PLAYGROUND CHAT] Error during LLM generation: {llm_error}", + exc_info=True, ) # Send error message to client error_msg = f"模型生成错误: {llm_error!s}" @@ -654,7 +682,7 @@ def generate_chat_response() -> Generator[str, None, None]: # Get further suggestion current_messages.append({"role": "assistant", "content": full_response}) further_suggestion = self._get_further_suggestion(current_messages) - self.logger.info(f"further_suggestion: {further_suggestion}") + self.logger.info(f"[PLAYGROUND CHAT] further_suggestion: {further_suggestion}") yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n" yield f"data: {json.dumps({'type': 'end'})}\n\n" @@ -685,7 +713,9 @@ def generate_chat_response() -> Generator[str, None, None]: ) except Exception as e: - self.logger.error(f"Error in chat stream: {e}", exc_info=True) + self.logger.error( + f"[PLAYGROUND CHAT] Error in playground chat stream: {e}", exc_info=True + ) error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" yield error_data @@ -705,7 +735,9 @@ def generate_chat_response() -> Generator[str, None, None]: except ValueError as err: raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err except Exception as err: - self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + self.logger.error( + f"[PLAYGROUND CHAT] Failed to start playground chat stream: {traceback.format_exc()}" + ) raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err def _dedup_and_supplement_memories( diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 7af3afe74..56f8ac195 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -177,7 +177,11 @@ def init_server() -> dict[str, Any]: else None ) llm = LLMFactory.from_config(llm_config) - chat_llms = _init_chat_llms(chat_llm_config) + chat_llms = ( + _init_chat_llms(chat_llm_config) + if os.getenv("ENABLE_CHAT_API", "false") == "true" + else None + ) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) reranker = RerankerFactory.from_config(reranker_config) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index fcb70a64c..37ca361ea 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -15,7 +15,7 @@ import random as _random import socket -from fastapi import APIRouter, Query +from fastapi import APIRouter, HTTPException, Query from memos.api import handlers from memos.api.handlers.add_handler import AddHandler @@ -64,12 +64,16 @@ # Initialize all handlers with dependency injection search_handler = SearchHandler(dependencies) add_handler = AddHandler(dependencies) -chat_handler = ChatHandler( - dependencies, - components["chat_llms"], - search_handler, - add_handler, - online_bot=components.get("online_bot"), +chat_handler = ( + ChatHandler( + dependencies, + components["chat_llms"], + search_handler, + add_handler, + online_bot=components.get("online_bot"), + ) + if os.getenv("ENABLE_CHAT_API", "false") == "true" + else None ) feedback_handler = FeedbackHandler(dependencies) # Extract commonly used components for function-based handlers @@ -201,6 +205,10 @@ def chat_complete(chat_req: APIChatCompleteRequest): This endpoint uses the class-based ChatHandler. """ + if chat_handler is None: + raise HTTPException( + status_code=503, detail="Chat service is not available. Chat handler not initialized." + ) return chat_handler.handle_chat_complete(chat_req) @@ -212,6 +220,10 @@ def chat_stream(chat_req: ChatRequest): This endpoint uses the class-based ChatHandler which internally composes SearchHandler and AddHandler for a clean architecture. """ + if chat_handler is None: + raise HTTPException( + status_code=503, detail="Chat service is not available. Chat handler not initialized." + ) return chat_handler.handle_chat_stream(chat_req) @@ -223,6 +235,10 @@ def chat_stream_playground(chat_req: ChatPlaygroundRequest): This endpoint uses the class-based ChatHandler which internally composes SearchHandler and AddHandler for a clean architecture. """ + if chat_handler is None: + raise HTTPException( + status_code=503, detail="Chat service is not available. Chat handler not initialized." + ) return chat_handler.handle_chat_stream_playground(chat_req) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 3404c6d4c..0c6e5339d 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -70,6 +70,9 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) if not response: + logger.error( + f"[prefer_extractor]: (Error) LLM response content is {response} when extracting explicit preference" + ) return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) @@ -95,6 +98,9 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A try: response = self.llm_provider.generate([{"role": "user", "content": prompt}]) if not response: + logger.error( + f"[prefer_extractor]: (Error) LLM response content is {response} when extracting implicit preference" + ) return None response = response.strip().replace("```json", "").replace("```", "").strip() result = json.loads(response) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index ab3d0ce03..a920f7b0e 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -57,6 +57,7 @@ class SingleCubeView(MemCubeView): feedback_server: Any | None = None deepsearch_agent: Any | None = None + @timed def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: """ This is basically your current handle_add_memories logic, @@ -103,6 +104,7 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: return all_memories + @timed def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: # Create UserContext object user_context = UserContext( @@ -150,6 +152,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: self.logger.info(f"Search {len(memories_result)} memories.") return memories_result + @timed def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]: target_session_id = feedback_req.session_id or "default_session" if feedback_req.async_mode == "async": @@ -554,6 +557,7 @@ def _schedule_memory_tasks( ) self.mem_scheduler.submit_messages(messages=[message_item_add]) + @timed def _process_pref_mem( self, add_req: APIADDRequest, @@ -732,6 +736,7 @@ def add_before_search( self.logger.error(f"[add_before_search] LLM execution error: {e}") return memory_list + @timed def _process_text_mem( self, add_req: APIADDRequest, From de0376cc0e7710dd3eeae499ac8d1c13b3daf9e5 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Thu, 25 Dec 2025 17:42:12 +0800 Subject: [PATCH 09/48] Feat: add OpenAI log (#785) * feat: timer false * feat: add openai request body log * feat: add openai request body log * feat: add openai request body log * feat: update openapi.json * feat: add timer status error log --------- Co-authored-by: harvey_xiang Co-authored-by: CaralHsi --- src/memos/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/memos/utils.py b/src/memos/utils.py index b57967db0..4f2666efd 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -1,5 +1,6 @@ import functools import time +import traceback from memos.log import get_logger @@ -35,6 +36,7 @@ def decorator(fn): def wrapper(*args, **kwargs): start = time.perf_counter() exc_type = None + exc_message = None result = None success_flag = False @@ -44,6 +46,7 @@ def wrapper(*args, **kwargs): return result except Exception as e: exc_type = type(e) + exc_message = traceback.format_exc() success_flag = False if fallback is not None and callable(fallback): @@ -78,7 +81,9 @@ def wrapper(*args, **kwargs): status_info = f", status: {status}" if not success_flag and exc_type is not None: - status_info += f", error: {exc_type.__name__}" + status_info += ( + f", error_type: {exc_type.__name__}, error_message: {exc_message}" + ) msg = ( f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} " From 3e4b34269ecd13b245cbe919c96ebec22dcbbea0 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 25 Dec 2025 20:19:50 +0800 Subject: [PATCH 10/48] Feat/dedup playground display (#789) add dedup to playground tree display Co-authored-by: yuan.wang --- src/memos/api/handlers/memory_handler.py | 13 +++++++++++++ src/memos/api/routers/server_router.py | 2 ++ 2 files changed, 15 insertions(+) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index a33ee9254..5cfa98160 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -23,6 +23,10 @@ remove_embedding_recursive, sort_children_by_memory_type, ) +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + cosine_similarity_matrix, + find_best_unrelated_subgroup, +) if TYPE_CHECKING: @@ -37,6 +41,7 @@ def handle_get_all_memories( mem_cube_id: str, memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], naive_mem_cube: Any, + embedder: Any, ) -> MemoryResponse: """ Main handler for getting all memories. @@ -59,6 +64,14 @@ def handle_get_all_memories( # Get all text memories from the graph database memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id) + mems = [mem.get("memory", "") for mem in memories.get("nodes", [])] + embeddings = embedder.embed(mems) + similarity_matrix = cosine_similarity_matrix(embeddings) + selected_indices, _ = find_best_unrelated_subgroup( + embeddings, similarity_matrix, bar=0.9 + ) + memories["nodes"] = [memories["nodes"][i] for i in selected_indices] + # Format and convert to tree structure memories_cleaned = remove_embedding_recursive(memories) custom_type_ratios = { diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 37ca361ea..e87e006dd 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -83,6 +83,7 @@ naive_mem_cube = components["naive_mem_cube"] redis_client = components["redis_client"] status_tracker = TaskStatusTracker(redis_client=redis_client) +embedder = components["embedder"] # ============================================================================= @@ -294,6 +295,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): ), memory_type=memory_req.memory_type or "text_mem", naive_mem_cube=naive_mem_cube, + embedder=embedder, ) From 10342ef38a6d16639c3a9b7e3551173fd67c7360 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 25 Dec 2025 21:10:36 +0800 Subject: [PATCH 11/48] add get_user_names_by_memory_ids api (#790) Co-authored-by: yuan.wang --- src/memos/api/product_models.py | 13 ++++++++++++ src/memos/api/routers/server_router.py | 28 ++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index adcb68a96..3c7070ec9 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1168,3 +1168,16 @@ class AllStatusResponse(BaseResponse[AllStatusResponseData]): """Response model for full scheduler status operations.""" message: str = "Scheduler status summary retrieved successfully" + + +# ─── Internal API Endpoints Models (for internal use) ─────────────────────────────────────────────────── + + +class GetUserNamesByMemoryIdsRequest(BaseRequest): + """Request model for getting user names by memory ids.""" + + memory_ids: list[str] = Field(..., description="Memory IDs") + + +class GetUserNamesByMemoryIdsResponse(BaseResponse[dict[str, list[str]]]): + """Response model for getting user names by memory ids.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e87e006dd..07c42bbb2 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -36,6 +36,8 @@ GetMemoryPlaygroundRequest, GetMemoryRequest, GetMemoryResponse, + GetUserNamesByMemoryIdsRequest, + GetUserNamesByMemoryIdsResponse, MemoryResponse, SearchResponse, StatusResponse, @@ -43,6 +45,7 @@ SuggestionResponse, TaskQueueResponse, ) +from memos.graph_dbs.polardb import PolarDBGraphDB from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -84,6 +87,7 @@ redis_client = components["redis_client"] status_tracker = TaskStatusTracker(redis_client=redis_client) embedder = components["embedder"] +graph_db = components["graph_db"] # ============================================================================= @@ -329,3 +333,27 @@ def feedback_memories(feedback_req: APIFeedbackRequest): This endpoint uses the class-based FeedbackHandler for better code organization. """ return feedback_handler.handle_feedback_memories(feedback_req) + + +# ============================================================================= +# Other API Endpoints (for internal use) +# ============================================================================= + + +@router.get( + "/get_user_names_by_memory_ids", + summary="Get user names by memory ids", + response_model=GetUserNamesByMemoryIdsResponse, +) +def get_user_names_by_memory_ids(memory_ids: GetUserNamesByMemoryIdsRequest): + """Get user names by memory ids.""" + if not isinstance(graph_db, PolarDBGraphDB): + raise HTTPException( + status_code=400, + detail=( + "graph_db must be an instance of PolarDBGraphDB to use " + "get_user_names_by_memory_ids" + f"current graph_db is: {graph_db.__class__.__name__}" + ), + ) + return graph_db.get_user_names_by_memory_ids(memory_ids=memory_ids) From fac1aa76422cc1455f30b260a7246d5d34a69202 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Fri, 26 Dec 2025 10:28:06 +0800 Subject: [PATCH 12/48] feat: add batch delete (#787) --- src/memos/graph_dbs/polardb.py | 219 +++++++++++++++++++++------------ 1 file changed, 137 insertions(+), 82 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 1d19dc98d..b29dd26ce 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4869,6 +4869,7 @@ def delete_node_by_prams( memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, + batch_size: int = 100, ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. @@ -4898,31 +4899,6 @@ def delete_node_by_prams( f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" ) - # Build WHERE conditions separately for memory_ids and file_ids - where_conditions = [] - - # Handle memory_ids: query properties.id - if memory_ids and len(memory_ids) > 0: - memory_id_conditions = [] - for node_id in memory_ids: - memory_id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - if memory_id_conditions: - where_conditions.append(f"({' OR '.join(memory_id_conditions)})") - - # Check if any file_id is in the file_ids array field (OR relationship) - if file_ids and len(file_ids) > 0: - file_id_conditions = [] - for file_id in file_ids: - # Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype) - file_id_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - if file_id_conditions: - # Use OR to match any file_id in the array - where_conditions.append(f"({' OR '.join(file_id_conditions)})") - # Query nodes by filter if provided filter_ids = set() if filter: @@ -4943,77 +4919,156 @@ def delete_node_by_prams( "[delete_node_by_prams] Filter parsed to None, skipping filter query" ) - # If filter returned IDs, add condition for them + # Combine all IDs that need to be deleted + all_memory_ids = set() + if memory_ids: + all_memory_ids.update(memory_ids) if filter_ids: - filter_id_conditions = [] - for node_id in filter_ids: - filter_id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - if filter_id_conditions: - where_conditions.append(f"({' OR '.join(filter_id_conditions)})") + all_memory_ids.update(filter_ids) - # If no conditions (except user_name), return 0 - if not where_conditions: + # If no conditions to delete, return 0 + if not all_memory_ids and not file_ids: logger.warning( "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" ) return 0 - # Build WHERE clause - # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) - data_conditions = " OR ".join([f"({cond})" for cond in where_conditions]) + conn = None + total_deleted_count = 0 + try: + conn = self._get_connection() + with conn.cursor() as cursor: + # Process memory_ids and filter_ids in batches + if all_memory_ids: + memory_ids_list = list(all_memory_ids) + total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size + logger.info( + f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})" + ) - # Build final WHERE clause - # If user_name_conditions exist, combine with data_conditions using AND - # Otherwise, use only data_conditions - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({data_conditions})" - else: - where_clause = f"({data_conditions})" + for batch_idx in range(total_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(memory_ids_list)) + batch_ids = memory_ids_list[batch_start:batch_end] - # Use SQL DELETE query for better performance - # First count matching nodes to get accurate count - count_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] count_query: {count_query}") + # Build conditions for this batch + batch_conditions = [] + for node_id in batch_ids: + batch_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" + ) + batch_where = f"({' OR '.join(batch_conditions)})" - # Then delete nodes - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({batch_where})" + else: + where_clause = batch_where - logger.info( - f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" - ) - logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + # Count before deletion + count_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info( + f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}" + ) - conn = None - deleted_count = 0 - try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Count nodes before deletion - cursor.execute(count_query) - count_result = cursor.fetchone() - expected_count = count_result[0] if count_result else 0 + cursor.execute(count_query) + count_result = cursor.fetchone() + expected_count = count_result[0] if count_result else 0 - logger.info( - f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria" - ) + if expected_count == 0: + logger.info( + f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping" + ) + continue + + # Delete batch + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info( + f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}" + ) + + logger.info( + f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes" + ) + cursor.execute(delete_query) + batch_deleted = cursor.rowcount + total_deleted_count += batch_deleted + + logger.info( + f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})" + ) + + # Process file_ids in batches + if file_ids: + total_file_batches = (len(file_ids) + batch_size - 1) // batch_size + logger.info( + f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})" + ) + + for batch_idx in range(total_file_batches): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, len(file_ids)) + batch_file_ids = file_ids[batch_start:batch_end] + + # Build conditions for this batch + batch_conditions = [] + for file_id in batch_file_ids: + batch_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" + ) + batch_where = f"({' OR '.join(batch_conditions)})" + + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({batch_where})" + else: + where_clause = batch_where + + # Count before deletion + count_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + + logger.info( + f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}" + ) + cursor.execute(count_query) + count_result = cursor.fetchone() + expected_count = count_result[0] if count_result else 0 + + if expected_count == 0: + logger.info( + f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping" + ) + continue + + # Delete batch + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + cursor.execute(delete_query) + batch_deleted = cursor.rowcount + total_deleted_count += batch_deleted + + logger.info( + f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}" + ) - # Delete nodes - cursor.execute(delete_query) - # Use rowcount to get actual deleted count - deleted_count = cursor.rowcount elapsed_time = time.time() - batch_start_time logger.info( - f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes" + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes" ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) @@ -5021,8 +5076,8 @@ def delete_node_by_prams( finally: self._return_connection(conn) - logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") - return deleted_count + logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") + return total_deleted_count @timed def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: From 336a2becacc420a4a2ea8c338233e642e8884e14 Mon Sep 17 00:00:00 2001 From: Zehao Lin Date: Fri, 26 Dec 2025 14:43:00 +0800 Subject: [PATCH 13/48] feat: add dedup search param (#788) * Add dedup option to search pipeline * Fix dedup handling in simple search * feat: optimize memory search deduplication and fix parsing bugs - Tune similarity threshold to 0.92 for 'dedup=sim' to preserve subtle semantic nuances. - Implement recall expansion (5x Top-K) when deduplicating to ensure output diversity. - Remove aggressive filling logic to strictly enforce the similarity threshold. - Fix attribute error in MultiModalStructMemReader by correctly importing parse_json_result. - Replace fragile eval() with robust parse_json_result in TaskGoalParser to handle JSON booleans. --------- Co-authored-by: glin1993@outlook.com <> --- src/memos/api/handlers/formatters_handler.py | 5 +- src/memos/api/handlers/search_handler.py | 102 ++++++++++++++++++ src/memos/api/product_models.py | 9 ++ src/memos/mem_reader/multi_modal_struct.py | 3 +- .../mem_scheduler/optimized_scheduler.py | 11 +- src/memos/mem_scheduler/utils/api_utils.py | 5 +- src/memos/memories/textual/tree.py | 2 + .../retrieve/retrieve_utils.py | 1 - .../tree_text_memory/retrieve/searcher.py | 51 ++------- .../retrieve/task_goal_parser.py | 19 ++-- src/memos/multi_mem_cube/single_cube.py | 26 ++++- 11 files changed, 171 insertions(+), 63 deletions(-) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 88875cacc..94988295b 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]: return list(running) if running else [] -def format_memory_item(memory_data: Any) -> dict[str, Any]: +def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: """ Format a single memory item for API response. @@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]: ref_id = f"[{memory_id.split('-')[0]}]" memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] + if not include_embedding: + memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] memory["metadata"]["usage"] = [] memory["metadata"]["ref_id"] = ref_id diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index f7d6ee2c8..3774410dc 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,9 +5,14 @@ using dependency injection for better modularity and testability. """ +from typing import Any + from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + cosine_similarity_matrix, +) from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView @@ -50,9 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse """ self.logger.info(f"[SearchHandler] Search Req is: {search_req}") + # Increase recall pool if deduplication is enabled to ensure diversity + original_top_k = search_req.top_k + if search_req.dedup == "sim": + search_req.top_k = original_top_k * 5 + cube_view = self._build_cube_view(search_req) results = cube_view.search_memories(search_req) + if search_req.dedup == "sim": + results = self._dedup_text_memories(results, original_top_k) + self._strip_embeddings(results) + # Restore original top_k for downstream logic or response metadata + search_req.top_k = original_top_k self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" @@ -63,6 +78,93 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse data=results, ) + def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]: + buckets = results.get("text_mem", []) + if not buckets: + return results + + flat: list[tuple[int, dict[str, Any], float]] = [] + for bucket_idx, bucket in enumerate(buckets): + for mem in bucket.get("memories", []): + score = mem.get("metadata", {}).get("relativity", 0.0) + flat.append((bucket_idx, mem, score)) + + if len(flat) <= 1: + return results + + embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) + if embeddings is None: + documents = [mem.get("memory", "") for _, mem, _ in flat] + embeddings = self.searcher.embedder.embed(documents) + + similarity_matrix = cosine_similarity_matrix(embeddings) + + indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))} + for flat_index, (bucket_idx, _, _) in enumerate(flat): + indices_by_bucket[bucket_idx].append(flat_index) + + selected_global: list[int] = [] + selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))} + + ordered_indices = sorted(range(len(flat)), key=lambda idx: flat[idx][2], reverse=True) + for idx in ordered_indices: + bucket_idx = flat[idx][0] + if len(selected_by_bucket[bucket_idx]) >= target_top_k: + continue + # Use 0.92 threshold strictly + if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92): + selected_by_bucket[bucket_idx].append(idx) + selected_global.append(idx) + + # Removed the 'filling' logic that was pulling back similar items. + # Now it will only return items that truly pass the 0.92 threshold, + # up to target_top_k. + + for bucket_idx, bucket in enumerate(buckets): + selected_indices = selected_by_bucket.get(bucket_idx, []) + bucket["memories"] = [flat[i][1] for i in selected_indices] + return results + + @staticmethod + def _is_unrelated( + index: int, + selected_indices: list[int], + similarity_matrix: list[list[float]], + similarity_threshold: float, + ) -> bool: + return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices) + + @staticmethod + def _max_similarity( + index: int, selected_indices: list[int], similarity_matrix: list[list[float]] + ) -> float: + if not selected_indices: + return 0.0 + return max(similarity_matrix[index][j] for j in selected_indices) + + @staticmethod + def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + embeddings: list[list[float]] = [] + for mem in memories: + embedding = mem.get("metadata", {}).get("embedding") + if not embedding: + return None + embeddings.append(embedding) + return embeddings + + @staticmethod + def _strip_embeddings(results: dict[str, Any]) -> None: + for bucket in results.get("text_mem", []): + for mem in bucket.get("memories", []): + metadata = mem.get("metadata", {}) + if "embedding" in metadata: + metadata["embedding"] = [] + for bucket in results.get("tool_mem", []): + for mem in bucket.get("memories", []): + metadata = mem.get("metadata", {}) + if "embedding" in metadata: + metadata["embedding"] = [] + def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: """ Normalize target cube ids from search_req. diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 3c7070ec9..120da8b55 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest): description="Number of textual memories to retrieve (top-K). Default: 10.", ) + dedup: Literal["no", "sim"] | None = Field( + None, + description=( + "Optional dedup option for textual memories. " + "Use 'no' for no dedup, 'sim' for similarity dedup. " + "If None, default exact-text dedup is applied." + ), + ) + pref_top_k: int = Field( 6, ge=0, diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 48be9b72c..2ed1af53e 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -10,6 +10,7 @@ from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang from memos.mem_reader.read_multi_modal.base import _derive_key from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader +from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -377,7 +378,7 @@ def _get_llm_response( messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) - response_json = self.parse_json_result(response_text) + response_json = parse_json_result(response_text) except Exception as e: logger.error(f"[LLM] Exception during chat generation: {e}") response_json = { diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index c3f5891ae..7007f8418 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -186,10 +186,14 @@ def mix_search_memories( info=info, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + dedup=search_req.dedup, ) memories = merged_memories[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] + formatted_memories = [ + format_textual_memory_item(item, include_embedding=search_req.dedup == "sim") + for item in memories + ] self.submit_memory_history_async_task( search_req=search_req, user_context=user_context, @@ -233,7 +237,10 @@ def update_search_memories_to_redis( mem_cube=self.mem_cube, mode=SearchMode.FAST, ) - formatted_memories = [format_textual_memory_item(data) for data in memories] + formatted_memories = [ + format_textual_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in memories + ] else: memories = [ TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"] diff --git a/src/memos/mem_scheduler/utils/api_utils.py b/src/memos/mem_scheduler/utils/api_utils.py index c8d096517..3833b5926 100644 --- a/src/memos/mem_scheduler/utils/api_utils.py +++ b/src/memos/mem_scheduler/utils/api_utils.py @@ -6,14 +6,15 @@ from memos.memories.textual.tree import TextualMemoryItem -def format_textual_memory_item(memory_data: Any) -> dict[str, Any]: +def format_textual_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: """Format a single memory item for API response.""" memory = memory_data.model_dump() memory_id = memory["id"] ref_id = f"[{memory_id.split('-')[0]}]" memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] + if not include_embedding: + memory["metadata"]["embedding"] = [] memory["metadata"]["sources"] = [] memory["metadata"]["ref_id"] = ref_id memory["metadata"]["id"] = memory_id diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 22545496a..fb33a2d03 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -161,6 +161,7 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -207,6 +208,7 @@ def search( user_name=user_name, search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + dedup=dedup, **kwargs, ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index d9398a22c..5a82883c8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any - import numpy as np from memos.dependency import require_python_package diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index dc47dd4d7..f3d6ba037 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -119,9 +119,13 @@ def post_retrieve( info=None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, plugin=False, ): - deduped = self._deduplicate_results(retrieved_results) + if dedup == "no": + deduped = retrieved_results + else: + deduped = self._deduplicate_results(retrieved_results) final_results = self._sort_and_trim( deduped, top_k, plugin, search_tool_memory, tool_mem_top_k ) @@ -141,6 +145,7 @@ def search( user_name: str | None = None, search_tool_memory: bool = False, tool_mem_top_k: int = 6, + dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: """ @@ -202,6 +207,7 @@ def search( plugin=kwargs.get("plugin", False), search_tool_memory=search_tool_memory, tool_mem_top_k=tool_mem_top_k, + dedup=dedup, ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -284,49 +290,6 @@ def _parse_task( return parsed_goal, query_embedding, context, query - @timed - def _retrieve_simple( - self, - query: str, - top_k: int, - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ): - """Retrieve from by keywords and embedding""" - query_words = [] - if self.tokenizer: - query_words = self.tokenizer.tokenize_mixed(query) - else: - query_words = query.strip().split() - query_words = [query, *query_words] - logger.info(f"[SIMPLESEARCH] Query words: {query_words}") - query_embeddings = self.embedder.embed(query_words) - - items = self.graph_retriever.retrieve_from_mixed( - top_k=top_k * 2, - memory_scope=None, - query_embedding=query_embeddings, - search_filter=search_filter, - user_name=user_name, - use_fast_graph=self.use_fast_graph, - ) - logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") - documents = [getattr(item, "memory", "") for item in items] - documents_embeddings = self.embedder.embed(documents) - similarity_matrix = cosine_similarity_matrix(documents_embeddings) - selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) - selected_items = [items[i] for i in selected_indices] - logger.info( - f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" - ) - return self.reranker.rerank( - query=query, - query_embedding=query_embeddings[0], - graph_results=selected_items, - top_k=top_k, - ) - @timed def _retrieve_paths( self, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index e1ce859bf..f4d6c4847 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -5,7 +5,10 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + parse_json_result, +) from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT @@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: for attempt_times in range(attempts): try: context = kwargs.get("context", "") - response = response.replace("```", "").replace("json", "").strip() - response_json = eval(response) + response_json = parse_json_result(response) + if not response_json: + raise ValueError("Parsed JSON is empty") + return ParsedTaskGoal( memories=response_json.get("memories", []), keys=response_json.get("keys", []), @@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: context=context, ) except Exception as e: - raise ValueError( - f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}" - ) from e + if attempt_times == attempts - 1: + raise ValueError( + f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}" + ) from e + continue diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index a920f7b0e..6c3cc0cc7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -264,7 +264,10 @@ def _deep_search( search_filter=search_filter, info=info, ) - formatted_memories = [format_memory_item(data) for data in enhanced_memories] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in enhanced_memories + ] return formatted_memories def _agentic_search( @@ -273,7 +276,10 @@ def _agentic_search( deepsearch_results = self.deepsearch_agent.run( search_req.query, user_id=user_context.mem_cube_id ) - formatted_memories = [format_memory_item(data) for data in deepsearch_results] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in deepsearch_results + ] return formatted_memories def _fine_search( @@ -328,6 +334,7 @@ def _fine_search( top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, + dedup=search_req.dedup, ) # Enhance with query @@ -378,8 +385,13 @@ def _dedup_by_content(memories: list) -> list: unique_memories.append(mem) return unique_memories - deduped_memories = _dedup_by_content(enhanced_memories) - formatted_memories = [format_memory_item(data) for data in deduped_memories] + deduped_memories = ( + enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories) + ) + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in deduped_memories + ] logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") @@ -463,9 +475,13 @@ def _fast_search( plugin=plugin, search_tool_memory=search_req.search_tool_memory, tool_mem_top_k=search_req.tool_mem_top_k, + dedup=search_req.dedup, ) - formatted_memories = [format_memory_item(data) for data in search_results] + formatted_memories = [ + format_memory_item(data, include_embedding=search_req.dedup == "sim") + for data in search_results + ] return formatted_memories From 748ef3db035c8d905f006c5db3aa1648d2d99999 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Fri, 26 Dec 2025 16:11:51 +0800 Subject: [PATCH 14/48] Dev zdy 1226 page (#796) * feat: add export_graph page for polardb.py * feat: add export_graph page for neo4j.py * feat: add get_user_names_by_memory_ids * feat: add delete_node_by_prams --- src/memos/graph_dbs/neo4j.py | 140 ++++++++++++++++++++++++++++----- src/memos/graph_dbs/polardb.py | 41 +++++++--- 2 files changed, 151 insertions(+), 30 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index debbb4e3c..2b3859252 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1132,10 +1132,21 @@ def clear(self, user_name: str | None = None) -> None: logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}") raise - def export_graph(self, **kwargs) -> dict[str, Any]: + def export_graph( + self, + page: int | None = None, + page_size: int | None = None, + **kwargs, + ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. + Args: + page (int, optional): Page number (starts from 1). If None, exports all data without pagination. + page_size (int, optional): Number of items per page. If None, exports all data without pagination. + **kwargs: Additional keyword arguments, including: + - user_name (str, optional): User name for filtering in non-multi-db mode + Returns: { "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], @@ -1143,6 +1154,18 @@ def export_graph(self, **kwargs) -> dict[str, Any]: } """ user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + + # Determine if pagination is needed + use_pagination = page is not None and page_size is not None + + # Validate pagination parameters if pagination is enabled + if use_pagination: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + skip = (page - 1) * page_size + with self.driver.session(database=self.db_name) as session: # Export nodes node_query = "MATCH (n:Memory)" @@ -1154,13 +1177,23 @@ def export_graph(self, **kwargs) -> dict[str, Any]: edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" params["user_name"] = user_name - node_result = session.run(f"{node_query} RETURN n", params) + # Add ORDER BY and pagination for nodes + node_query += " RETURN n ORDER BY n.id" + if use_pagination: + node_query += f" SKIP {skip} LIMIT {page_size}" + + node_result = session.run(node_query, params) nodes = [self._parse_node(dict(record["n"])) for record in node_result] # Export edges - edge_result = session.run( - f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) AS type", params + # Add ORDER BY and pagination for edges + edge_query += ( + " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id" ) + if use_pagination: + edge_query += f" SKIP {skip} LIMIT {page_size}" + + edge_result = session.run(edge_query, params) edges = [ {"source": record["source"], "target": record["target"], "type": record["type"]} for record in edge_result @@ -1646,7 +1679,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: def delete_node_by_prams( self, - writable_cube_ids: list[str], + writable_cube_ids: list[str] | None = None, memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, @@ -1655,7 +1688,8 @@ def delete_node_by_prams( Delete nodes by memory_ids, file_ids, or filter. Args: - writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter. + writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. + If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. filter (dict, optional): Filter dictionary to query matching nodes for deletion. @@ -1670,20 +1704,18 @@ def delete_node_by_prams( f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) - # Validate writable_cube_ids - if not writable_cube_ids or len(writable_cube_ids) == 0: - raise ValueError("writable_cube_ids is required and cannot be empty") - # Build WHERE conditions separately for memory_ids and file_ids where_clauses = [] params = {} # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) + # Only add user_name filter if writable_cube_ids is provided user_name_conditions = [] - for idx, cube_id in enumerate(writable_cube_ids): - param_name = f"cube_id_{idx}" - user_name_conditions.append(f"n.user_name = ${param_name}") - params[param_name] = cube_id + if writable_cube_ids and len(writable_cube_ids) > 0: + for idx, cube_id in enumerate(writable_cube_ids): + param_name = f"cube_id_{idx}" + user_name_conditions.append(f"n.user_name = ${param_name}") + params[param_name] = cube_id # Handle memory_ids: query n.id if memory_ids and len(memory_ids) > 0: @@ -1711,7 +1743,7 @@ def delete_node_by_prams( filters=[], user_name=None, filter=filter, - knowledgebase_ids=writable_cube_ids, + knowledgebase_ids=writable_cube_ids if writable_cube_ids else None, ) # If filter returned IDs, add condition for them @@ -1730,9 +1762,14 @@ def delete_node_by_prams( # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) - # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions) - user_name_where = " OR ".join(user_name_conditions) - ids_where = f"({user_name_where}) AND ({data_conditions})" + # Build final WHERE clause + # If user_name_conditions exist, combine with data_conditions using AND + # Otherwise, use only data_conditions + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + ids_where = f"({user_name_where}) AND ({data_conditions})" + else: + ids_where = f"({data_conditions})" logger.info( f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" @@ -1773,3 +1810,70 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") return deleted_count + + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + """Get user names by memory ids. + + Args: + memory_ids: List of memory node IDs to query. + + Returns: + dict[str, list[str]]: Dictionary with one key: + - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) + - 'exist_user_names': List of distinct user names (if all memory_ids exist) + """ + if not memory_ids: + return {"exist_user_names": []} + + logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids") + + try: + with self.driver.session(database=self.db_name) as session: + # Query to check which memory_ids exist + check_query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN n.id AS id + """ + + check_result = session.run(check_query, memory_ids=memory_ids) + existing_ids = set() + for record in check_result: + node_id = record["id"] + existing_ids.add(node_id) + + # Check if any memory_ids are missing + no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] + + # If any memory_ids are missing, return no_exist_memory_ids + if no_exist_list: + logger.info( + f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" + ) + return {"no_exist_memory_ids": no_exist_list} + + # All memory_ids exist, query user_names + user_names_query = """ + MATCH (n:Memory) + WHERE n.id IN $memory_ids + RETURN DISTINCT n.user_name AS user_name + """ + logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") + + user_names_result = session.run(user_names_query, memory_ids=memory_ids) + user_names = [] + for record in user_names_result: + user_name = record["user_name"] + if user_name: + user_names.append(user_name) + + logger.info( + f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + ) + + return {"exist_user_names": user_names} + except Exception as e: + logger.error( + f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b29dd26ce..fcb7e0caa 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2505,16 +2505,16 @@ def export_graph( self, include_embedding: bool = False, user_name: str | None = None, - page: int = 1, - page_size: int = 10, + page: int | None = None, + page_size: int | None = None, ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. user_name (str, optional): User name for filtering in non-multi-db mode - page (int): Page number (starts from 1). Default is 1. - page_size (int): Number of items per page. Default is 1000. + page (int, optional): Page number (starts from 1). If None, exports all data without pagination. + page_size (int, optional): Number of items per page. If None, exports all data without pagination. Returns: { @@ -2527,23 +2527,35 @@ def export_graph( ) user_name = user_name if user_name else self._get_config_value("user_name") - # Validate pagination parameters - if page < 1: - page = 1 - if page_size < 1: - page_size = 10 + # Determine if pagination is needed + use_pagination = page is not None and page_size is not None + + # Validate pagination parameters if pagination is enabled + if use_pagination: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + offset = (page - 1) * page_size + else: + offset = None conn = None try: conn = self._get_connection() # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + if include_embedding: node_query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype ORDER BY id - LIMIT {page_size} OFFSET {(page - 1) * page_size} + {pagination_clause} """ else: node_query = f""" @@ -2551,7 +2563,7 @@ def export_graph( FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype ORDER BY id - LIMIT {page_size} OFFSET {(page - 1) * page_size} + {pagination_clause} """ logger.info(f"[export_graph nodes] Query: {node_query}") with conn.cursor() as cursor: @@ -2601,6 +2613,11 @@ def export_graph( conn = self._get_connection() # Export edges using cypher query # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery + # Build pagination clause if needed + edge_pagination_clause = "" + if use_pagination: + edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + edge_query = f""" SELECT source, target, edge FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ @@ -2610,7 +2627,7 @@ def export_graph( ORDER BY a.id, b.id $$) AS (source agtype, target agtype, edge agtype) ) AS edges - LIMIT {page_size} OFFSET {(page - 1) * page_size} + {edge_pagination_clause} """ logger.info(f"[export_graph edges] Query: {edge_query}") with conn.cursor() as cursor: From 21df1c76b0eeed2850dae82fe00fbdd783570ad0 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Mon, 29 Dec 2025 15:20:42 +0800 Subject: [PATCH 15/48] Patch: get_memory adds the page size parameter function and the filtering of user id (#801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback * patch: get_memory func filter user id and make page chunk --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/handlers/memory_handler.py | 7 ++++++- src/memos/api/product_models.py | 7 +++++++ src/memos/graph_dbs/nebular.py | 2 +- src/memos/graph_dbs/polardb.py | 1 + src/memos/memories/textual/tree.py | 12 ++++++++++-- 5 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 5cfa98160..2a99d912c 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -180,7 +180,12 @@ def handle_get_memories( get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube ) -> GetMemoryResponse: # TODO: Implement get memory with filter - memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"] + memories = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + )["nodes"] preferences: list[TextualMemoryItem] = [] if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 120da8b55..25e0d809d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -772,6 +772,13 @@ class GetMemoryRequest(BaseRequest): mem_cube_id: str = Field(..., description="Cube ID") user_id: str | None = Field(None, description="User ID") include_preference: bool = Field(True, description="Whether to handle preference memory") + page: int | None = Field( + None, + description="Page number (starts from 1). If None, exports all data without pagination.", + ) + page_size: int | None = Field( + None, description="Number of items per page. If None, exports all data without pagination." + ) class DeleteMemoryRequest(BaseRequest): diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 89b58f417..428d6d09e 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -1207,7 +1207,7 @@ def clear(self, user_name: str | None = None) -> None: @timed def export_graph( - self, include_embedding: bool = False, user_name: str | None = None + self, include_embedding: bool = False, user_name: str | None = None, **kwargs ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index fcb7e0caa..4799542bf 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2507,6 +2507,7 @@ def export_graph( user_name: str | None = None, page: int | None = None, page_size: int | None = None, + **kwargs, ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index fb33a2d03..764ceee67 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -321,12 +321,20 @@ def get_by_ids( ) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self, user_name: str | None = None) -> dict: + def get_all( + self, + user_name: str, + user_id: str | None = None, + page: int | None = None, + page_size: int | None = None, + ) -> dict: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ - all_items = self.graph_store.export_graph(user_name=user_name) + all_items = self.graph_store.export_graph( + user_name=user_name, user_id=user_id, page=page, page_size=page_size + ) return all_items def delete(self, memory_ids: list[str], user_name: str | None = None) -> None: From 99dcf1dca6762d7384e9121bf1dcc5eaa1f42c53 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Mon, 29 Dec 2025 16:56:49 +0800 Subject: [PATCH 16/48] Dev zdy 1229 (#802) * feat: add get_user_names_by_memory_ids log && add batch * feat: fix delete_node_by_prams * feat: add export_graph user_id --- src/memos/graph_dbs/polardb.py | 342 +++++++++++++++------------------ 1 file changed, 155 insertions(+), 187 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 4799542bf..f88824493 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2505,6 +2505,7 @@ def export_graph( self, include_embedding: bool = False, user_name: str | None = None, + user_id: str | None = None, page: int | None = None, page_size: int | None = None, **kwargs, @@ -2514,6 +2515,7 @@ def export_graph( Args: include_embedding (bool): Whether to include the large embedding field. user_name (str, optional): User name for filtering in non-multi-db mode + user_id (str, optional): User ID for filtering page (int, optional): Page number (starts from 1). If None, exports all data without pagination. page_size (int, optional): Number of items per page. If None, exports all data without pagination. @@ -2524,9 +2526,9 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, page: {page}, page_size: {page_size}" + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}" ) - user_name = user_name if user_name else self._get_config_value("user_name") + user_id = user_id if user_id else self._get_config_value("user_id") # Determine if pagination is needed use_pagination = page is not None and page_size is not None @@ -2550,11 +2552,26 @@ def export_graph( if use_pagination: pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + if include_embedding: node_query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + {where_clause} ORDER BY id {pagination_clause} """ @@ -2562,7 +2579,7 @@ def export_graph( node_query = f""" SELECT id, properties FROM "{self.db_name}_graph"."Memory" - WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype + {where_clause} ORDER BY id {pagination_clause} """ @@ -2619,11 +2636,24 @@ def export_graph( if use_pagination: edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + # Build Cypher WHERE conditions for edges + cypher_where_conditions = [] + if user_name: + cypher_where_conditions.append(f"a.user_name = '{user_name}'") + cypher_where_conditions.append(f"b.user_name = '{user_name}'") + if user_id: + cypher_where_conditions.append(f"a.user_id = '{user_id}'") + cypher_where_conditions.append(f"b.user_id = '{user_id}'") + + cypher_where_clause = "" + if cypher_where_conditions: + cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" + edge_query = f""" SELECT source, target, edge FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (a:Memory)-[r]->(b:Memory) - WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' + {cypher_where_clause} RETURN a.id AS source, b.id AS target, type(r) as edge ORDER BY a.id, b.id $$) AS (source agtype, target agtype, edge agtype) @@ -3399,7 +3429,7 @@ def add_nodes_batch( logger.warning("[add_nodes_batch] Empty nodes list, skipping") return - logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes") + logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") # user_name comes from parameter; fallback to config if missing effective_user_name = user_name if user_name else self.config.user_name @@ -3528,92 +3558,89 @@ def add_nodes_batch( if graph_id: node["properties"]["graph_id"] = str(graph_id) - # Batch insert using VALUES with multiple rows - # Use psycopg2.extras.execute_values for efficient batch insert - from psycopg2.extras import execute_values - - if embedding_column and any(node["embedding_vector"] for node in nodes_group): - # Prepare data tuples for batch insert with embedding - data_tuples = [] - for node in nodes_group: - # Each tuple: (id, properties_json, embedding_json) - data_tuples.append( - ( - node["id"], - json.dumps(node["properties"]), - json.dumps(node["embedding_vector"]) - if node["embedding_vector"] - else None, + # Use PREPARE/EXECUTE for efficient batch insert + # Generate unique prepare statement name to avoid conflicts + prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" + + try: + if embedding_column and any( + node["embedding_vector"] for node in nodes_group + ): + # PREPARE statement for insert with embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype, + $3::vector ) + """ + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" + ) + logger.info( + f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" ) - # Build the INSERT query template - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES %s - """ + cursor.execute(prepare_query) - # Build the VALUES template for execute_values - # Each row: (graph_id_function, agtype, vector) - # Note: properties column is agtype, not jsonb - template = f""" - ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s::text::agtype, - %s::vector - ) - """ - # Execute batch insert - execute_values( - cursor, - insert_query, - data_tuples, - template=template, - page_size=100, # Insert in batches of 100 - ) - else: - # Prepare data tuples for batch insert without embedding - data_tuples = [] - for node in nodes_group: - # Each tuple: (id, properties_json) - data_tuples.append( - ( - node["id"], - json.dumps(node["properties"]), + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) + embedding_json = ( + json.dumps(node["embedding_vector"]) + if node["embedding_vector"] + else None ) + + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s, %s)", + (node["id"], properties_json, embedding_json), + ) + else: + # PREPARE statement for insert without embedding + prepare_query = f""" + PREPARE {prepare_name} AS + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring), + $2::text::agtype + ) + """ + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" ) + logger.info( + f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" + ) + cursor.execute(prepare_query) - # Build the INSERT query template - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES %s - """ + # Execute prepared statement for each node + for node in nodes_group: + properties_json = json.dumps(node["properties"]) - # Build the VALUES template for execute_values - # Note: properties column is agtype, not jsonb - template = f""" - ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s::text::agtype + cursor.execute( + f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + ) + finally: + # DEALLOCATE prepared statement (always execute, even on error) + try: + cursor.execute(f"DEALLOCATE {prepare_name}") + logger.info( + f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" + ) + except Exception as dealloc_error: + logger.warning( + f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - """ - logger.info(f"[add_nodes_batch] Inserting insert_query:{insert_query}") - logger.info(f"[add_nodes_batch] Inserting data_tuples:{data_tuples}") - # Execute batch insert - execute_values( - cursor, - insert_query, - data_tuples, - template=template, - page_size=100, # Insert in batches of 100 - ) logger.info( f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" ) elapsed_time = time.time() - batch_start_time logger.info( - f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s" + f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" ) except Exception as e: @@ -4887,7 +4914,6 @@ def delete_node_by_prams( memory_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, - batch_size: int = 100, ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. @@ -4956,133 +4982,74 @@ def delete_node_by_prams( try: conn = self._get_connection() with conn.cursor() as cursor: - # Process memory_ids and filter_ids in batches + # Process memory_ids and filter_ids (all at once, no batching) if all_memory_ids: memory_ids_list = list(all_memory_ids) - total_batches = (len(memory_ids_list) + batch_size - 1) // batch_size logger.info( - f"[delete_node_by_prams] memoryids Processing {len(memory_ids_list)} memory_ids in {total_batches} batches (batch_size={batch_size})" + f"[delete_node_by_prams] Processing {len(memory_ids_list)} memory_ids" ) - for batch_idx in range(total_batches): - batch_start = batch_idx * batch_size - batch_end = min(batch_start + batch_size, len(memory_ids_list)) - batch_ids = memory_ids_list[batch_start:batch_end] - - # Build conditions for this batch - batch_conditions = [] - for node_id in batch_ids: - batch_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" - ) - batch_where = f"({' OR '.join(batch_conditions)})" - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({batch_where})" - else: - where_clause = batch_where - - # Count before deletion - count_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info( - f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: count_query: {count_query}" + # Build conditions for all memory_ids + id_conditions = [] + for node_id in memory_ids_list: + id_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" ) + id_where = f"({' OR '.join(id_conditions)})" - cursor.execute(count_query) - count_result = cursor.fetchone() - expected_count = count_result[0] if count_result else 0 - - if expected_count == 0: - logger.info( - f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: No nodes found, skipping" - ) - continue + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({id_where})" + else: + where_clause = id_where - # Delete batch - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info( - f"[delete_node_by_prams] memoryids batch {batch_idx + 1}/{total_batches}: delete_query: {delete_query}" - ) + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] memory_ids delete_query: {delete_query}") - logger.info( - f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Executing delete query for {len(batch_ids)} nodes" - ) - cursor.execute(delete_query) - batch_deleted = cursor.rowcount - total_deleted_count += batch_deleted + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count += deleted_count - logger.info( - f"[delete_node_by_prams] memoryids Batch {batch_idx + 1}/{total_batches}: Deleted {batch_deleted} nodes (batch size: {len(batch_ids)})" - ) - - # Process file_ids in batches - if file_ids: - total_file_batches = (len(file_ids) + batch_size - 1) // batch_size logger.info( - f"[delete_node_by_prams] Processing {len(file_ids)} file_ids in {total_file_batches} batches (batch_size={batch_size})" + f"[delete_node_by_prams] Deleted {deleted_count} nodes by memory_ids" ) - for batch_idx in range(total_file_batches): - batch_start = batch_idx * batch_size - batch_end = min(batch_start + batch_size, len(file_ids)) - batch_file_ids = file_ids[batch_start:batch_end] - - # Build conditions for this batch - batch_conditions = [] - for file_id in batch_file_ids: - batch_conditions.append( - f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" - ) - batch_where = f"({' OR '.join(batch_conditions)})" - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({batch_where})" - else: - where_clause = batch_where - - # Count before deletion - count_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ + # Process file_ids (all at once, no batching) + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - logger.info( - f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: count_query: {count_query}" + # Build conditions for all file_ids + file_id_conditions = [] + for file_id in file_ids: + file_id_conditions.append( + f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" ) - cursor.execute(count_query) - count_result = cursor.fetchone() - expected_count = count_result[0] if count_result else 0 + file_id_where = f"({' OR '.join(file_id_conditions)})" - if expected_count == 0: - logger.info( - f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: No nodes found, skipping" - ) - continue + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_clause = f"({user_name_where}) AND ({file_id_where})" + else: + where_clause = file_id_where - # Delete batch - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - cursor.execute(delete_query) - batch_deleted = cursor.rowcount - total_deleted_count += batch_deleted + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] file_ids delete_query: {delete_query}") - logger.info( - f"[delete_node_by_prams] File batch {batch_idx + 1}/{total_file_batches}: delete_query: {delete_query}" - ) + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count += deleted_count + + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes by file_ids") elapsed_time = time.time() - batch_start_time logger.info( @@ -5109,6 +5076,7 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[ - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) - 'exist_user_names': List of distinct user names (if all memory_ids exist) """ + logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") if not memory_ids: return {"exist_user_names": []} From 17afbe777a48628356110eadb985affe6bfa9f33 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Mon, 29 Dec 2025 17:38:19 +0800 Subject: [PATCH 17/48] feat: add get_user_names_by_memory_ids for polardb && neo4j (#803) --- src/memos/graph_dbs/neo4j.py | 64 ++++++++----------- src/memos/graph_dbs/polardb.py | 112 +++++++++++++++++++-------------- 2 files changed, 90 insertions(+), 86 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 2b3859252..d57e7c596 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1811,67 +1811,53 @@ def delete_node_by_prams( logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes") return deleted_count - def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: """Get user names by memory ids. Args: memory_ids: List of memory node IDs to query. Returns: - dict[str, list[str]]: Dictionary with one key: - - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) - - 'exist_user_names': List of distinct user names (if all memory_ids exist) + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} """ if not memory_ids: - return {"exist_user_names": []} + return {} - logger.info(f"[get_user_names_by_memory_ids] Checking {len(memory_ids)} memory_ids") + logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") try: with self.driver.session(database=self.db_name) as session: - # Query to check which memory_ids exist - check_query = """ + # Query to get memory_id and user_name pairs + query = """ MATCH (n:Memory) WHERE n.id IN $memory_ids - RETURN n.id AS id + RETURN n.id AS memory_id, n.user_name AS user_name """ + logger.info(f"[get_user_names_by_memory_ids] query: {query}") - check_result = session.run(check_query, memory_ids=memory_ids) - existing_ids = set() - for record in check_result: - node_id = record["id"] - existing_ids.add(node_id) + result = session.run(query, memory_ids=memory_ids) + result_dict = {} - # Check if any memory_ids are missing - no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] - - # If any memory_ids are missing, return no_exist_memory_ids - if no_exist_list: - logger.info( - f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" - ) - return {"no_exist_memory_ids": no_exist_list} - - # All memory_ids exist, query user_names - user_names_query = """ - MATCH (n:Memory) - WHERE n.id IN $memory_ids - RETURN DISTINCT n.user_name AS user_name - """ - logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") - - user_names_result = session.run(user_names_query, memory_ids=memory_ids) - user_names = [] - for record in user_names_result: + # Build result dictionary from query results + for record in result: + memory_id = record["memory_id"] user_name = record["user_name"] - if user_name: - user_names.append(user_name) + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in memory_ids: + if mid not in result_dict: + result_dict[mid] = None logger.info( - f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" ) - return {"exist_user_names": user_names} + return result_dict except Exception as e: logger.error( f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f88824493..8eb3e4ece 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5065,86 +5065,104 @@ def delete_node_by_prams( return total_deleted_count @timed - def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]: + def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]: """Get user names by memory ids. Args: memory_ids: List of memory node IDs to query. Returns: - dict[str, list[str]]: Dictionary with one key: - - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing) - - 'exist_user_names': List of distinct user names (if all memory_ids exist) + dict[str, str | None]: Dictionary mapping memory_id to user_name. + - Key: memory_id + - Value: user_name if exists, None if memory_id does not exist + Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None} """ logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}") if not memory_ids: - return {"exist_user_names": []} + return {} + + # Validate and normalize memory_ids + # Ensure all items are strings + normalized_memory_ids = [] + for mid in memory_ids: + if not isinstance(mid, str): + mid = str(mid) + # Remove any whitespace + mid = mid.strip() + if mid: + normalized_memory_ids.append(mid) + + if not normalized_memory_ids: + return {} + + # Escape special characters for JSON string format in agtype + def escape_memory_id(mid: str) -> str: + """Escape special characters in memory_id for JSON string format.""" + # Escape backslashes first, then double quotes + mid_str = mid.replace("\\", "\\\\") + mid_str = mid_str.replace('"', '\\"') + return mid_str # Build OR conditions for each memory_id id_conditions = [] - for mid in memory_ids: + for mid in normalized_memory_ids: + # Escape special characters + escaped_mid = escape_memory_id(mid) id_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype" ) where_clause = f"({' OR '.join(id_conditions)})" - # Query to check which memory_ids exist - check_query = f""" - SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text + # Query to get memory_id and user_name pairs + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text AS memory_id, + ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text AS user_name FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ - logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}") + logger.info(f"[get_user_names_by_memory_ids] query: {query}") conn = None + result_dict = {} try: conn = self._get_connection() with conn.cursor() as cursor: - # Check which memory_ids exist - cursor.execute(check_query) - check_results = cursor.fetchall() - existing_ids = set() - for row in check_results: - node_id = row[0] - # Remove quotes if present - if isinstance(node_id, str): - node_id = node_id.strip('"').strip("'") - existing_ids.add(node_id) + cursor.execute(query) + results = cursor.fetchall() - # Check if any memory_ids are missing - no_exist_list = [mid for mid in memory_ids if mid not in existing_ids] + # Build result dictionary from query results + for row in results: + memory_id_raw = row[0] + user_name_raw = row[1] - # If any memory_ids are missing, return no_exist_memory_ids - if no_exist_list: - logger.info( - f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}" - ) - return {"no_exist_memory_ids": no_exist_list} + # Remove quotes if present + if isinstance(memory_id_raw, str): + memory_id = memory_id_raw.strip('"').strip("'") + else: + memory_id = str(memory_id_raw).strip('"').strip("'") - # All memory_ids exist, query user_names - user_names_query = f""" - SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text - FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}") + if isinstance(user_name_raw, str): + user_name = user_name_raw.strip('"').strip("'") + else: + user_name = ( + str(user_name_raw).strip('"').strip("'") if user_name_raw else None + ) - cursor.execute(user_names_query) - results = cursor.fetchall() - user_names = [] - for row in results: - user_name = row[0] - # Remove quotes if present - if isinstance(user_name, str): - user_name = user_name.strip('"').strip("'") - user_names.append(user_name) + result_dict[memory_id] = user_name if user_name else None + + # Set None for memory_ids that were not found + for mid in normalized_memory_ids: + if mid not in result_dict: + result_dict[mid] = None logger.info( - f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names" + f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, " + f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names" ) - return {"exist_user_names": user_names} + return result_dict except Exception as e: logger.error( f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True From 9c25b46154ce8345dbe3b609d6157843cd86445b Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 30 Dec 2025 10:08:47 +0800 Subject: [PATCH 18/48] feat: add export_graph total (#804) * feat: add export_graph total * feat: add export_graph total --- src/memos/graph_dbs/neo4j.py | 47 ++++++++++----- src/memos/graph_dbs/polardb.py | 101 ++++++++++++++++++++++----------- 2 files changed, 103 insertions(+), 45 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index d57e7c596..c2dc4a629 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1150,11 +1150,17 @@ def export_graph( Returns: { "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], + "total_nodes": int, # Total number of nodes matching the filter criteria + "total_edges": int, # Total number of edges matching the filter criteria } """ user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name + # Initialize total counts + total_nodes = 0 + total_edges = 0 + # Determine if pagination is needed use_pagination = page is not None and page_size is not None @@ -1167,28 +1173,38 @@ def export_graph( skip = (page - 1) * page_size with self.driver.session(database=self.db_name) as session: - # Export nodes - node_query = "MATCH (n:Memory)" - edge_query = "MATCH (a:Memory)-[r]->(b:Memory)" + # Build base queries + node_base_query = "MATCH (n:Memory)" + edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)" params = {} if not self.config.use_multi_db and (self.config.user_name or user_name): - node_query += " WHERE n.user_name = $user_name" - edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" + node_base_query += " WHERE n.user_name = $user_name" + edge_base_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name" params["user_name"] = user_name - # Add ORDER BY and pagination for nodes - node_query += " RETURN n ORDER BY n.id" + # Get total count of nodes before pagination + count_node_query = node_base_query + " RETURN COUNT(n) AS count" + count_node_result = session.run(count_node_query, params) + total_nodes = count_node_result.single()["count"] + + # Export nodes with ORDER BY created_at DESC + node_query = node_base_query + " RETURN n ORDER BY n.created_at DESC, n.id DESC" if use_pagination: node_query += f" SKIP {skip} LIMIT {page_size}" node_result = session.run(node_query, params) nodes = [self._parse_node(dict(record["n"])) for record in node_result] - # Export edges - # Add ORDER BY and pagination for edges - edge_query += ( - " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.id, b.id" + # Get total count of edges before pagination + count_edge_query = edge_base_query + " RETURN COUNT(r) AS count" + count_edge_result = session.run(count_edge_query, params) + total_edges = count_edge_result.single()["count"] + + # Export edges with ORDER BY created_at DESC + edge_query = ( + edge_base_query + + " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.created_at DESC, b.created_at DESC, a.id DESC, b.id DESC" ) if use_pagination: edge_query += f" SKIP {skip} LIMIT {page_size}" @@ -1199,7 +1215,12 @@ def export_graph( for record in edge_result ] - return {"nodes": nodes, "edges": edges} + return { + "nodes": nodes, + "edges": edges, + "total_nodes": total_nodes, + "total_edges": total_edges, + } def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: """ diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8eb3e4ece..8eabda6d8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2522,7 +2522,9 @@ def export_graph( Returns: { "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] + "edges": [ { "source": ..., "target": ..., "type": ... }, ... ], + "total_nodes": int, # Total number of nodes matching the filter criteria + "total_edges": int, # Total number of edges matching the filter criteria } """ logger.info( @@ -2530,6 +2532,10 @@ def export_graph( ) user_id = user_id if user_id else self._get_config_value("user_id") + # Initialize total counts + total_nodes = 0 + total_edges = 0 + # Determine if pagination is needed use_pagination = page is not None and page_size is not None @@ -2546,12 +2552,6 @@ def export_graph( conn = None try: conn = self._get_connection() - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - # Build WHERE conditions where_conditions = [] if user_name: @@ -2567,12 +2567,30 @@ def export_graph( if where_conditions: where_clause = f"WHERE {' AND '.join(where_conditions)}" + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + {where_clause} + """ + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + if include_embedding: node_query = f""" SELECT id, properties, embedding FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY id + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC {pagination_clause} """ else: @@ -2580,7 +2598,8 @@ def export_graph( SELECT id, properties FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY id + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC {pagination_clause} """ logger.info(f"[export_graph nodes] Query: {node_query}") @@ -2591,9 +2610,11 @@ def export_graph( for row in node_results: if include_embedding: - properties_json, embedding_json = row + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row else: - properties_json = row + """row is (id, properties)""" + _, properties_json = row embedding_json = None # Parse properties from JSONB if it's a string @@ -2605,20 +2626,13 @@ def export_graph( else: properties = properties_json if properties_json else {} - # # Build node data - - """ - # node_data = { - # "id": properties.get("id", node_id), - # "memory": properties.get("memory", ""), - # "metadata": properties - # } - """ - - if include_embedding and embedding_json is not None: + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: properties["embedding"] = embedding_json - nodes.append(self._parse_node(json.loads(properties[1]))) + nodes.append(self._parse_node(properties)) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) @@ -2629,13 +2643,6 @@ def export_graph( conn = None try: conn = self._get_connection() - # Export edges using cypher query - # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery - # Build pagination clause if needed - edge_pagination_clause = "" - if use_pagination: - edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - # Build Cypher WHERE conditions for edges cypher_where_conditions = [] if user_name: @@ -2649,13 +2656,38 @@ def export_graph( if cypher_where_conditions: cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" + # Get total count of edges before pagination + count_edge_query = f""" + SELECT COUNT(*) + FROM ( + SELECT * FROM cypher('{self.db_name}_graph', $$ + MATCH (a:Memory)-[r]->(b:Memory) + {cypher_where_clause} + RETURN a.id AS source, b.id AS target, type(r) as edge + $$) AS (source agtype, target agtype, edge agtype) + ) AS edges + """ + logger.info(f"[export_graph edges count] Query: {count_edge_query}") + with conn.cursor() as cursor: + cursor.execute(count_edge_query) + total_edges = cursor.fetchone()[0] + + # Export edges using cypher query + # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery + # Build pagination clause if needed + edge_pagination_clause = "" + if use_pagination: + edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + edge_query = f""" SELECT source, target, edge FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (a:Memory)-[r]->(b:Memory) {cypher_where_clause} RETURN a.id AS source, b.id AS target, type(r) as edge - ORDER BY a.id, b.id + ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, + COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, + a.id DESC, b.id DESC $$) AS (source agtype, target agtype, edge agtype) ) AS edges {edge_pagination_clause} @@ -2726,7 +2758,12 @@ def export_graph( finally: self._return_connection(conn) - return {"nodes": nodes, "edges": edges} + return { + "nodes": nodes, + "edges": edges, + "total_nodes": total_nodes, + "total_edges": total_edges, + } @timed def count_nodes(self, scope: str, user_name: str | None = None) -> int: From 2ee0754bf07e655965e5d4bc594473507f560737 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Tue, 30 Dec 2025 10:58:55 +0800 Subject: [PATCH 19/48] add: get_memory return edges and count of items (#805) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback * patch: get_memory func filter user id and make page chunk * add total num * add total num --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/handlers/memory_handler.py | 25 +++++++++++++++++++++--- src/memos/memories/textual/tree.py | 4 ++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 2a99d912c..d05da19db 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -185,7 +185,12 @@ def handle_get_memories( user_id=get_mem_req.user_id, page=get_mem_req.page, page_size=get_mem_req.page_size, - )["nodes"] + ) + total_nodes = memories["total_nodes"] + total_edges = memories["total_edges"] + del memories["total_nodes"] + del memories["total_edges"] + preferences: list[TextualMemoryItem] = [] if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} @@ -195,11 +200,25 @@ def handle_get_memories( filter_params["mem_cube_id"] = get_mem_req.mem_cube_id preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) preferences = [format_memory_item(mem) for mem in preferences] + return GetMemoryResponse( message="Memories retrieved successfully", data={ - "text_mem": [{"cube_id": get_mem_req.mem_cube_id, "memories": memories}], - "pref_mem": [{"cube_id": get_mem_req.mem_cube_id, "memories": preferences}], + "text_mem": [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": memories, + "total_nodes": total_nodes, + "total_edges": total_edges, + } + ], + "pref_mem": [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": preferences, + "total_nodes": len(preferences), + } + ], }, ) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 764ceee67..e576c0ea9 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -332,10 +332,10 @@ def get_all( Returns: list[TextualMemoryItem]: List of all memories. """ - all_items = self.graph_store.export_graph( + graph_output = self.graph_store.export_graph( user_name=user_name, user_id=user_id, page=page, page_size=page_size ) - return all_items + return graph_output def delete(self, memory_ids: list[str], user_name: str | None = None) -> None: """Hard delete: permanently remove nodes and their edges from the graph.""" From 63987d5f222d3b8b1902d53ed59c9ecd76fc7b22 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Tue, 30 Dec 2025 11:09:05 +0800 Subject: [PATCH 20/48] Scheduler: address some issues to run old scheduler example and kv cache example (#797) * fix bugs: try to fix bugs in _submit_web_logs * fix bugs: try to address bugs * fix bugs * refactor: modify examples * revise add operation and fix an unbelievable bug * address the bug issues * the doc file has a format problem which has been fixed in this commit * add a range of new feats for the add operation * address the incompatible issue of local scheduler * feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. * fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. * fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. * refactor add_before_search from mem_reader to SingleCubeView * address bugs * fix: fix the qsize bug of task queue, and accept change from hotfix/scheduler * fix: address some issues to run old scheduler example and kv cache example * fix: address the issue of Top-level import of unavailable module 'torch' * fix: resolve linting errors and make optional dependencies lazy loaded - Fix ambiguous characters and commented-out code in examples/mem_scheduler/quick_start_examples.py - Fix nested if statements in src/memos/mem_os/core.py - Move torch and transformers imports to method scope in src/memos/llms/hf.py to support optional dependencies - Update tests/llms/test_hf.py to patch transformers module directly * refactor: revise the rewrite prompt to make it better * refactor: update examples --- .../config/mem_scheduler/mem_cube_config.yaml | 21 ++ .../memos_config_w_scheduler.yaml | 12 +- .../mem_scheduler/quick_start_examples.py | 253 ++++++++++++++++++ src/memos/llms/hf.py | 31 ++- src/memos/mem_os/core.py | 6 +- src/memos/mem_os/main.py | 2 +- src/memos/mem_os/utils/format_utils.py | 82 ++++-- src/memos/mem_reader/simple_struct.py | 2 +- src/memos/mem_scheduler/base_scheduler.py | 42 ++- .../general_modules/scheduler_logger.py | 30 ++- src/memos/mem_scheduler/general_scheduler.py | 65 +++-- .../mem_scheduler/monitors/general_monitor.py | 22 +- .../task_schedule_modules/dispatcher.py | 3 - .../task_schedule_modules/local_queue.py | 111 +++++--- .../task_schedule_modules/redis_queue.py | 2 +- src/memos/memories/activation/kv.py | 32 ++- src/memos/templates/mem_reader_prompts.py | 35 ++- tests/llms/test_hf.py | 4 +- 18 files changed, 618 insertions(+), 137 deletions(-) create mode 100644 examples/data/config/mem_scheduler/mem_cube_config.yaml create mode 100644 examples/mem_scheduler/quick_start_examples.py diff --git a/examples/data/config/mem_scheduler/mem_cube_config.yaml b/examples/data/config/mem_scheduler/mem_cube_config.yaml new file mode 100644 index 000000000..398d8dbb3 --- /dev/null +++ b/examples/data/config/mem_scheduler/mem_cube_config.yaml @@ -0,0 +1,21 @@ +user_id: "user_test" +cube_id: "user_test/mem_cube_naive" +text_mem: + backend: "naive_text" + config: + extractor_llm: + backend: "huggingface_singleton" + config: + model_name_or_path: "Qwen/Qwen3-0.6B" + temperature: 0.1 + max_tokens: 1024 +act_mem: + backend: "kv_cache" + config: + memory_filename: "activation_memory.pickle" + extractor_llm: + backend: "huggingface_singleton" + config: + model_name_or_path: "Qwen/Qwen3-0.6B" + temperature: 0.8 + max_tokens: 1024 diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml index bd9910300..a5e91dc4e 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml @@ -10,16 +10,12 @@ mem_reader: backend: "simple_struct" config: llm: - backend: "openai" + backend: "huggingface_singleton" config: - model_name_or_path: "gpt-4o-mini" - temperature: 0.8 - max_tokens: 4096 - top_p: 0.9 - top_k: 50 + model_name_or_path: "Qwen/Qwen3-1.7B" + temperature: 0.1 remove_think_prefix: true - api_key: "sk-xxxxxx" - api_base: "https://api.openai.com/v1" + max_tokens: 4096 embedder: backend: "ollama" config: diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py new file mode 100644 index 000000000..c71869e76 --- /dev/null +++ b/examples/mem_scheduler/quick_start_examples.py @@ -0,0 +1,253 @@ +import json +import shutil +import sys +import uuid + +from pathlib import Path + +from transformers import DynamicCache + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.configs.memory import MemoryConfigFactory +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.main import MOS +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import ( + ANSWER_TASK_LABEL, + MEM_UPDATE_TASK_LABEL, + QUERY_TASK_LABEL, +) +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import parse_yaml +from memos.memories.activation.item import KVCacheItem +from memos.memories.factory import MemoryFactory + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +def get_cache_info(cache): + if not cache: + return None + + num_layers = 0 + total_size_bytes = 0 + + if hasattr(cache, "layers"): + num_layers = len(cache.layers) + for layer in cache.layers: + if hasattr(layer, "key_cache") and layer.key_cache is not None: + total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size() + if hasattr(layer, "value_cache") and layer.value_cache is not None: + total_size_bytes += layer.value_cache.nelement() * layer.value_cache.element_size() + + if hasattr(layer, "keys") and layer.keys is not None: + total_size_bytes += layer.keys.nelement() * layer.keys.element_size() + if hasattr(layer, "values") and layer.values is not None: + total_size_bytes += layer.values.nelement() * layer.values.element_size() + + elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"): + num_layers = len(cache.key_cache) + for k, v in zip(cache.key_cache, cache.value_cache, strict=False): + if k is not None: + total_size_bytes += k.nelement() * k.element_size() + if v is not None: + total_size_bytes += v.nelement() * v.element_size() + + return { + "num_layers": num_layers, + "size_bytes": total_size_bytes, + "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB", + } + + +def serialize_item(obj): + if isinstance(obj, list): + return [serialize_item(x) for x in obj] + + if isinstance(obj, KVCacheItem): + return { + "id": obj.id, + "metadata": obj.metadata, + "records": obj.records.model_dump() + if hasattr(obj.records, "model_dump") + else obj.records, + "memory": get_cache_info(obj.memory), + } + + if isinstance(obj, DynamicCache): + return get_cache_info(obj) + + return str(obj) + + +def kv_cache_only(): + # 为 KVCacheMemory(HuggingFace 后端)创建配置 + config = MemoryConfigFactory( + backend="kv_cache", + config={ + "extractor_llm": { + "backend": "huggingface", + "config": { + "model_name_or_path": "Qwen/Qwen3-0.6B", + "max_tokens": 32, + "add_generation_prompt": True, + "remove_think_prefix": True, + }, + }, + }, + ) + + # 实例化 KVCacheMemory + kv_mem = MemoryFactory.from_config(config) + + # 提取一个 KVCacheItem(DynamicCache) + prompt = [ + {"role": "user", "content": "What is MemOS?"}, + {"role": "assistant", "content": "MemOS is a memory operating system for LLMs."}, + ] + print("===== Extract KVCacheItem =====") + cache_item = kv_mem.extract(prompt) + print(json.dumps(serialize_item(cache_item), indent=2, default=str)) + + # 将缓存添加到内存中 + kv_mem.add([cache_item]) + print("All caches:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + # 通过 ID 获取 + retrieved = kv_mem.get(cache_item.id) + print("Retrieved:") + print(json.dumps(serialize_item(retrieved), indent=2, default=str)) + + # 合并缓存 + item2 = kv_mem.extract([{"role": "user", "content": "Tell me a joke."}]) + kv_mem.add([item2]) + merged = kv_mem.get_cache([cache_item.id, item2.id]) + print("Merged cache:") + print(json.dumps(serialize_item(merged), indent=2, default=str)) + + # 删除其中一个 + kv_mem.delete([cache_item.id]) + print("After delete:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + # 导出和加载缓存 + kv_mem.dump("tmp/kv_mem") + print("Dumped to tmp/kv_mem") + kv_mem.delete_all() + kv_mem.load("tmp/kv_mem") + print("Loaded caches:") + print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str)) + + +def run_scheduler_example(): + # 使用 MemScheduler 加载主 MOS 配置 + config = parse_yaml( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + ) + mos_config = MOSConfig(**config) + mos = MOS(mos_config) + + # 创建动态用户 ID + user_id = str(uuid.uuid4()) + mos.create_user(user_id=user_id) + + # 创建 MemCube 配置并导出 + config = GeneralMemCubeConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + ) + mem_cube_id = "mem_cube_5" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + + # 若存在旧目录则删除 + if Path(mem_cube_name_or_path).exists(): + shutil.rmtree(mem_cube_name_or_path) + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + + # 导出新的 MemCube + mem_cube = GeneralMemCube(config) + mem_cube.dump(mem_cube_name_or_path) + + # 为该用户注册 MemCube + mos.register_mem_cube( + mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id + ) + + # Define custom scheduler handlers + def custom_query_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + print(f"\n[scheduler] 用户输入了query: {msg.content}") + # Trigger mem_update manually + new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL}) + mos.mem_scheduler.submit_messages([new_msg]) + + def custom_answer_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + kv_mem = mem_cube.act_mem + for cache_item in kv_mem.get_all(): + print( + f"[scheduler] act memory: {get_cache_info(cache_item.memory)} ({cache_item.records})" + ) + print(f"\n[scheduler] LLM回复了answer:{msg.content}") + + def custom_mem_update_handler(messages: list[ScheduleMessageItem]): + for msg in messages: + mem_cube = mos.mem_cubes.get(msg.mem_cube_id) + kv_mem = mem_cube.act_mem + if mem_cube and mem_cube.text_mem: + results = mem_cube.text_mem.search(msg.content, top_k=3) + for mem in results: + print(f"\n[scheduler] searched memories: {mem.memory}") + + cache_item = kv_mem.extract(mem.memory) + cache_item.records.text_memories = [mem.memory] + cache_item.records.timestamp = get_utc_now() + kv_mem.add([cache_item]) + + # Register custom handlers + mos.mem_scheduler.dispatcher.register_handlers( + { + QUERY_TASK_LABEL: custom_query_handler, + ANSWER_TASK_LABEL: custom_answer_handler, + MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, + } + ) + + # 添加消息 + messages = [ + {"role": "user", "content": "I like playing football."}, + {"role": "assistant", "content": "I like playing football too."}, + ] + mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id) + + # 聊天循环: 展示 TreeTextMemory 节点 + KVCache + while True: + user_input = input("👤 [You] ").strip() + print() + response = mos.chat(user_input, user_id=user_id) + retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id) + + print(f"🤖 [Assistant] {response}") + + # 展示 TreeTextMemory 中的各类型节点 + text_memories = retrieved_memories["text_mem"][0]["memories"] + # Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes) + if isinstance(text_memories, dict) and "nodes" in text_memories: + for node in text_memories["nodes"]: + mem_type = node["metadata"].get("memory_type", "Unknown") + print(f"[{mem_type}] {node['memory']}") + elif isinstance(text_memories, list): + for mem in text_memories: + # Naive memory items might not have memory_type metadata, or it might be different + print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}") + + +if __name__ == "__main__": + kv_cache_only() + + run_scheduler_example() diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index d46db7c9e..b5fc4ba13 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -2,13 +2,7 @@ from typing import Any from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, DynamicCache, - LogitsProcessorList, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, ) from memos.configs.llm import HFLLMConfig @@ -30,6 +24,17 @@ def __init__(self, config: HFLLMConfig): """ Initialize the HFLLM model and tokenizer, and set up logits processors for sampling. """ + import torch + + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LogitsProcessorList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) + self.config = config # Default model if not specified @@ -37,9 +42,14 @@ def __init__(self, config: HFLLMConfig): self.config.model_name_or_path = "Qwen/Qwen3-1.7B" # Initialize hf model - self.model = AutoModelForCausalLM.from_pretrained( - self.config.model_name_or_path, torch_dtype="auto", device_map="auto" - ) + if torch.backends.mps.is_available(): + self.model = AutoModelForCausalLM.from_pretrained( + self.config.model_name_or_path, torch_dtype="auto" + ).to("mps") + else: + self.model = AutoModelForCausalLM.from_pretrained( + self.config.model_name_or_path, torch_dtype="auto", device_map="auto" + ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_name_or_path, use_fast=True ) @@ -355,6 +365,7 @@ def build_kv_cache(self, messages) -> DynamicCache: DynamicCache: The constructed KV cache object. """ import torch + import transformers # Accept multiple input types and convert to standard chat messages if isinstance(messages, str): @@ -391,7 +402,7 @@ def build_kv_cache(self, messages) -> DynamicCache: # Convert from legacy tuple format to DynamicCache if needed if isinstance(kv, tuple): - kv = DynamicCache.from_legacy_cache(kv) + kv = transformers.DynamicCache.from_legacy_cache(kv) # Handle compatibility between old and new transformers versions # In newer versions, DynamicCache uses 'layers' attribute diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 1a88fa831..e7f01ec3e 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -311,7 +311,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = past_key_values = None if self.config.enable_activation_memory: - if self.config.chat_model.backend != "huggingface": + if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]: logger.error( "Activation memory only used for huggingface backend. Skipping activation memory." ) @@ -498,7 +498,9 @@ def register_mem_cube( existing_cube = self.user_manager.get_cube(mem_cube_id) # check the embedder is it consistent with MOSConfig - if self.config.mem_reader.config.embedder != ( + if hasattr( + self.mem_cubes[mem_cube_id].text_mem.config, "embedder" + ) and self.config.mem_reader.config.embedder != ( cube_embedder := self.mem_cubes[mem_cube_id].text_mem.config.embedder ): logger.warning( diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 0114fc0da..0dc6ab209 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -310,7 +310,7 @@ def _generate_enhanced_response_with_context( # Handle activation memory if enabled (same as core method) past_key_values = None if self.config.enable_activation_memory: - if self.config.chat_model.backend != "huggingface": + if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]: logger.error( "Activation memory only used for huggingface backend. Skipping activation memory." ) diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 5fdb59058..f6e33bb31 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -1087,38 +1087,64 @@ def convert_activation_memory_to_serializable( serializable_items = [] for item in act_mem_items: + key_layers = 0 + val_layers = 0 + device = "unknown" + dtype = "unknown" + key_shapes = [] + value_shapes = [] + + if item.memory: + if hasattr(item.memory, "layers"): + key_layers = len(item.memory.layers) + val_layers = len(item.memory.layers) + if key_layers > 0: + l0 = item.memory.layers[0] + k0 = getattr(l0, "key_cache", getattr(l0, "keys", None)) + if k0 is not None: + device = str(k0.device) + dtype = str(k0.dtype) + + for i, layer in enumerate(item.memory.layers): + k = getattr(layer, "key_cache", getattr(layer, "keys", None)) + v = getattr(layer, "value_cache", getattr(layer, "values", None)) + if k is not None: + key_shapes.append({"layer": i, "shape": list(k.shape)}) + if v is not None: + value_shapes.append({"layer": i, "shape": list(v.shape)}) + + elif hasattr(item.memory, "key_cache"): + key_layers = len(item.memory.key_cache) + val_layers = len(item.memory.value_cache) + if key_layers > 0 and item.memory.key_cache[0] is not None: + device = str(item.memory.key_cache[0].device) + dtype = str(item.memory.key_cache[0].dtype) + + for i, key_tensor in enumerate(item.memory.key_cache): + if key_tensor is not None: + key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) + + for i, val_tensor in enumerate(item.memory.value_cache): + if val_tensor is not None: + value_shapes.append({"layer": i, "shape": list(val_tensor.shape)}) + # Extract basic information that can be serialized serializable_item = { "id": item.id, "metadata": item.metadata, "memory_info": { "type": "DynamicCache", - "key_cache_layers": len(item.memory.key_cache) if item.memory else 0, - "value_cache_layers": len(item.memory.value_cache) if item.memory else 0, - "device": str(item.memory.key_cache[0].device) - if item.memory and item.memory.key_cache - else "unknown", - "dtype": str(item.memory.key_cache[0].dtype) - if item.memory and item.memory.key_cache - else "unknown", + "key_cache_layers": key_layers, + "value_cache_layers": val_layers, + "device": device, + "dtype": dtype, }, } # Add tensor shape information if available - if item.memory and item.memory.key_cache: - key_shapes = [] - value_shapes = [] - - for i, key_tensor in enumerate(item.memory.key_cache): - if key_tensor is not None: - key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) - - if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None: - value_shapes.append( - {"layer": i, "shape": list(item.memory.value_cache[i].shape)} - ) - + if key_shapes: serializable_item["memory_info"]["key_shapes"] = key_shapes + if value_shapes: serializable_item["memory_info"]["value_shapes"] = value_shapes serializable_items.append(serializable_item) @@ -1144,7 +1170,19 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[ total_parameters = 0 for item in act_mem_items: - if item.memory and item.memory.key_cache: + if not item.memory: + continue + + if hasattr(item.memory, "layers"): + total_layers += len(item.memory.layers) + for layer in item.memory.layers: + k = getattr(layer, "key_cache", getattr(layer, "keys", None)) + v = getattr(layer, "value_cache", getattr(layer, "values", None)) + if k is not None: + total_parameters += k.numel() + if v is not None: + total_parameters += v.numel() + elif hasattr(item.memory, "key_cache"): total_layers += len(item.memory.key_cache) # Calculate approximate parameter count diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 70472958e..61a7d2b6d 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -618,7 +618,7 @@ def _read_memory( messages=combined_messages, memory_list=original_memory_group, user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() - == "true", + == "false", ) serialized_revised_memories = json.dumps( [one.memory for one in revised_memory_list], indent=2 diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 728203f5b..3f5c90b67 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -74,6 +74,7 @@ from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory +from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE @@ -198,13 +199,16 @@ def init_mem_cube( logger.error("mem_cube is None, cannot initialize", stack_info=True) self.mem_cube = mem_cube self.text_mem: TreeTextMemory = self.mem_cube.text_mem - self.reranker: HTTPBGEReranker = self.text_mem.reranker + self.reranker: HTTPBGEReranker = getattr(self.text_mem, "reranker", None) if searcher is None: - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, - process_llm=self.process_llm, - ) + if hasattr(self.text_mem, "get_searcher"): + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + process_llm=self.process_llm, + ) + else: + self.searcher = None else: self.searcher = searcher self.feedback_server = feedback_server @@ -540,6 +544,29 @@ def replace_working_memory( mem_cube=mem_cube, log_func_callback=self._submit_web_logs, ) + elif isinstance(text_mem_base, NaiveTextMemory): + # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up + logger.info( + f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates." + ) + + # Use query keywords if available, otherwise just basic monitoring + query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id] + query_db_manager.sync_with_orm() + query_keywords = query_db_manager.obj.get_keywords_collections() + + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=new_memory, + ) + + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + memories_with_new_order = new_memory else: logger.error("memory_base is not supported") memories_with_new_order = new_memory @@ -1008,6 +1035,9 @@ def _monitor_loop(self): try: q_sizes = self.memos_message_queue.qsize() + if not isinstance(q_sizes, dict): + continue + for stream_key, queue_length in q_sizes.items(): # Skip aggregate keys like 'total_size' if stream_key == "total_size": diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 57d78676f..fd83ec86f 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -55,7 +55,11 @@ def create_autofilled_log_item( "mem_cube is None — this should not happen in production!", stack_info=True ) text_mem_base: TreeTextMemory = mem_cube.text_mem - current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) + + current_memory_sizes = {} + if hasattr(text_mem_base, "get_current_memory_size"): + current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id) + current_memory_sizes = { "long_term_memory_size": current_memory_sizes.get("LongTermMemory", 0), "user_memory_size": current_memory_sizes.get("UserMemory", 0), @@ -63,14 +67,32 @@ def create_autofilled_log_item( "transformed_act_memory_size": NOT_INITIALIZED, "parameter_memory_size": NOT_INITIALIZED, } + memory_capacities = { - "long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"], - "user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"], - "working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"], + "long_term_memory_capacity": 0, + "user_memory_capacity": 0, + "working_memory_capacity": 0, "transformed_act_memory_capacity": NOT_INITIALIZED, "parameter_memory_capacity": NOT_INITIALIZED, } + if hasattr(text_mem_base, "memory_manager") and hasattr( + text_mem_base.memory_manager, "memory_size" + ): + memory_capacities.update( + { + "long_term_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "LongTermMemory", 0 + ), + "user_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "UserMemory", 0 + ), + "working_memory_capacity": text_mem_base.memory_manager.memory_size.get( + "WorkingMemory", 0 + ), + } + ) + if hasattr(self, "monitor"): if ( user_id in self.monitor.activation_memory_monitors diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 86066f346..9b19e9ecb 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -34,6 +34,7 @@ is_cloud_env, ) from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.naive import NaiveTextMemory from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory from memos.types import ( @@ -846,7 +847,9 @@ def _process_memories_with_reader( memory_item = text_mem.get(mem_id, user_name=user_name) memory_items.append(memory_item) except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") + logger.warning( + f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}" + ) continue if not memory_items: @@ -1364,22 +1367,31 @@ def process_session_turn( text_mem_base = mem_cube.text_mem if not isinstance(text_mem_base, TreeTextMemory): - logger.error( - f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " - f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " - f"text_mem_base value: {text_mem_base}", - exc_info=True, + if isinstance(text_mem_base, NaiveTextMemory): + logger.debug( + f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search." + ) + # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic + # We will perform retrieval to get "working memory" candidates for activation memory + # But we won't have a distinct "current working memory" + cur_working_memory = [] + else: + logger.warning( + f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} " + f"for mem_cube_id={mem_cube_id}, user_id={user_id}. " + f"text_mem_base value: {text_mem_base}" + ) + return [], [] + else: + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( + user_name=mem_cube_id ) - return + cur_working_memory = cur_working_memory[:top_k] logger.info( f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}" ) - cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory( - user_name=mem_cube_id - ) - cur_working_memory = cur_working_memory[:top_k] text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] intent_result = self.monitor.detect_intent( q_list=queries, text_working_memory=text_working_memory @@ -1419,15 +1431,28 @@ def process_session_turn( ) search_args = {} - results: list[TextualMemoryItem] = self.retriever.search( - query=item, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - top_k=k_per_evidence, - method=self.search_method, - search_args=search_args, - ) + if isinstance(text_mem_base, NaiveTextMemory): + # NaiveTextMemory doesn't support complex search args usually, but let's see + # self.retriever.search calls mem_cube.text_mem.search + # NaiveTextMemory.search takes query and top_k + # SchedulerRetriever.search handles method dispatch + # For NaiveTextMemory, we might need to bypass retriever or extend it + # But let's try calling naive memory directly if retriever fails or doesn't support it + try: + results = text_mem_base.search(query=item, top_k=k_per_evidence) + except Exception as e: + logger.warning(f"NaiveTextMemory search failed: {e}") + results = [] + else: + results: list[TextualMemoryItem] = self.retriever.search( + query=item, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + search_args=search_args, + ) logger.info( f"[process_session_turn] Search results for missing evidence '{item}': " diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index b097b1e2d..d75d6ee75 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -200,15 +200,19 @@ def update_working_memory_monitors( mem_cube_id: str, mem_cube: GeneralMemCube, ): - text_mem_base: TreeTextMemory = mem_cube.text_mem - assert isinstance(text_mem_base, TreeTextMemory) - self.working_mem_monitor_capacity = min( - DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, - ( - int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) - + self.partial_retention_number - ), - ) + text_mem_base = mem_cube.text_mem + + if isinstance(text_mem_base, TreeTextMemory): + self.working_mem_monitor_capacity = min( + DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, + ( + int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) + + self.partial_retention_number + ), + ) + else: + # Fallback for NaiveTextMemory and others + self.working_mem_monitor_capacity = DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT # register monitors self.register_memory_manager_if_not_exists( diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 35df3db64..e2c1621d4 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -128,9 +128,6 @@ def status_tracker(self) -> TaskStatusTracker | None: if self._status_tracker is None: try: self._status_tracker = TaskStatusTracker(self.redis) - # Propagate to submodules when created lazily - if self.dispatcher: - self.dispatcher.status_tracker = self._status_tracker if self.memos_message_queue: self.memos_message_queue.set_status_tracker(self._status_tracker) except Exception as e: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index eae70f8ef..791cedf41 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -4,9 +4,18 @@ the local memos_message_queue functionality in BaseScheduler. """ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator +from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -16,26 +25,38 @@ class SchedulerLocalQueue(RedisSchedulerModule): def __init__( self, - maxsize: int, + maxsize: int = 0, + stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX, + orchestrator: SchedulerOrchestrator | None = None, + status_tracker: TaskStatusTracker | None = None, ): """ Initialize the SchedulerLocalQueue with a maximum queue size limit. + Arguments match SchedulerRedisQueue for compatibility. Args: - maxsize (int): Maximum number of messages allowed - in each individual queue. - If exceeded, subsequent puts will block - or raise an exception based on `block` parameter. + maxsize (int): Maximum number of messages allowed in each individual queue. + stream_key_prefix (str): Prefix for stream keys (simulated). + orchestrator: SchedulerOrchestrator instance (ignored). + status_tracker: TaskStatusTracker instance (ignored). """ super().__init__() - self.stream_key_prefix = "local_queue" + self.stream_key_prefix = stream_key_prefix or "local_queue" self.max_internal_message_queue_size = maxsize + # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem] self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {} + + self.orchestrator = orchestrator + self.status_tracker = status_tracker + + self._is_listening = False + self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + logger.info( - f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" + f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}" ) def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: @@ -86,7 +107,7 @@ def get( stream_key: str, block: bool = True, timeout: float | None = None, - batch_size: int | None = None, + batch_size: int | None = 1, ) -> list[ScheduleMessageItem]: if batch_size is not None and batch_size <= 0: logger.warning( @@ -99,18 +120,19 @@ def get( logger.error(f"Stream {stream_key} does not exist when trying to get messages.") return [] + # Ensure we always request a batch so we get a list back + effective_batch_size = batch_size if batch_size is not None else 1 + # Note: Assumes custom Queue implementation supports batch_size parameter res = self.queue_streams[stream_key].get( - block=block, timeout=timeout, batch_size=batch_size + block=block, timeout=timeout, batch_size=effective_batch_size ) logger.debug( f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" ) return res - def get_nowait( - self, stream_key: str, batch_size: int | None = None - ) -> list[ScheduleMessageItem]: + def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]: """ Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size). @@ -170,35 +192,13 @@ def qsize(self) -> dict: logger.debug(f"Current queue sizes: {sizes}") return sizes - def size(self) -> int: - """ - Get the current size of the queue (total message count). - Compatible with SchedulerRedisQueue. - """ - return self.unfinished_tasks - - def empty(self) -> bool: - """ - Check if the queue is empty. - Compatible with SchedulerRedisQueue. - """ - return self.size() == 0 - - def full(self) -> bool: - """ - Check if the queue is full. - Compatible with SchedulerRedisQueue. - """ - # Local queue limits are per-stream (max_internal_message_queue_size). - # It is considered full only if all streams are full. - if not self.queue_streams: - return False - - return all(queue.full() for queue in self.queue_streams.values()) - - def clear(self) -> None: - for queue in self.queue_streams.values(): - queue.clear() + def clear(self, stream_key: str | None = None) -> None: + if stream_key: + if stream_key in self.queue_streams: + self.queue_streams[stream_key].clear() + else: + for queue in self.queue_streams.values(): + queue.clear() @property def unfinished_tasks(self) -> int: @@ -216,3 +216,32 @@ def unfinished_tasks(self) -> int: total = sum(queue.qsize() for queue in self.queue_streams.values()) logger.debug(f"Total unfinished tasks across all queues: {total}") return total + + def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: + """ + Return list of active stream keys. + """ + prefix = stream_key_prefix or self.stream_key_prefix + return [k for k in self.queue_streams if k.startswith(prefix)] + + def size(self) -> int: + """ + Total size of all queues. + """ + return sum(q.qsize() for q in self.queue_streams.values()) + + def empty(self) -> bool: + """ + Check if all queues are empty. + """ + return self.size() == 0 + + def full(self) -> bool: + """ + Check if any queue is full (approximate). + """ + if self.max_internal_message_queue_size <= 0: + return False + return any( + q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values() + ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 2f4318003..941c52164 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -787,7 +787,7 @@ def qsize(self) -> dict: Total number of messages across all matching streams. """ if not self._redis_conn: - return 0 + return {} total_size = 0 try: diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 98d611dbf..1981b958f 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -2,9 +2,7 @@ import pickle from datetime import datetime -from importlib.metadata import version -from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig @@ -211,10 +209,24 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: return caches[0] merged = DynamicCache() - num_layers = len(caches[0].key_cache) - if Version(version("transformers")) >= Version("4.54.0"): - merged.append_new_layers(num_layers - 1) + # Check for new structure (layers) + if hasattr(caches[0], "layers"): + num_layers = len(caches[0].layers) + + # Ensure merged has layers attribute and populate it + if not hasattr(merged, "layers"): + merged.layers = [] + + if num_layers > 0: + # Get the class of the layer from the first cache + # We assume all caches use the same layer class + layer_cls = type(caches[0].layers[0]) + + # Populate merged.layers + while len(merged.layers) < num_layers: + merged.layers.append(layer_cls()) + for layer in range(num_layers): # gather all K and V for this layer keys = [c.layers[layer].keys for c in caches] @@ -223,7 +235,10 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: merged.layers[layer].keys = torch.cat(keys, dim=-2) merged.layers[layer].values = torch.cat(vals, dim=-2) - else: + # Check for old structure (key_cache) + elif hasattr(caches[0], "key_cache"): + num_layers = len(caches[0].key_cache) + for layer in range(num_layers): # gather all K and V for this layer keys = [c.key_cache[layer] for c in caches] @@ -232,6 +247,11 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: merged.key_cache.append(torch.cat(keys, dim=-2)) merged.value_cache.append(torch.cat(vals, dim=-2)) + else: + raise AttributeError( + "DynamicCache object has neither 'layers' nor 'key_cache' attributes" + ) + return merged diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 40971c77e..26795a2b1 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -622,7 +622,7 @@ 专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。""" -SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """ You are a strict, language-preserving memory validator and rewriter. Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. @@ -655,6 +655,39 @@ Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. """ +SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """ +You are a strict, language-preserving memory validator and rewriter. + +Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content. + +Rules: +1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching. +2. **Strict Factual Grounding**: Include only what is explicitly stated by the user in messages marked as [user]. Remove or flag anything not directly present in the user’s utterances—no assumptions, interpretations, predictions, generalizations, or content originating solely from [assistant]. +3. **Source Attribution Requirement**: + - Every memory must be clearly traceable to its source: + - If a fact appears **only in [assistant] messages** and **is not affirmed by [user]**, label it as “[assistant] memory”. + - If [assistant] states something and [user] explicitly contradicts or denies it, label it as “[assistant] memory, but [user] [brief quote or summary of denial]”. + - If a fact is stated by [user] —whether or not [assistant] also mentions it— it is attributed to “[user]” and may be retained without qualification. +4. **Timestamp Exception**: Memories may include timestamps (e.g., "On December 19, 2026") derived from conversation metadata. If such a date likely reflects the conversation time (even if not in the `messages` list), do NOT treat it as hallucinated—but still attribute it to “[user]” only if the user mentioned or confirmed the date. + +Inputs: +messages: +{messages_inline} + +memories: +{memories_inline} + +Output Format: +- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices. +- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }} +- The "reason" must be brief and precise, e.g.: + - "contains unsupported inference from [assistant]" + - "[assistant] memory, but [user] said 'I don't have a dog'" + - "fully grounded in [user]" + +Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. +""" + SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """ You are a strict, language-preserving memory validator and rewriter. diff --git a/tests/llms/test_hf.py b/tests/llms/test_hf.py index 595995ad1..375bf2247 100644 --- a/tests/llms/test_hf.py +++ b/tests/llms/test_hf.py @@ -11,8 +11,8 @@ from memos.llms.hf import HFLLM -@patch("memos.llms.hf.AutoModelForCausalLM", MagicMock()) -@patch("memos.llms.hf.AutoTokenizer", MagicMock()) +@patch("transformers.AutoModelForCausalLM", MagicMock()) +@patch("transformers.AutoTokenizer", MagicMock()) class TestHFLLM(unittest.TestCase): def setUp(self): self.mock_inputs = MagicMock() From 01172f3de0aa93bf38339358960bcd191fa94443 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:09:35 +0800 Subject: [PATCH 21/48] add: milvus return data pagination (#806) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback * patch: get_memory func filter user id and make page chunk * add total num * add total num * add milvus pagination --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/handlers/memory_handler.py | 19 +++++- src/memos/memories/textual/preference.py | 38 ++++++++--- src/memos/vec_dbs/milvus.py | 86 +++++++++++++++++------- 3 files changed, 103 insertions(+), 40 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index d05da19db..941b59106 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -192,14 +192,26 @@ def handle_get_memories( del memories["total_edges"] preferences: list[TextualMemoryItem] = [] + total_explicit_nodes, total_implicit_nodes = 0, 0 if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} if get_mem_req.user_id is not None: filter_params["user_id"] = get_mem_req.user_id if get_mem_req.mem_cube_id is not None: filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) - preferences = [format_memory_item(mem) for mem in preferences] + preferences = naive_mem_cube.pref_mem.get_memory_by_filter( + filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + ) + + for key, value_list in preferences.items(): + if key in ["explicit_preference", "implicit_preference"]: + formatted_list = [format_memory_item(item) for item in value_list] + preferences[key] = formatted_list + + total_explicit_nodes = preferences["total_explicit_nodes"] + total_implicit_nodes = preferences["total_implicit_nodes"] + del preferences["total_explicit_nodes"] + del preferences["total_implicit_nodes"] return GetMemoryResponse( message="Memories retrieved successfully", @@ -216,7 +228,8 @@ def handle_get_memories( { "cube_id": get_mem_req.mem_cube_id, "memories": preferences, - "total_nodes": len(preferences), + "total_explicit_nodes": total_explicit_nodes, + "total_implicit_nodes": total_implicit_nodes, } ], }, diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 9e521158d..75d7d2a4c 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -261,7 +261,9 @@ def get_all(self) -> list[TextualMemoryItem]: ] return all_memories - def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[TextualMemoryItem]: + def get_memory_by_filter( + self, filter: dict[str, Any] | None = None, **kwargs + ) -> list[TextualMemoryItem]: """Get memories by filter. Args: filter (dict[str, Any]): Filter criteria. @@ -269,18 +271,32 @@ def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[Tex list[TextualMemoryItem]: List of memories that match the filter. """ collection_list = self.vector_db.config.collection_name - all_db_items = [] + + memories = {} + total_explicit_nodes = 0 + total_implicit_nodes = 0 for collection_name in collection_list: - db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter) - all_db_items.extend(db_items) - memories = [ - TextualMemoryItem( - id=memo.id, - memory=memo.memory, - metadata=PreferenceTextualMemoryMetadata(**memo.payload), + memories[collection_name] = [] + db_items, total_count = self.vector_db.get_by_filter( + collection_name=collection_name, filter=filter, count_total=True, **kwargs ) - for memo in all_db_items - ] + db_items_memory = [ + TextualMemoryItem( + id=memo.id, + memory=memo.memory, + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in db_items + ] + memories[collection_name].extend(db_items_memory) + + if collection_name == "explicit_preference": + total_explicit_nodes = total_count + if collection_name == "implicit_preference": + total_implicit_nodes = total_count + memories["total_explicit_nodes"] = total_explicit_nodes + memories["total_implicit_nodes"] = total_implicit_nodes + return memories def delete(self, memory_ids: list[str]) -> None: diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index ecbca5815..b0753b31d 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -493,7 +493,14 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt return items def get_by_filter( - self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 + self, + collection_name: str, + filter: dict[str, Any], + scroll_limit: int = 100, + page: int | None = None, + page_size: int | None = None, + count_total=False, + **kwargs, ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -506,47 +513,74 @@ def get_by_filter( List of items including vectors and payload that match the filter """ expr = self._dict_to_expr(filter) if filter else "" - all_items = [] + if count_total: + total_count = 0 + count_iterator = self.client.query_iterator( + collection_name=collection_name, + filter=expr, + batch_size=scroll_limit, + output_fields=["id"], + ) + try: + while True: + batch = count_iterator.next() + if not batch: + break + total_count += len(batch) + finally: + count_iterator.close() + + result = [] + skipped = 0 + needed = page_size - # Use query_iterator for efficient pagination iterator = self.client.query_iterator( collection_name=collection_name, filter=expr, batch_size=scroll_limit, - output_fields=["*"], # Include all fields including payload + output_fields=["*"], ) - # Iterate through all batches try: - while True: - batch_results = iterator.next() - - if not batch_results: + while needed > 0: + batch = iterator.next() + if not batch: break - # Convert batch results to MilvusVecDBItem objects - for entity in batch_results: - # Extract the actual payload from Milvus entity + for entity in batch: + skipped += 1 + + if skipped <= (page - 1) * page_size: + continue + payload = entity.get("payload", {}) - all_items.append( - MilvusVecDBItem( - id=entity["id"], - memory=entity.get("memory"), - original_text=entity.get("original_text"), - vector=entity.get("vector"), - payload=payload, - ) + item = MilvusVecDBItem( + id=entity["id"], + memory=entity.get("memory"), + original_text=entity.get("original_text"), + vector=entity.get("vector"), + payload=payload, ) + result.append(item) + needed -= 1 + + if needed <= 0: + if count_total: + return result, total_count + return result + except Exception as e: - logger.warning( - f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." - ) + logger.warning(f"Error during iteration: {e}") finally: - # Close the iterator iterator.close() - logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") - return all_items + logger.info( + f"Milvus retrieve by filter completed - " + f"page {page}, page_size {page_size}, got {len(result)} items." + ) + if count_total: + return result, total_count + return result def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database.""" From acb5799bff0d731711bfda5454e755df9eb4ced5 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:47:09 +0800 Subject: [PATCH 22/48] feat: update source return and chunk settings (#808) * feat: update source return and chunk settings * feat: update code format --- docker/Dockerfile | 2 +- docker/requirements-full.txt | 2 +- docker/requirements.txt | 2 +- src/memos/chunkers/markdown_chunker.py | 2 +- src/memos/mem_reader/read_multi_modal/base.py | 2 +- .../mem_reader/read_multi_modal/file_content_parser.py | 1 + src/memos/mem_reader/read_multi_modal/utils.py | 2 +- src/memos/reranker/strategies/concat_docsource.py | 7 ++++++- 8 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 13fb477d9..76be1709d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,4 +32,4 @@ ENV PYTHONPATH=/app/src EXPOSE 8000 # Start the docker -CMD ["uvicorn", "memos.api.server_api:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] \ No newline at end of file +CMD ["uvicorn", "memos.api.server_api:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt index 538f5e578..57c26067f 100644 --- a/docker/requirements-full.txt +++ b/docker/requirements-full.txt @@ -183,4 +183,4 @@ psycopg2-binary==2.9.9 py-key-value-aio==0.2.8 py-key-value-shared==0.2.8 PyJWT==2.10.1 -pytest==9.0.2 \ No newline at end of file +pytest==9.0.2 diff --git a/docker/requirements.txt b/docker/requirements.txt index 738a53920..aa01fa626 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -125,4 +125,4 @@ urllib3==2.5.0 uvicorn==0.38.0 uvloop==0.22.1 watchfiles==1.1.1 -websockets==15.0.1 \ No newline at end of file +websockets==15.0.1 diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py index de375a4dc..b7771ac35 100644 --- a/src/memos/chunkers/markdown_chunker.py +++ b/src/memos/chunkers/markdown_chunker.py @@ -57,6 +57,6 @@ def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: except Exception as e: logger.warning(f"warning chunking document: {e}") chunks.append(doc.page_content) - + logger.info(f"Generated chunks: {chunks[:5]}") logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py index 7664f4d7f..1a756c5d0 100644 --- a/src/memos/mem_reader/read_multi_modal/base.py +++ b/src/memos/mem_reader/read_multi_modal/base.py @@ -258,7 +258,7 @@ def _split_text(self, text: str, is_markdown: bool = False) -> list[str]: if not text or not text.strip(): return [] - splitter = get_text_splitter() + splitter = get_text_splitter(is_markdown=is_markdown) if not splitter: # If text splitter is not available, return text as single chunk return [text] if text.strip() else [] diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 8fa0f2454..fbc704d0b 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -94,6 +94,7 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, boo response = requests.get(url_str, timeout=30) response.raise_for_status() + response.encoding = "utf-8" if not filename: filename = os.path.basename(parsed_url.path) or "downloaded_file" diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index cba8ddeda..d3d97b4e6 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -107,7 +107,7 @@ def _cheap_close(t: str) -> str: "config": {}, } -DEFAULT_CHUNK_SIZE = int(os.getenv("FILE_PARSER_CHUNK_SIZE", "1000")) +DEFAULT_CHUNK_SIZE = int(os.getenv("FILE_PARSER_CHUNK_SIZE", "1280")) DEFAULT_CHUNK_OVERLAP = int(os.getenv("FILE_PARSER_CHUNK_OVERLAP", "200")) diff --git a/src/memos/reranker/strategies/concat_docsource.py b/src/memos/reranker/strategies/concat_docsource.py index 0fb471218..d90452995 100644 --- a/src/memos/reranker/strategies/concat_docsource.py +++ b/src/memos/reranker/strategies/concat_docsource.py @@ -54,6 +54,7 @@ def prepare_documents( original_items = {} tracker = DialogueRankingTracker() documents = [] + documents_set = set() for item in graph_results: memory = getattr(item, "memory", None) if isinstance(memory, str): @@ -66,7 +67,11 @@ def prepare_documents( if source.type == "file": chunk_text += source.content if chunk_text: - documents.append(f"{memory}\n\n[Sources]:\n{chunk_text}") + if chunk_text in documents_set: + continue + else: + documents_set.add(chunk_text) + documents.append(f"{memory}\n\n[Sources]:\n{chunk_text}") else: documents.append(memory) return tracker, original_items, documents From 9dba3323114c37c599f23fe669857b23d7b6cb91 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Tue, 30 Dec 2025 16:45:27 +0800 Subject: [PATCH 23/48] Scheduler: update exampels (#807) * fix bugs: try to fix bugs in _submit_web_logs * fix bugs: try to address bugs * fix bugs * refactor: modify examples * revise add operation and fix an unbelievable bug * address the bug issues * the doc file has a format problem which has been fixed in this commit * add a range of new feats for the add operation * address the incompatible issue of local scheduler * feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. * fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. * fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. * refactor add_before_search from mem_reader to SingleCubeView * address bugs * fix: fix the qsize bug of task queue, and accept change from hotfix/scheduler * fix: address some issues to run old scheduler example and kv cache example * fix: address the issue of Top-level import of unavailable module 'torch' * fix: resolve linting errors and make optional dependencies lazy loaded - Fix ambiguous characters and commented-out code in examples/mem_scheduler/quick_start_examples.py - Fix nested if statements in src/memos/mem_os/core.py - Move torch and transformers imports to method scope in src/memos/llms/hf.py to support optional dependencies - Update tests/llms/test_hf.py to patch transformers module directly * refactor: revise the rewrite prompt to make it better * refactor: update examples * refactor: update examples for scheduler --- .../mem_scheduler/quick_start_examples.py | 139 ++++++++++++------ 1 file changed, 98 insertions(+), 41 deletions(-) diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py index c71869e76..fbfef4d76 100644 --- a/examples/mem_scheduler/quick_start_examples.py +++ b/examples/mem_scheduler/quick_start_examples.py @@ -145,106 +145,163 @@ def kv_cache_only(): def run_scheduler_example(): - # 使用 MemScheduler 加载主 MOS 配置 - config = parse_yaml( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" - ) + # 使用 MemScheduler 加载主 MOS(Memory-Oriented System)配置文件 + config = parse_yaml("./examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml") + # 将解析出的配置字典传入 MOSConfig 构造器, 构建配置对象 mos_config = MOSConfig(**config) + # 使用配置对象初始化 MOS 系统实例 mos = MOS(mos_config) - # 创建动态用户 ID + # 生成一个唯一的动态用户 ID(使用 UUID4) user_id = str(uuid.uuid4()) + # 在 MOS 系统中为该用户创建账户 mos.create_user(user_id=user_id) - # 创建 MemCube 配置并导出 + # 从 YAML 文件加载 MemCube(记忆立方体)的通用配置 config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" + "./examples/data/config/mem_scheduler/mem_cube_config.yaml" ) + # 定义 MemCube 的唯一标识符 mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + # 定义 MemCube 的本地存储路径(路径中包含用户 ID 和 MemCube ID) + mem_cube_name_or_path = f"./outputs/mem_scheduler/{user_id}/{mem_cube_id}" - # 若存在旧目录则删除 + # 如果该路径已存在, 则先删除旧目录 if Path(mem_cube_name_or_path).exists(): shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") + print(f"{mem_cube_name_or_path} 目录非空,已被删除。") - # 导出新的 MemCube + # 根据加载的配置创建一个新的 MemCube 实例 mem_cube = GeneralMemCube(config) + # 将该 MemCube 实例序列化并保存到指定路径 mem_cube.dump(mem_cube_name_or_path) - # 为该用户注册 MemCube + # 在 MOS 系统中为当前用户注册这个 MemCube mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) - # Define custom scheduler handlers + # 定义一个辅助函数, 用于获取缓存(如 KV Cache)的内存信息 + def get_cache_info(cache): + # 如果缓存为空, 则直接返回 None + if not cache: + return None + + num_layers = 0 # 记录缓存的层数 + total_size_bytes = 0 # 记录总字节数 + + # 情况一: 缓存结构包含 layers 属性(如 HuggingFace 的缓存格式) + if hasattr(cache, "layers"): + num_layers = len(cache.layers) + for layer in cache.layers: + # 统计 key_cache 的内存占用(如果存在) + if hasattr(layer, "key_cache") and layer.key_cache is not None: + total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size() + # 统计 value_cache 的内存占用(如果存在) + if hasattr(layer, "value_cache") and layer.value_cache is not None: + total_size_bytes += ( + layer.value_cache.nelement() * layer.value_cache.element_size() + ) + + # 兼容其他可能的缓存命名方式(如 keys/values) + if hasattr(layer, "keys") and layer.keys is not None: + total_size_bytes += layer.keys.nelement() * layer.keys.element_size() + if hasattr(layer, "values") and layer.values is not None: + total_size_bytes += layer.values.nelement() * layer.values.element_size() + + # 情况二: 缓存结构直接包含 key_cache 和 value_cache 列表(如某些自定义格式) + elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"): + num_layers = len(cache.key_cache) + for k, v in zip(cache.key_cache, cache.value_cache, strict=False): + if k is not None: + total_size_bytes += k.nelement() * k.element_size() + if v is not None: + total_size_bytes += v.nelement() * v.element_size() + + # 返回结构化的缓存信息, 包括层数, 字节数和以 MB 为单位的可读格式 + return { + "num_layers": num_layers, + "size_bytes": total_size_bytes, + "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB", + } + + # 定义自定义的查询(query)处理函数 def custom_query_handler(messages: list[ScheduleMessageItem]): for msg in messages: - print(f"\n[scheduler] 用户输入了query: {msg.content}") - # Trigger mem_update manually + # 打印用户输入内容 + print(f"\n[scheduler] 用户输入了查询:{msg.content}") + # 手动构造一个带有 MEM_UPDATE 标签的新消息, 用于触发记忆更新 new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL}) + # 将该消息提交给调度器处理 mos.mem_scheduler.submit_messages([new_msg]) + # 定义自定义的回答(answer)处理函数 def custom_answer_handler(messages: list[ScheduleMessageItem]): for msg in messages: - mem_cube = mos.mem_cubes.get(msg.mem_cube_id) - kv_mem = mem_cube.act_mem - for cache_item in kv_mem.get_all(): - print( - f"[scheduler] act memory: {get_cache_info(cache_item.memory)} ({cache_item.records})" - ) - print(f"\n[scheduler] LLM回复了answer:{msg.content}") + # 打印 LLM 的回复内容 + print(f"\n[scheduler] LLM 回复了答案:{msg.content}") + # 定义自定义的记忆更新(mem_update)处理函数 def custom_mem_update_handler(messages: list[ScheduleMessageItem]): for msg in messages: mem_cube = mos.mem_cubes.get(msg.mem_cube_id) kv_mem = mem_cube.act_mem + # 如果该 MemCube 配置了文本记忆(TreeTextMemory / NaiveTextMemory) if mem_cube and mem_cube.text_mem: + # 在文本记忆中搜索与当前内容相关的记忆(返回 top_k=3 条) results = mem_cube.text_mem.search(msg.content, top_k=3) for mem in results: - print(f"\n[scheduler] searched memories: {mem.memory}") - + print(f"\n[scheduler] 检索到的记忆:{mem.memory}") + print("\n[scheduler] 转换为激活记忆......") + # 从文本记忆中提取对应的 KV 缓存项 cache_item = kv_mem.extract(mem.memory) + # 附加元信息 cache_item.records.text_memories = [mem.memory] cache_item.records.timestamp = get_utc_now() + # 将该缓存项添加到激活记忆中 kv_mem.add([cache_item]) + print("\n[scheduler] 完成!") - # Register custom handlers + # 将上述三个自定义处理器注册到调度器的分发器中, 分别对应不同任务标签 mos.mem_scheduler.dispatcher.register_handlers( { - QUERY_TASK_LABEL: custom_query_handler, - ANSWER_TASK_LABEL: custom_answer_handler, - MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, + QUERY_TASK_LABEL: custom_query_handler, # 查询任务 + ANSWER_TASK_LABEL: custom_answer_handler, # 回答任务 + MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, # 记忆更新任务 } ) - # 添加消息 + # 初始添加两条测试消息(用户和助手的对话)到系统中 messages = [ {"role": "user", "content": "I like playing football."}, {"role": "assistant", "content": "I like playing football too."}, ] mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id) - # 聊天循环: 展示 TreeTextMemory 节点 + KVCache + # 进入聊天循环: 展示 TreeTextMemory 的记忆节点结构 + KV Cache 的状态 while True: + # 获取用户输入并去除首尾空格 user_input = input("👤 [You] ").strip() print() + # 调用 MOS 系统进行聊天响应 response = mos.chat(user_input, user_id=user_id) + # 获取该用户当前 MemCube 中的所有记忆内容 retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id) + # 打印助手的回复 print(f"🤖 [Assistant] {response}") - # 展示 TreeTextMemory 中的各类型节点 - text_memories = retrieved_memories["text_mem"][0]["memories"] - # Handle different memory structures (NaiveTextMemory returns list, TreeTextMemory returns dict with nodes) - if isinstance(text_memories, dict) and "nodes" in text_memories: - for node in text_memories["nodes"]: - mem_type = node["metadata"].get("memory_type", "Unknown") - print(f"[{mem_type}] {node['memory']}") - elif isinstance(text_memories, list): - for mem in text_memories: - # Naive memory items might not have memory_type metadata, or it might be different - print(f"[TextMemory] {mem.memory if hasattr(mem, 'memory') else mem}") + # 获取文本记忆部分 - TreeTextMemory + memories = retrieved_memories["text_mem"][0]["memories"] + for mem in memories: + print(f"[文本记忆] {mem.memory}") + + # 获取对应的 MemCube 和其激活记忆(KV Cache) + mem_cube = mos.mem_scheduler.mem_cube + kv_mem = mem_cube.act_mem + # 遍历所有激活记忆项, 打印其缓存信息和记录 + for cache_item in kv_mem.get_all(): + print(f"[激活记忆] {get_cache_info(cache_item.memory)} (记录:{cache_item.records})") if __name__ == "__main__": From c88f17aa64a8a698d74e2f490856dbfd8cd10f4a Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 30 Dec 2025 17:03:21 +0800 Subject: [PATCH 24/48] feat: add delete_node_by_prams filter (#810) --- src/memos/graph_dbs/polardb.py | 120 ++++++++++++--------------------- 1 file changed, 42 insertions(+), 78 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8eabda6d8..12f2c2ca9 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4960,7 +4960,8 @@ def delete_node_by_prams( If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. file_ids (list[str], optional): List of file node IDs to delete. - filter (dict, optional): Filter dictionary to query matching nodes for deletion. + filter (dict, optional): Filter dictionary for metadata filtering. + Filter conditions are directly used in DELETE WHERE clause without pre-querying. Returns: int: Number of nodes deleted. @@ -4980,35 +4981,14 @@ def delete_node_by_prams( f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype" ) - # Query nodes by filter if provided - filter_ids = set() + # Build filter conditions using common method (no query, direct use in WHERE clause) + filter_conditions = [] if filter: - # Parse filter to validate and transform field names (e.g., add "info." prefix if needed) - parsed_filter = self.parse_filter(filter) - if parsed_filter: - # Use get_by_metadata with empty filters list and parsed filter - filter_ids = set( - self.get_by_metadata( - filters=[], - user_name=None, - filter=parsed_filter, - knowledgebase_ids=writable_cube_ids, - ) - ) - else: - logger.warning( - "[delete_node_by_prams] Filter parsed to None, skipping filter query" - ) - - # Combine all IDs that need to be deleted - all_memory_ids = set() - if memory_ids: - all_memory_ids.update(memory_ids) - if filter_ids: - all_memory_ids.update(filter_ids) + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") # If no conditions to delete, return 0 - if not all_memory_ids and not file_ids: + if not memory_ids and not file_ids and not filter_conditions: logger.warning( "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" ) @@ -5019,74 +4999,58 @@ def delete_node_by_prams( try: conn = self._get_connection() with conn.cursor() as cursor: - # Process memory_ids and filter_ids (all at once, no batching) - if all_memory_ids: - memory_ids_list = list(all_memory_ids) - logger.info( - f"[delete_node_by_prams] Processing {len(memory_ids_list)} memory_ids" - ) + # Build WHERE conditions list + where_conditions = [] - # Build conditions for all memory_ids + # Add memory_ids conditions + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") id_conditions = [] - for node_id in memory_ids_list: + for node_id in memory_ids: id_conditions.append( f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype" ) - id_where = f"({' OR '.join(id_conditions)})" - - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({id_where})" - else: - where_clause = id_where - - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] memory_ids delete_query: {delete_query}") - - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count += deleted_count + where_conditions.append(f"({' OR '.join(id_conditions)})") - logger.info( - f"[delete_node_by_prams] Deleted {deleted_count} nodes by memory_ids" - ) - - # Process file_ids (all at once, no batching) + # Add file_ids conditions if file_ids: logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") - - # Build conditions for all file_ids file_id_conditions = [] for file_id in file_ids: file_id_conditions.append( f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)" ) - file_id_where = f"({' OR '.join(file_id_conditions)})" + where_conditions.append(f"({' OR '.join(file_id_conditions)})") - # Add user_name filter if provided - if user_name_conditions: - user_name_where = " OR ".join(user_name_conditions) - where_clause = f"({user_name_where}) AND ({file_id_where})" - else: - where_clause = file_id_where + # Add filter conditions + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + where_conditions.extend(filter_conditions) - # Delete directly without counting - delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE {where_clause} - """ - logger.info(f"[delete_node_by_prams] file_ids delete_query: {delete_query}") + # Add user_name filter if provided + if user_name_conditions: + user_name_where = " OR ".join(user_name_conditions) + where_conditions.append(f"({user_name_where})") + + # Build final WHERE clause + if not where_conditions: + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") + return 0 - cursor.execute(delete_query) - deleted_count = cursor.rowcount - total_deleted_count += deleted_count + where_clause = " AND ".join(where_conditions) + + # Delete directly without counting + delete_query = f""" + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE {where_clause} + """ + logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") + + cursor.execute(delete_query) + deleted_count = cursor.rowcount + total_deleted_count = deleted_count - logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes by file_ids") + logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes") elapsed_time = time.time() - batch_start_time logger.info( From 791c2ee91c2c5d65b8879952af439ca863f707b9 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Tue, 30 Dec 2025 18:48:13 +0800 Subject: [PATCH 25/48] fix: merge all preference (#811) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback * patch: get_memory func filter user id and make page chunk * add total num * add total num * add milvus pagination * fix merge implicit explicit pref * fix merge implicit explicit pref * fix merge implicit explicit pref --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/handlers/memory_handler.py | 22 ++---- src/memos/memories/textual/preference.py | 42 +++++++----- src/memos/vec_dbs/milvus.py | 86 +++++++----------------- 3 files changed, 57 insertions(+), 93 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 941b59106..a744e16e2 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -192,26 +192,19 @@ def handle_get_memories( del memories["total_edges"] preferences: list[TextualMemoryItem] = [] - total_explicit_nodes, total_implicit_nodes = 0, 0 + total_pref = 0 + if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} if get_mem_req.user_id is not None: filter_params["user_id"] = get_mem_req.user_id if get_mem_req.mem_cube_id is not None: filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - preferences = naive_mem_cube.pref_mem.get_memory_by_filter( + + preferences, total_pref = naive_mem_cube.pref_mem.get_memory_by_filter( filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size ) - - for key, value_list in preferences.items(): - if key in ["explicit_preference", "implicit_preference"]: - formatted_list = [format_memory_item(item) for item in value_list] - preferences[key] = formatted_list - - total_explicit_nodes = preferences["total_explicit_nodes"] - total_implicit_nodes = preferences["total_implicit_nodes"] - del preferences["total_explicit_nodes"] - del preferences["total_implicit_nodes"] + format_preferences = [format_memory_item(item) for item in preferences] return GetMemoryResponse( message="Memories retrieved successfully", @@ -227,9 +220,8 @@ def handle_get_memories( "pref_mem": [ { "cube_id": get_mem_req.mem_cube_id, - "memories": preferences, - "total_explicit_nodes": total_explicit_nodes, - "total_implicit_nodes": total_implicit_nodes, + "memories": format_preferences, + "total_nodes": total_pref, } ], }, diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 75d7d2a4c..cb4f00735 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -1,6 +1,7 @@ import json import os +from datetime import datetime from typing import Any from memos.configs.memory import PreferenceTextMemoryConfig @@ -262,8 +263,11 @@ def get_all(self) -> list[TextualMemoryItem]: return all_memories def get_memory_by_filter( - self, filter: dict[str, Any] | None = None, **kwargs - ) -> list[TextualMemoryItem]: + self, + filter: dict[str, Any] | None = None, + page: int | None = None, + page_size: int | None = None, + ): """Get memories by filter. Args: filter (dict[str, Any]): Filter criteria. @@ -272,14 +276,9 @@ def get_memory_by_filter( """ collection_list = self.vector_db.config.collection_name - memories = {} - total_explicit_nodes = 0 - total_implicit_nodes = 0 + memories = [] for collection_name in collection_list: - memories[collection_name] = [] - db_items, total_count = self.vector_db.get_by_filter( - collection_name=collection_name, filter=filter, count_total=True, **kwargs - ) + db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter) db_items_memory = [ TextualMemoryItem( id=memo.id, @@ -288,16 +287,23 @@ def get_memory_by_filter( ) for memo in db_items ] - memories[collection_name].extend(db_items_memory) + memories.extend(db_items_memory) - if collection_name == "explicit_preference": - total_explicit_nodes = total_count - if collection_name == "implicit_preference": - total_implicit_nodes = total_count - memories["total_explicit_nodes"] = total_explicit_nodes - memories["total_implicit_nodes"] = total_implicit_nodes - - return memories + # sort + sorted_memories = sorted( + memories, + key=lambda item: datetime.fromisoformat(item.metadata.created_at), + reverse=True, + ) + if page and page_size: + if page < 1: + page = 1 + if page_size < 1: + page_size = 10 + pick_memories = sorted_memories[(page - 1) * page_size : page * page_size] + return pick_memories, len(sorted_memories) + + return sorted_memories, len(sorted_memories) def delete(self, memory_ids: list[str]) -> None: """Delete memories. diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index b0753b31d..ecbca5815 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -493,14 +493,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt return items def get_by_filter( - self, - collection_name: str, - filter: dict[str, Any], - scroll_limit: int = 100, - page: int | None = None, - page_size: int | None = None, - count_total=False, - **kwargs, + self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -513,74 +506,47 @@ def get_by_filter( List of items including vectors and payload that match the filter """ expr = self._dict_to_expr(filter) if filter else "" - if count_total: - total_count = 0 - count_iterator = self.client.query_iterator( - collection_name=collection_name, - filter=expr, - batch_size=scroll_limit, - output_fields=["id"], - ) - try: - while True: - batch = count_iterator.next() - if not batch: - break - total_count += len(batch) - finally: - count_iterator.close() - - result = [] - skipped = 0 - needed = page_size + all_items = [] + # Use query_iterator for efficient pagination iterator = self.client.query_iterator( collection_name=collection_name, filter=expr, batch_size=scroll_limit, - output_fields=["*"], + output_fields=["*"], # Include all fields including payload ) + # Iterate through all batches try: - while needed > 0: - batch = iterator.next() - if not batch: - break - - for entity in batch: - skipped += 1 + while True: + batch_results = iterator.next() - if skipped <= (page - 1) * page_size: - continue + if not batch_results: + break + # Convert batch results to MilvusVecDBItem objects + for entity in batch_results: + # Extract the actual payload from Milvus entity payload = entity.get("payload", {}) - item = MilvusVecDBItem( - id=entity["id"], - memory=entity.get("memory"), - original_text=entity.get("original_text"), - vector=entity.get("vector"), - payload=payload, + all_items.append( + MilvusVecDBItem( + id=entity["id"], + memory=entity.get("memory"), + original_text=entity.get("original_text"), + vector=entity.get("vector"), + payload=payload, + ) ) - result.append(item) - needed -= 1 - - if needed <= 0: - if count_total: - return result, total_count - return result - except Exception as e: - logger.warning(f"Error during iteration: {e}") + logger.warning( + f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far." + ) finally: + # Close the iterator iterator.close() - logger.info( - f"Milvus retrieve by filter completed - " - f"page {page}, page_size {page_size}, got {len(result)} items." - ) - if count_total: - return result, total_count - return result + logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") + return all_items def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database.""" From c152f447c3a88aa777391a85ee3242c2e8c7a137 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Wed, 31 Dec 2025 14:30:23 +0800 Subject: [PATCH 26/48] feat: update _build_filter_conditions_sql in conditions && build_cypher_filter_condition filter (#812) * feat: add export_graph filter * feat: update _build_filter_conditions_sql in conditions * feat: build_cypher_filter_condition filter --- src/memos/graph_dbs/polardb.py | 249 +++++++++++++++++++++++++-------- 1 file changed, 190 insertions(+), 59 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 12f2c2ca9..c9f3ee5ba 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -2508,6 +2508,7 @@ def export_graph( user_id: str | None = None, page: int | None = None, page_size: int | None = None, + filter: dict | None = None, **kwargs, ) -> dict[str, Any]: """ @@ -2518,6 +2519,13 @@ def export_graph( user_id (str, optional): User ID for filtering page (int, optional): Page number (starts from 1). If None, exports all data without pagination. page_size (int, optional): Number of items per page. If None, exports all data without pagination. + filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators: + - "=": equality + - "in": value in list + - "contains": array contains value + - "gt", "lt", "gte", "lte": comparison operators + - "like": fuzzy matching + Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]} Returns: { @@ -2528,7 +2536,7 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}" + f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}" ) user_id = user_id if user_id else self._get_config_value("user_id") @@ -2563,6 +2571,12 @@ def export_graph( f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" ) + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) + where_clause = "" if where_conditions: where_clause = f"WHERE {' AND '.join(where_conditions)}" @@ -2652,6 +2666,22 @@ def export_graph( cypher_where_conditions.append(f"a.user_id = '{user_id}'") cypher_where_conditions.append(f"b.user_id = '{user_id}'") + # Build filter conditions for edges (apply to both source and target nodes) + filter_where_clause = self._build_filter_conditions_cypher(filter) + logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") + if filter_where_clause: + # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists + # Remove the leading " AND " and replace n. with a. for source node and b. for target node + filter_clause = filter_where_clause.strip() + if filter_clause.startswith("AND "): + filter_clause = filter_clause[4:].strip() + # Replace n. with a. for source node and create a copy for target node + source_filter = filter_clause.replace("n.", "a.") + target_filter = filter_clause.replace("n.", "b.") + # Combine source and target filters with AND + combined_filter = f"({source_filter}) AND ({target_filter})" + cypher_where_conditions.append(combined_filter) + cypher_where_clause = "" if cypher_where_conditions: cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" @@ -4416,70 +4446,133 @@ def build_cypher_filter_condition(condition_dict: dict) -> str: elif op == "in": # Handle in operator (for checking if field value is in a list) # Supports array format: {"field": {"in": ["value1", "value2"]}} - # Generates: n.field IN ['value1', 'value2'] or (n.field = 'value1' OR n.field = 'value2') + # For array fields (like file_ids, tags, sources), uses CONTAINS logic + # For scalar fields, uses equality or IN clause if not isinstance(op_value, list): raise ValueError( f"in operator only supports array format. " f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" ) + # Check if key is an array field + is_array_field = key in ("file_ids", "tags", "sources") + # Check if key starts with "info." prefix if key.startswith("info."): info_field = key[5:] # Remove "info." prefix - # Build OR conditions for nested properties (Apache AGE compatibility) + # Check if info field is an array field + is_info_array = info_field in ("tags", "sources", "file_ids") + if len(op_value) == 0: # Empty list means no match condition_parts.append("false") elif len(op_value) == 1: - # Single value, use equality + # Single value item = op_value[0] - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - condition_parts.append( - f"n.info.{info_field} = '{escaped_value}'" - ) + if is_info_array: + # For array fields, use CONTAINS (value IN array_field) + if isinstance(item, str): + escaped_value = escape_cypher_string(item) + condition_parts.append( + f"'{escaped_value}' IN n.info.{info_field}" + ) + else: + condition_parts.append( + f"{item} IN n.info.{info_field}" + ) else: - condition_parts.append(f"n.info.{info_field} = {item}") - else: - # Multiple values, use OR conditions instead of IN (Apache AGE compatibility) - or_conditions = [] - for item in op_value: + # For scalar fields, use equality if isinstance(item, str): escaped_value = escape_cypher_string(item) - or_conditions.append( + condition_parts.append( f"n.info.{info_field} = '{escaped_value}'" ) else: - or_conditions.append( + condition_parts.append( f"n.info.{info_field} = {item}" ) + else: + # Multiple values, use OR conditions + or_conditions = [] + for item in op_value: + if is_info_array: + # For array fields, use CONTAINS (value IN array_field) + if isinstance(item, str): + escaped_value = escape_cypher_string(item) + or_conditions.append( + f"'{escaped_value}' IN n.info.{info_field}" + ) + else: + or_conditions.append( + f"{item} IN n.info.{info_field}" + ) + else: + # For scalar fields, use equality + if isinstance(item, str): + escaped_value = escape_cypher_string(item) + or_conditions.append( + f"n.info.{info_field} = '{escaped_value}'" + ) + else: + or_conditions.append( + f"n.info.{info_field} = {item}" + ) if or_conditions: condition_parts.append( f"({' OR '.join(or_conditions)})" ) else: # Direct property access - # Build array for IN clause or OR conditions if len(op_value) == 0: # Empty list means no match condition_parts.append("false") elif len(op_value) == 1: - # Single value, use equality + # Single value item = op_value[0] - if isinstance(item, str): - escaped_value = escape_cypher_string(item) - condition_parts.append(f"n.{key} = '{escaped_value}'") + if is_array_field: + # For array fields, use CONTAINS (value IN array_field) + if isinstance(item, str): + escaped_value = escape_cypher_string(item) + condition_parts.append( + f"'{escaped_value}' IN n.{key}" + ) + else: + condition_parts.append(f"{item} IN n.{key}") else: - condition_parts.append(f"n.{key} = {item}") + # For scalar fields, use equality + if isinstance(item, str): + escaped_value = escape_cypher_string(item) + condition_parts.append( + f"n.{key} = '{escaped_value}'" + ) + else: + condition_parts.append(f"n.{key} = {item}") else: - # Multiple values, use IN clause - escaped_items = [ - f"'{escape_cypher_string(str(item))}'" - if isinstance(item, str) - else str(item) - for item in op_value - ] - array_str = "[" + ", ".join(escaped_items) + "]" - condition_parts.append(f"n.{key} IN {array_str}") + # Multiple values + if is_array_field: + # For array fields, use OR conditions with CONTAINS + or_conditions = [] + for item in op_value: + if isinstance(item, str): + escaped_value = escape_cypher_string(item) + or_conditions.append( + f"'{escaped_value}' IN n.{key}" + ) + else: + or_conditions.append(f"{item} IN n.{key}") + if or_conditions: + condition_parts.append( + f"({' OR '.join(or_conditions)})" + ) + else: + # For scalar fields, use IN clause + escaped_items = [ + f"'{escape_cypher_string(str(item))}'" + if isinstance(item, str) + else str(item) + for item in op_value + ] + array_str = "[" + ", ".join(escaped_items) + "]" + condition_parts.append(f"n.{key} IN {array_str}") elif op == "like": # Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%') # Check if key starts with "info." prefix @@ -4710,78 +4803,116 @@ def build_filter_condition(condition_dict: dict) -> str: elif op == "in": # Handle in operator (for checking if field value is in a list) # Supports array format: {"field": {"in": ["value1", "value2"]}} + # For array fields (like file_ids, tags, sources), uses @> operator (contains) + # For scalar fields, uses = operator (equality) if not isinstance(op_value, list): raise ValueError( f"in operator only supports array format. " f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}" ) + # Check if key is an array field + is_array_field = key in ("file_ids", "tags", "sources") + # Check if key starts with "info." prefix if key.startswith("info."): info_field = key[5:] # Remove "info." prefix - # Build OR conditions for nested properties + # Check if info field is an array field + is_info_array = info_field in ("tags", "sources", "file_ids") + if len(op_value) == 0: # Empty list means no match condition_parts.append("false") elif len(op_value) == 1: - # Single value, use equality + # Single value item = op_value[0] - if isinstance(item, str): - escaped_value = escape_sql_string(item) + if is_info_array: + # For array fields, use @> operator (contains) + escaped_value = escape_sql_string(str(item)) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype" ) else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" - ) - else: - # Multiple values, use OR conditions - or_conditions = [] - for item in op_value: + # For scalar fields, use equality if isinstance(item, str): escaped_value = escape_sql_string(item) - or_conditions.append( + condition_parts.append( f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" ) else: - or_conditions.append( + condition_parts.append( f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" ) + else: + # Multiple values, use OR conditions + or_conditions = [] + for item in op_value: + if is_info_array: + # For array fields, use @> operator (contains) to check if array contains the value + escaped_value = escape_sql_string(str(item)) + or_conditions.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype" + ) + else: + # For scalar fields, use equality + if isinstance(item, str): + escaped_value = escape_sql_string(item) + or_conditions.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" + ) + else: + or_conditions.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype" + ) if or_conditions: condition_parts.append( f"({' OR '.join(or_conditions)})" ) else: # Direct property access - # Build OR conditions if len(op_value) == 0: # Empty list means no match condition_parts.append("false") elif len(op_value) == 1: - # Single value, use equality + # Single value item = op_value[0] - if isinstance(item, str): - escaped_value = escape_sql_string(item) + if is_array_field: + # For array fields, use @> operator (contains) + escaped_value = escape_sql_string(str(item)) condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype" ) else: - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" - ) - else: - # Multiple values, use OR conditions - or_conditions = [] - for item in op_value: + # For scalar fields, use equality if isinstance(item, str): escaped_value = escape_sql_string(item) - or_conditions.append( + condition_parts.append( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" ) else: - or_conditions.append( + condition_parts.append( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" ) + else: + # Multiple values, use OR conditions + or_conditions = [] + for item in op_value: + if is_array_field: + # For array fields, use @> operator (contains) to check if array contains the value + escaped_value = escape_sql_string(str(item)) + or_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype" + ) + else: + # For scalar fields, use equality + if isinstance(item, str): + escaped_value = escape_sql_string(item) + or_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" + ) + else: + or_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype" + ) if or_conditions: condition_parts.append( f"({' OR '.join(or_conditions)})" From c0b7228ebbe0a82ca9afae361b16f2a9807943e6 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Wed, 31 Dec 2025 16:23:06 +0800 Subject: [PATCH 27/48] feat: _build_filter_conditions_sql filter (#813) --- src/memos/graph_dbs/polardb.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index c9f3ee5ba..b0a8bc4be 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4686,8 +4686,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype" ) else: + # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype + value_json = json.dumps(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} {op_value}::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} ag_catalog.agtype_in('{value_json}')" ) else: # Direct property access (e.g., "created_at" is directly in properties, not in properties.info) @@ -4697,8 +4699,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype" ) else: + # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype + value_json = json.dumps(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} {op_value}::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} ag_catalog.agtype_in('{value_json}')" ) elif op == "=": # Handle equality operator @@ -4739,8 +4743,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '[{op_value}]'::agtype" ) else: + # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype + value_json = json.dumps(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {op_value}::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = ag_catalog.agtype_in('{value_json}')" ) else: # Direct property access @@ -4767,8 +4773,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '{json_array}'::agtype" ) else: + # For non-string list values, convert to JSON string and then to agtype + value_json = json.dumps(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')" ) else: if key in ("tags", "sources"): @@ -4776,8 +4784,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '[{op_value}]'::agtype" ) else: + # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype + value_json = json.dumps(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')" ) elif op == "contains": # Handle contains operator @@ -4962,8 +4972,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype" ) else: + # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype + value_json = json.dumps(value) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{value}\"'::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = ag_catalog.agtype_in('{value_json}')" ) else: # Direct property access (simple equality) @@ -4973,8 +4985,10 @@ def build_filter_condition(condition_dict: dict) -> str: f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype" ) else: + # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype + value_json = json.dumps(value) condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')" ) return " AND ".join(condition_parts) From 7993c3afb1b97a402acb6f4fc2c409b02bf69f95 Mon Sep 17 00:00:00 2001 From: zhixiangxue Date: Wed, 31 Dec 2025 17:00:09 +0800 Subject: [PATCH 28/48] fix: update deprecated APIs for chonkie v1.4.0 and qdrant-client v1.16.0 (#705) * fix: update deprecated APIs and dependency versions * feat: add chonkie API version compatibility * chore: update poetry.lock for chonkie version compatibility --------- Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- poetry.lock | 54 +++++++++++++------------- pyproject.toml | 2 +- src/memos/chunkers/sentence_chunker.py | 25 +++++++++--- src/memos/mem_os/core.py | 4 +- src/memos/vec_dbs/qdrant.py | 6 +-- 5 files changed, 53 insertions(+), 38 deletions(-) diff --git a/poetry.lock b/poetry.lock index 187b6c4aa..2a5ed9080 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1098,7 +1098,7 @@ description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.9" groups = ["main", "eval"] -markers = "(platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") and python_version < \"3.14\"" +markers = "python_version < \"3.14\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")" files = [ {file = "greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be"}, {file = "greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac"}, @@ -2641,7 +2641,7 @@ files = [ {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"}, {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cuda-cupti-cu12" @@ -2657,7 +2657,7 @@ files = [ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"}, {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cuda-nvrtc-cu12" @@ -2671,7 +2671,7 @@ files = [ {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"}, {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cuda-runtime-cu12" @@ -2687,7 +2687,7 @@ files = [ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"}, {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cudnn-cu12" @@ -2701,7 +2701,7 @@ files = [ {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"}, {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-cublas-cu12 = "*" @@ -2720,7 +2720,7 @@ files = [ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"}, {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-nvjitlink-cu12 = "*" @@ -2736,7 +2736,7 @@ files = [ {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"}, {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-curand-cu12" @@ -2752,7 +2752,7 @@ files = [ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"}, {file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-cusolver-cu12" @@ -2768,7 +2768,7 @@ files = [ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"}, {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-cublas-cu12 = "*" @@ -2789,7 +2789,7 @@ files = [ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"}, {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] nvidia-nvjitlink-cu12 = "*" @@ -2806,7 +2806,7 @@ files = [ {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"}, {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-nccl-cu12" @@ -2819,7 +2819,7 @@ files = [ {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"}, {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-nvjitlink-cu12" @@ -2833,7 +2833,7 @@ files = [ {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"}, {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "nvidia-nvtx-cu12" @@ -2849,7 +2849,7 @@ files = [ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"}, {file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "ollama" @@ -4003,14 +4003,14 @@ files = [ [[package]] name = "qdrant-client" -version = "1.14.3" +version = "1.16.2" description = "Client library for the Qdrant vector search engine" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main", "eval"] files = [ - {file = "qdrant_client-1.14.3-py3-none-any.whl", hash = "sha256:66faaeae00f9b5326946851fe4ca4ddb1ad226490712e2f05142266f68dfc04d"}, - {file = "qdrant_client-1.14.3.tar.gz", hash = "sha256:bb899e3e065b79c04f5e47053d59176150c0a5dabc09d7f476c8ce8e52f4d281"}, + {file = "qdrant_client-1.16.2-py3-none-any.whl", hash = "sha256:442c7ef32ae0f005e88b5d3c0783c63d4912b97ae756eb5e052523be682f17d3"}, + {file = "qdrant_client-1.16.2.tar.gz", hash = "sha256:ca4ef5f9be7b5eadeec89a085d96d5c723585a391eb8b2be8192919ab63185f0"}, ] markers = {main = "extra == \"all\""} @@ -4018,11 +4018,13 @@ markers = {main = "extra == \"all\""} grpcio = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ - {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""}, + {version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""}, + {version = ">=1.21", markers = "python_version == \"3.11\""}, {version = ">=1.26", markers = "python_version == \"3.12\""}, - {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, + {version = ">=2.1.0", markers = "python_version == \"3.13\""}, + {version = ">=2.3.0", markers = "python_version >= \"3.14\""}, ] -portalocker = ">=2.7.0,<3.0.0" +portalocker = ">=2.7.0,<4.0" protobuf = ">=3.20.0" pydantic = ">=1.10.8,<2.0.dev0 || >2.2.0" urllib3 = ">=1.26.14,<3" @@ -5441,7 +5443,7 @@ files = [ {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"}, {file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"}, ] -markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [package.dependencies] setuptools = ">=40.8.0" @@ -5650,7 +5652,7 @@ description = "Fast implementation of asyncio event loop on top of libuv" optional = false python-versions = ">=3.8.0" groups = ["main"] -markers = "platform_python_implementation != \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"cygwin\"" +markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"" files = [ {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"}, {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"}, @@ -6242,4 +6244,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "dab8e54c6f4c51597adbd0fa34be7a8adb3b3a9c733508f3cc2b93c0ed434ec1" +content-hash = "22bfcac5ed0be1e3aea294e3da96ff1a4bd9d7b62865ad827e1508f5ade6b708" diff --git a/pyproject.toml b/pyproject.toml index 3c2eecf18..f869f7642 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,7 @@ all = [ # We kindof don't want users to install them. "torch (>=2.7.1,<3.0.0)", "sentence-transformers (>=4.1.0,<5.0.0)", - "qdrant-client (>=1.14.2,<2.0.0)", + "qdrant-client (>=1.16.0,<2.0.0)", "volcengine-python-sdk (>=4.0.4,<5.0.0)", "nltk (>=3.9.1,<4.0.0)", "rake-nltk (>=1.0.6,<1.1.0)", diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 080962482..4757301c7 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -20,12 +20,25 @@ def __init__(self, config: SentenceChunkerConfig): from chonkie import SentenceChunker as ChonkieSentenceChunker self.config = config - self.chunker = ChonkieSentenceChunker( - tokenizer_or_token_counter=config.tokenizer_or_token_counter, - chunk_size=config.chunk_size, - chunk_overlap=config.chunk_overlap, - min_sentences_per_chunk=config.min_sentences_per_chunk, - ) + + # Try new API first (v1.4.0+) + try: + self.chunker = ChonkieSentenceChunker( + tokenizer=config.tokenizer_or_token_counter, + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + min_sentences_per_chunk=config.min_sentences_per_chunk, + ) + except (TypeError, AttributeError) as e: + # Fallback to old API ( list[str] | list[Chunk]: diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index e7f01ec3e..22cd0e9cb 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -2,7 +2,7 @@ import os import time -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from threading import Lock from typing import Any, Literal @@ -192,7 +192,7 @@ def _register_chat_history( self.chat_history_manager[user_id] = ChatHistory( user_id=user_id if user_id is not None else self.user_id, session_id=session_id if session_id is not None else self.session_id, - created_at=datetime.utcnow(), + created_at=datetime.now(timezone.utc), total_messages=0, chat_history=[], ) diff --git a/src/memos/vec_dbs/qdrant.py b/src/memos/vec_dbs/qdrant.py index 633cd3580..d0853c4af 100644 --- a/src/memos/vec_dbs/qdrant.py +++ b/src/memos/vec_dbs/qdrant.py @@ -138,14 +138,14 @@ def search( List of search results with distance scores and payloads. """ qdrant_filter = self._dict_to_filter(filter) if filter else None - response = self.client.search( + response = self.client.query_points( collection_name=self.config.collection_name, - query_vector=query_vector, + query=query_vector, limit=top_k, query_filter=qdrant_filter, with_vectors=True, with_payload=True, - ) + ).points logger.info(f"Qdrant search completed with {len(response)} results.") return [ VecDBItem( From 03b79a225d89051897fa688ba926de78921603db Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Wed, 31 Dec 2025 17:27:29 +0800 Subject: [PATCH 29/48] feat: update code format (#814) * fix: code * feat:change qdrant test * feat: code format --- poetry.lock | 6 +++--- src/memos/chunkers/sentence_chunker.py | 4 ++-- tests/vec_dbs/test_qdrant.py | 26 +++++++++++++++++++------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/poetry.lock b/poetry.lock index 2a5ed9080..fb818e665 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -3163,8 +3163,8 @@ markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pre [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" @@ -4018,8 +4018,8 @@ markers = {main = "extra == \"all\""} grpcio = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ - {version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""}, {version = ">=1.21", markers = "python_version == \"3.11\""}, + {version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""}, {version = ">=1.26", markers = "python_version == \"3.12\""}, {version = ">=2.1.0", markers = "python_version == \"3.13\""}, {version = ">=2.3.0", markers = "python_version >= \"3.14\""}, diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 4757301c7..f39dfb8e2 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -20,7 +20,7 @@ def __init__(self, config: SentenceChunkerConfig): from chonkie import SentenceChunker as ChonkieSentenceChunker self.config = config - + # Try new API first (v1.4.0+) try: self.chunker = ChonkieSentenceChunker( @@ -38,7 +38,7 @@ def __init__(self, config: SentenceChunkerConfig): chunk_overlap=config.chunk_overlap, min_sentences_per_chunk=config.min_sentences_per_chunk, ) - + logger.info(f"Initialized SentenceChunker with config: {config}") def chunk(self, text: str) -> list[str] | list[Chunk]: diff --git a/tests/vec_dbs/test_qdrant.py b/tests/vec_dbs/test_qdrant.py index f4bd276c3..67f76d463 100644 --- a/tests/vec_dbs/test_qdrant.py +++ b/tests/vec_dbs/test_qdrant.py @@ -70,13 +70,25 @@ def test_add_and_get_by_id(vec_db): def test_search(vec_db): id = str(uuid.uuid4()) - vec_db.client.search.return_value = [ - type( - "obj", - (object,), - {"id": id, "vector": [0.1, 0.2, 0.3], "payload": {"tag": "search"}, "score": 0.9}, - ) - ] + mock_response = type( + "QueryResponse", + (object,), + { + "points": [ + type( + "obj", + (object,), + { + "id": id, + "vector": [0.1, 0.2, 0.3], + "payload": {"tag": "search"}, + "score": 0.9, + }, + ) + ] + }, + )() + vec_db.client.query_points.return_value = mock_response results = vec_db.search([0.1, 0.2, 0.3], top_k=1) assert len(results) == 1 assert isinstance(results[0], VecDBItem) From 8819cc5e611d01c821bbe427b8b00eb94450bcb2 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:05:07 +0800 Subject: [PATCH 30/48] Feat/optimize cloud service api (#816) * add get_user_names_by_memory_ids api * modify delete api --------- Co-authored-by: yuan.wang --- src/memos/api/handlers/memory_handler.py | 13 ++++--------- src/memos/api/product_models.py | 2 +- src/memos/api/routers/server_router.py | 2 +- src/memos/memories/textual/preference.py | 9 +++++++++ src/memos/memories/textual/tree.py | 9 ++++++++- src/memos/vec_dbs/milvus.py | 8 ++++++++ 6 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index a744e16e2..ef829d757 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -246,8 +246,7 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: try: if delete_mem_req.memory_ids is not None: - for cube_id in delete_mem_req.writable_cube_ids: - naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids, user_name=cube_id) + naive_mem_cube.text_mem.delete_by_memory_ids(delete_mem_req.memory_ids) if naive_mem_cube.pref_mem is not None: naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) elif delete_mem_req.file_ids is not None: @@ -255,13 +254,9 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids ) elif delete_mem_req.filter is not None: - # TODO: Implement deletion by filter - # Need to find memories matching filter and delete them - logger.warning("Deletion by filter not implemented yet") - return DeleteMemoryResponse( - message="Deletion by filter not implemented yet", - data={"status": "failure"}, - ) + naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter) + if naive_mem_cube.pref_mem is not None: + naive_mem_cube.pref_mem.delete_by_filter(filter=delete_mem_req.filter) except Exception as e: logger.error(f"Failed to delete memories: {e}", exc_info=True) return DeleteMemoryResponse( diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 25e0d809d..c52d9e8d2 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -784,7 +784,7 @@ class GetMemoryRequest(BaseRequest): class DeleteMemoryRequest(BaseRequest): """Request model for deleting memories.""" - writable_cube_ids: list[str] = Field(..., description="Writable cube IDs") + writable_cube_ids: list[str] = Field(None, description="Writable cube IDs") memory_ids: list[str] | None = Field(None, description="Memory IDs") file_ids: list[str] | None = Field(None, description="File IDs") filter: dict[str, Any] | None = Field(None, description="Filter for the memory") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 07c42bbb2..c3b05e823 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -340,7 +340,7 @@ def feedback_memories(feedback_req: APIFeedbackRequest): # ============================================================================= -@router.get( +@router.post( "/get_user_names_by_memory_ids", summary="Get user names by memory ids", response_model=GetUserNamesByMemoryIdsResponse, diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index cb4f00735..a34315918 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -314,6 +314,15 @@ def delete(self, memory_ids: list[str]) -> None: for collection_name in collection_list: self.vector_db.delete(collection_name, memory_ids) + def delete_by_filter(self, filter: dict[str, Any]) -> None: + """Delete memories by filter. + Args: + filter (dict[str, Any]): Filter criteria. + """ + collection_list = self.vector_db.config.collection_name + for collection_name in collection_list: + self.vector_db.delete_by_filter(collection_name=collection_name, filter=filter) + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: """Delete memories by their IDs and collection name. Args: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index e576c0ea9..c486e6cf6 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -347,6 +347,13 @@ def delete(self, memory_ids: list[str], user_name: str | None = None) -> None: except Exception as e: logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}") + def delete_by_memory_ids(self, memory_ids: list[str]) -> None: + """Delete memories by memory_ids.""" + try: + self.graph_store.delete_node_by_prams(memory_ids=memory_ids) + except Exception as e: + logger.error(f"An error occurred while deleting memories by memory_ids: {e}") + def delete_all(self) -> None: """Delete all memories and their relationships from the graph store.""" try: @@ -358,7 +365,7 @@ def delete_all(self) -> None: def delete_by_filter( self, - writable_cube_ids: list[str], + writable_cube_ids: list[str] | None = None, file_ids: list[str] | None = None, filter: dict | None = None, ) -> None: diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index ecbca5815..5dacf0499 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -646,3 +646,11 @@ def delete(self, collection_name: str, ids: list[str]) -> None: collection_name=collection_name, ids=ids, ) + + def delete_by_filter(self, collection_name: str, filter: dict[str, Any]) -> None: + """Delete items from the vector database by filter.""" + expr = self._dict_to_expr(filter) if filter else "" + self.client.delete( + collection_name=collection_name, + filter=expr, + ) From 5349674cbbe5564ea277a7cfde630576ecaf761c Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:50:05 +0800 Subject: [PATCH 31/48] fix: [PrefEval Evaluation] propagate --lib and --version arguments in search and response modes (#780) Co-authored-by: CaralHsi --- evaluation/scripts/run_prefeval_eval.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh index 6f5f3b7b0..f65873cb9 100755 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -108,7 +108,9 @@ python $LIB_SCRIPT search \ --input $IDS_FILE \ --output $SEARCH_FILE \ --top-k $TOP_K \ - --max-workers $WORKERS + --max-workers $WORKERS \ + --lib $LIB \ + --version $VERSION if [ $? -ne 0 ]; then echo "Error: $LIB_SCRIPT 'search' mode failed." @@ -121,7 +123,9 @@ echo "Running $LIB_SCRIPT in 'response' mode..." python $LIB_SCRIPT response \ --input $SEARCH_FILE \ --output $RESPONSE_FILE \ - --max-workers $WORKERS + --max-workers $WORKERS \ + --lib $LIB \ + --version $VERSION if [ $? -ne 0 ]; then echo "Error: $LIB_SCRIPT 'response' mode failed." From b3c9e845a920e9ec9a793d7cdd9febda26f9108a Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Sun, 4 Jan 2026 14:52:25 +0800 Subject: [PATCH 32/48] fix: fix context error and empty embedding (#817) fix: fix context Co-authored-by: CaralHsi --- .../memories/textual/tree_text_memory/retrieve/searcher.py | 2 ++ src/memos/multi_mem_cube/composite_cube.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index f3d6ba037..7e28c174b 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -674,6 +674,8 @@ def _retrieve_simple( ) logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") documents = [getattr(item, "memory", "") for item in items] + if not documents: + return [] documents_embeddings = self.embedder.embed(documents) similarity_matrix = cosine_similarity_matrix(documents_embeddings) selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py index 420856407..c1017bfae 100644 --- a/src/memos/multi_mem_cube/composite_cube.py +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -1,9 +1,10 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from memos.context.context import ContextThreadPoolExecutor from memos.multi_mem_cube.views import MemCubeView @@ -52,7 +53,7 @@ def _search_single_cube(view: SingleCubeView) -> dict[str, Any]: return view.search_memories(search_req) # parallel search for each cube - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: future_to_view = { executor.submit(_search_single_cube, view): view for view in self.cube_views } From 38f9e2fd43577fcfa0c6866dd7df88f4ab60cd7c Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Sun, 4 Jan 2026 15:48:08 +0800 Subject: [PATCH 33/48] Feat/optimize cloud service api (#818) * add get_user_names_by_memory_ids api * modify delete api * modify bug --------- Co-authored-by: yuan.wang --- src/memos/api/product_models.py | 2 +- src/memos/api/routers/server_router.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index c52d9e8d2..f0a4e333b 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1195,5 +1195,5 @@ class GetUserNamesByMemoryIdsRequest(BaseRequest): memory_ids: list[str] = Field(..., description="Memory IDs") -class GetUserNamesByMemoryIdsResponse(BaseResponse[dict[str, list[str]]]): +class GetUserNamesByMemoryIdsResponse(BaseResponse[dict[str, str | None]]): """Response model for getting user names by memory ids.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index c3b05e823..7c0f3ea8f 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -345,7 +345,7 @@ def feedback_memories(feedback_req: APIFeedbackRequest): summary="Get user names by memory ids", response_model=GetUserNamesByMemoryIdsResponse, ) -def get_user_names_by_memory_ids(memory_ids: GetUserNamesByMemoryIdsRequest): +def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): """Get user names by memory ids.""" if not isinstance(graph_db, PolarDBGraphDB): raise HTTPException( @@ -356,4 +356,9 @@ def get_user_names_by_memory_ids(memory_ids: GetUserNamesByMemoryIdsRequest): f"current graph_db is: {graph_db.__class__.__name__}" ), ) - return graph_db.get_user_names_by_memory_ids(memory_ids=memory_ids) + result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids) + return GetUserNamesByMemoryIdsResponse( + code=200, + message="Successfully", + data=result, + ) From 7142baa2ebf713256a3197c0b2433ecc9d30259e Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Sun, 4 Jan 2026 20:57:23 +0800 Subject: [PATCH 34/48] Feat/optimize cloud service api (#820) * add get_user_names_by_memory_ids api * modify delete api * modify bug * add extract limit in implicit memory * close internet search in chat api, modify implicit pref prompt --------- Co-authored-by: yuan.wang --- src/memos/api/handlers/chat_handler.py | 14 ++++++++++++++ src/memos/templates/prefer_complete_prompt.py | 2 ++ 2 files changed, 16 insertions(+) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 3e9d1e5ec..caeba0ca1 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -138,6 +138,13 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An if text_mem_results and text_mem_results[0].get("memories"): memories_list = text_mem_results[0]["memories"] + # Drop internet memories forced + memories_list = [ + mem + for mem in memories_list + if mem.get("metadata", {}).get("memory_type") != "OuterMemory" + ] + # Filter memories by threshold filtered_memories = self._filter_memories_by_threshold( memories_list, chat_req.threshold or 0.5 @@ -277,6 +284,13 @@ def generate_chat_response() -> Generator[str, None, None]: if text_mem_results and text_mem_results[0].get("memories"): memories_list = text_mem_results[0]["memories"] + # Drop internet memories forced + memories_list = [ + mem + for mem in memories_list + if mem.get("metadata", {}).get("memory_type") != "OuterMemory" + ] + # Filter memories by threshold filtered_memories = self._filter_memories_by_threshold(memories_list) diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 3315e061e..a67f0c12c 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -77,6 +77,7 @@ * **Contextual signals**: What do the user's choices, comparisons, exclusions, or scenario selections reveal about their deeper preferences? - Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned. - Go beyond surface-level facts to understand the user's hidden possibilities and underlying logic. +- For Assistant's responses or suggestions, they can only be extracted as the user's implicit preferences if there is evidence in subsequent conversation that the user implicitly accepted them (e.g., adoption, agreement, acting on the suggestion, etc.). Assistant suggestions alone do not constitute user preferences. Requirements: 1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. @@ -117,6 +118,7 @@ * **情境信号**:用户的选择、比较、排除或场景选择揭示了什么样的深层偏好? - 不要将明确陈述的偏好视为隐式偏好;此提示仅用于推断未直接提及的偏好。 - 超越表面事实,理解用户的隐藏可能性和背后的逻辑。 +- 对于Assistant的回答内容或建议,只有在后续对话中用户表现出隐含接受(如采纳、认同、按建议行动等)的情况下,才能将相关内容提取为用户的隐式偏好。单纯的Assistant建议本身不构成用户偏好。 要求: 1. 仅在对话中有充分证据时进行推断;避免无根据或牵强的猜测。 From 2a91bd68b2dc5ec4bd5955a8513f59dbef363027 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Mon, 5 Jan 2026 10:41:47 +0800 Subject: [PATCH 35/48] add exist_user_name for neo4j.py (#821) * feat: add exist_user_name * feat: add exist_user_name * feat: add exist_user_name for neo4j.py * feat: fix delete_node_by_prams by neo4j.py --- src/memos/graph_dbs/neo4j.py | 181 +++++++++++++++++++-------------- src/memos/graph_dbs/polardb.py | 49 +++++++++ 2 files changed, 154 insertions(+), 76 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index c2dc4a629..64aedc8f4 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1707,124 +1707,116 @@ def delete_node_by_prams( ) -> int: """ Delete nodes by memory_ids, file_ids, or filter. + Supports three scenarios: + 1. Delete by memory_ids (standalone) + 2. Delete by writable_cube_ids + file_ids (combined) + 3. Delete by filter (standalone, no writable_cube_ids needed) Args: writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes. - If not provided, no user_name filter will be applied. + Only used with file_ids scenario. If not provided, no user_name filter will be applied. memory_ids (list[str], optional): List of memory node IDs to delete. - file_ids (list[str], optional): List of file node IDs to delete. - filter (dict, optional): Filter dictionary to query matching nodes for deletion. + file_ids (list[str], optional): List of file node IDs to delete. Must be used with writable_cube_ids. + filter (dict, optional): Filter dictionary for metadata filtering. + Filter conditions are directly used in DELETE WHERE clause without pre-querying. + Does not require writable_cube_ids. Returns: int: Number of nodes deleted. """ + batch_start_time = time.time() logger.info( f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" ) - print( - f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}" - ) - - # Build WHERE conditions separately for memory_ids and file_ids - where_clauses = [] - params = {} # Build user_name condition from writable_cube_ids (OR relationship - match any cube_id) - # Only add user_name filter if writable_cube_ids is provided + # Only add user_name filter if writable_cube_ids is provided (for file_ids scenario) user_name_conditions = [] + params = {} if writable_cube_ids and len(writable_cube_ids) > 0: for idx, cube_id in enumerate(writable_cube_ids): param_name = f"cube_id_{idx}" user_name_conditions.append(f"n.user_name = ${param_name}") params[param_name] = cube_id - # Handle memory_ids: query n.id - if memory_ids and len(memory_ids) > 0: + # Build filter conditions using common method (no query, direct use in WHERE clause) + filter_conditions = [] + filter_params = {} + if filter: + filter_conditions, filter_params = self._build_filter_conditions_cypher( + filter, param_counter_start=0, node_alias="n" + ) + logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}") + params.update(filter_params) + + # If no conditions to delete, return 0 + if not memory_ids and not file_ids and not filter_conditions: + logger.warning( + "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" + ) + return 0 + + # Build WHERE conditions list + where_clauses = [] + + # Scenario 1: memory_ids (standalone) + if memory_ids: + logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids") where_clauses.append("n.id IN $memory_ids") params["memory_ids"] = memory_ids - # Handle file_ids: query n.file_ids field - # All file_ids must be present in the array field (AND relationship) - if file_ids and len(file_ids) > 0: - file_id_and_conditions = [] + # Scenario 2: file_ids + writable_cube_ids (combined) + if file_ids: + logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids") + file_id_conditions = [] for idx, file_id in enumerate(file_ids): param_name = f"file_id_{idx}" params[param_name] = file_id # Check if this file_id is in the file_ids array field - file_id_and_conditions.append(f"${param_name} IN n.file_ids") - if file_id_and_conditions: - # Use AND to require all file_ids to be present - where_clauses.append(f"({' OR '.join(file_id_and_conditions)})") - - # Query nodes by filter if provided - filter_ids = [] - if filter: - # Use get_by_metadata with empty filters list and filter - filter_ids = self.get_by_metadata( - filters=[], - user_name=None, - filter=filter, - knowledgebase_ids=writable_cube_ids if writable_cube_ids else None, - ) + file_id_conditions.append(f"${param_name} IN n.file_ids") + if file_id_conditions: + where_clauses.append(f"({' OR '.join(file_id_conditions)})") - # If filter returned IDs, add condition for them - if filter_ids: - where_clauses.append("n.id IN $filter_ids") - params["filter_ids"] = filter_ids + # Scenario 3: filter (standalone, no writable_cube_ids needed) + if filter_conditions: + logger.info("[delete_node_by_prams] Processing filter conditions") + # Combine filter conditions with AND + filter_where = " AND ".join(filter_conditions) + where_clauses.append(f"({filter_where})") - # If no conditions (except user_name), return 0 + # Build final WHERE clause if not where_clauses: - logger.warning( - "[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)" - ) + logger.warning("[delete_node_by_prams] No WHERE conditions to delete") return 0 - # Build WHERE clause - # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match) - data_conditions = " OR ".join([f"({clause})" for clause in where_clauses]) + # Combine all conditions with AND + data_conditions = " AND ".join([f"({clause})" for clause in where_clauses]) - # Build final WHERE clause - # If user_name_conditions exist, combine with data_conditions using AND - # Otherwise, use only data_conditions + # Add user_name filter if provided (for file_ids scenario) if user_name_conditions: user_name_where = " OR ".join(user_name_conditions) - ids_where = f"({user_name_where}) AND ({data_conditions})" + final_where = f"({user_name_where}) AND ({data_conditions})" else: - ids_where = f"({data_conditions})" - - logger.info( - f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" - ) - print( - f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}" - ) + final_where = data_conditions - # First count matching nodes to get accurate count - count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count" - logger.info(f"[delete_node_by_prams] count_query: {count_query}") - print(f"[delete_node_by_prams] count_query: {count_query}") - - # Then delete nodes - delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n" + # Delete directly without pre-counting + delete_query = f"MATCH (n:Memory) WHERE {final_where} DETACH DELETE n" logger.info(f"[delete_node_by_prams] delete_query: {delete_query}") - print(f"[delete_node_by_prams] delete_query: {delete_query}") - print(f"[delete_node_by_prams] params: {params}") deleted_count = 0 try: with self.driver.session(database=self.db_name) as session: - # Count nodes before deletion - count_result = session.run(count_query, **params) - count_record = count_result.single() - expected_count = 0 - if count_record: - expected_count = count_record["node_count"] or 0 - - # Delete nodes - session.run(delete_query, **params) - # Use the count from before deletion as the actual deleted count - deleted_count = expected_count - + # Execute delete query + result = session.run(delete_query, **params) + # Consume the result to ensure deletion completes and get the summary + summary = result.consume() + # Get the count from the result summary + deleted_count = summary.counters.nodes_deleted if summary.counters else 0 + + elapsed_time = time.time() - batch_start_time + logger.info( + f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {deleted_count} nodes" + ) except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise @@ -1884,3 +1876,40 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise + + def exist_user_name(self, user_name: str) -> dict[str, bool]: + """Check if user name exists in the graph. + + Args: + user_name: User name to check. + + Returns: + dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. + """ + logger.info(f"[exist_user_name] Querying user_name {user_name}") + if not user_name: + return {user_name: False} + + try: + with self.driver.session(database=self.db_name) as session: + # Query to check if user_name exists + query = """ + MATCH (n:Memory) + WHERE n.user_name = $user_name + RETURN COUNT(n) AS count + """ + logger.info(f"[exist_user_name] query: {query}") + + result = session.run(query, user_name=user_name) + count = result.single()["count"] + result_dict = {user_name: count > 0} + + logger.info( + f"[exist_user_name] user_name {user_name} exists: {result_dict[user_name]}" + ) + return result_dict + except Exception as e: + logger.error( + f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True + ) + raise diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b0a8bc4be..d1c2716c8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -5316,3 +5316,52 @@ def escape_memory_id(mid: str) -> str: raise finally: self._return_connection(conn) + + def exist_user_name(self, user_name: str) -> dict[str, bool]: + """Check if user name exists in the graph. + + Args: + user_name: User name to check. + + Returns: + dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence. + """ + logger.info(f"[exist_user_name] Querying user_name {user_name}") + if not user_name: + return {user_name: False} + + # Escape special characters for JSON string format in agtype + def escape_user_name(un: str) -> str: + """Escape special characters in user_name for JSON string format.""" + # Escape backslashes first, then double quotes + un_str = un.replace("\\", "\\\\") + un_str = un_str.replace('"', '\\"') + return un_str + + # Escape special characters + escaped_un = escape_user_name(user_name) + + # Query to check if user_name exists + query = f""" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" + WHERE ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype + """ + logger.info(f"[exist_user_name] query: {query}") + result_dict = {} + conn = None + try: + conn = self._get_connection() + with conn.cursor() as cursor: + cursor.execute(query) + count = cursor.fetchone()[0] + result = count > 0 + result_dict[user_name] = result + return result_dict + except Exception as e: + logger.error( + f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True + ) + raise + finally: + self._return_connection(conn) From c5b6f15ce34993ec3b132705c03d7f32cf0391c4 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 5 Jan 2026 14:46:55 +0800 Subject: [PATCH 36/48] Feat/optimize cloud service api (#822) * add get_user_names_by_memory_ids api * modify delete api * modify bug * add extract limit in implicit memory * close internet search in chat api, modify implicit pref prompt * modify bug * add a new internal method for check cube id exist --------- Co-authored-by: yuan.wang --- src/memos/api/handlers/chat_handler.py | 8 ++++---- src/memos/api/product_models.py | 10 ++++++++++ src/memos/api/routers/server_router.py | 16 ++++++++++++++++ .../textual/tree_text_memory/retrieve/utils.py | 2 +- 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index caeba0ca1..812cf2793 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -482,7 +482,7 @@ def generate_chat_response() -> Generator[str, None, None]: # get preference string pref_string = search_response.data.get("pref_string", "") - yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + yield f"data: {json.dumps({'type': 'reference', 'data': reference}, ensure_ascii=False)}\n\n" # Prepare preference markdown string if chat_req.include_preference: @@ -586,7 +586,7 @@ def generate_chat_response() -> Generator[str, None, None]: internet_reference = self._get_internet_reference( search_response.data.get("text_mem")[0]["memories"] ) - yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + yield f"data: {json.dumps({'type': 'reference', 'data': reference}, ensure_ascii=False)}\n\n" # Step 2: Build system prompt with memories lang = detect_lang(chat_req.query) @@ -684,7 +684,7 @@ def generate_chat_response() -> Generator[str, None, None]: if chat_req.internet_search or parsed_goal.internet_search: # Yield internet reference after text response - yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" + yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference}, ensure_ascii=False)}\n\n" # Calculate timing time_end = time.time() @@ -697,7 +697,7 @@ def generate_chat_response() -> Generator[str, None, None]: current_messages.append({"role": "assistant", "content": full_response}) further_suggestion = self._get_further_suggestion(current_messages) self.logger.info(f"[PLAYGROUND CHAT] further_suggestion: {further_suggestion}") - yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n" + yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'end'})}\n\n" diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index f0a4e333b..ee7a45c2d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1197,3 +1197,13 @@ class GetUserNamesByMemoryIdsRequest(BaseRequest): class GetUserNamesByMemoryIdsResponse(BaseResponse[dict[str, str | None]]): """Response model for getting user names by memory ids.""" + + +class ExistMemCubeIdRequest(BaseRequest): + """Request model for checking if mem cube id exists.""" + + mem_cube_id: str = Field(..., description="Mem cube ID") + + +class ExistMemCubeIdResponse(BaseResponse[dict[str, bool]]): + """Response model for checking if mem cube id exists.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 7c0f3ea8f..c60e84253 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -33,6 +33,8 @@ ChatRequest, DeleteMemoryRequest, DeleteMemoryResponse, + ExistMemCubeIdRequest, + ExistMemCubeIdResponse, GetMemoryPlaygroundRequest, GetMemoryRequest, GetMemoryResponse, @@ -362,3 +364,17 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): message="Successfully", data=result, ) + + +@router.post( + "/exist_mem_cube_id", + summary="Check if mem cube id exists", + response_model=ExistMemCubeIdResponse, +) +def exist_mem_cube_id(request: ExistMemCubeIdRequest): + """Check if mem cube id exists.""" + return ExistMemCubeIdResponse( + code=200, + message="Successfully", + data=graph_db.exist_user_name(user_name=request.mem_cube_id), + ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py index bcd47b078..54caa20f7 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py @@ -27,7 +27,7 @@ "tags": [...], "goal_type": "retrieval | qa | generation", "rephrased_instruction": "...", # return an empty string if the original instruction is easy enough to understand - "internet_search": True/False, + "internet_search": true/false, "memories": ["...", "...", ...] } """ From 0abb555677fb0494aed694953ea6f5f349fef10f Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Tue, 6 Jan 2026 16:36:03 +0800 Subject: [PATCH 37/48] Feat/optimize cloud service api (#825) * add get_user_names_by_memory_ids api * modify delete api * modify bug * add extract limit in implicit memory * close internet search in chat api, modify implicit pref prompt * modify bug * add a new internal method for check cube id exist * modify code --------- Co-authored-by: yuan.wang --- src/memos/api/routers/server_router.py | 8 ++++++++ src/memos/memories/textual/preference.py | 2 +- src/memos/memories/textual/simple_preference.py | 8 ++++---- src/memos/vec_dbs/milvus.py | 6 ++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index c60e84253..a4052d313 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -90,6 +90,7 @@ status_tracker = TaskStatusTracker(redis_client=redis_client) embedder = components["embedder"] graph_db = components["graph_db"] +vector_db = components["vector_db"] # ============================================================================= @@ -359,6 +360,13 @@ def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): ), ) result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids) + if vector_db: + prefs = [] + for collection_name in ["explicit_preference", "implicit_preference"]: + prefs.extend( + vector_db.get_by_ids(collection_name=collection_name, ids=request.memory_ids) + ) + result.update({pref.id: pref.payload.get("mem_cube_id", None) for pref in prefs}) return GetUserNamesByMemoryIdsResponse( code=200, message="Successfully", diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index a34315918..78f4d6e28 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -248,7 +248,7 @@ def get_all(self) -> list[TextualMemoryItem]: Returns: list[TextualMemoryItem]: List of all memories. """ - all_collections = self.vector_db.list_collections() + all_collections = ["explicit_preference", "implicit_preference"] all_memories = {} for collection_name in all_collections: items = self.vector_db.get_all(collection_name) diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index ee37d638c..cc1781f06 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -90,7 +90,7 @@ def get_with_collection_name( return None return TextualMemoryItem( id=res.id, - memory=res.payload.get("dialog_str", ""), + memory=res.memory, metadata=PreferenceTextualMemoryMetadata(**res.payload), ) except Exception as e: @@ -116,7 +116,7 @@ def get_by_ids_with_collection_name( return [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in res @@ -132,14 +132,14 @@ def get_all(self) -> list[TextualMemoryItem]: Returns: list[TextualMemoryItem]: List of all memories. """ - all_collections = self.vector_db.list_collections() + all_collections = ["explicit_preference", "implicit_preference"] all_memories = {} for collection_name in all_collections: items = self.vector_db.get_all(collection_name) all_memories[collection_name] = [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in items diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 5dacf0499..cc8909d34 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -457,14 +457,13 @@ def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None: return None entity = results[0] - payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} return MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), original_text=entity.get("original_text"), vector=entity.get("vector"), - payload=payload, + payload=entity.get("payload", {}), ) def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]: @@ -479,14 +478,13 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt items = [] for entity in results: - payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} items.append( MilvusVecDBItem( id=entity["id"], memory=entity.get("memory"), original_text=entity.get("original_text"), vector=entity.get("vector"), - payload=payload, + payload=entity.get("payload", {}), ) ) From 85860ce9bedce9afc6c18a8da6aa111afed03108 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 6 Jan 2026 16:53:52 +0800 Subject: [PATCH 38/48] feat: add filter time query (#826) --- src/memos/graph_dbs/polardb.py | 62 +++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index d1c2716c8..e67f866ac 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -4333,15 +4333,29 @@ def build_cypher_filter_condition(condition_dict: dict) -> str: cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} cypher_op = cypher_op_map[op] + # Check if key is a datetime field + is_datetime = key in ("created_at", "updated_at") or key.endswith( + "_at" + ) + # Check if key starts with "info." prefix (for nested fields like info.A, info.B) if key.startswith("info."): # Nested field access: n.info.field_name info_field = key[5:] # Remove "info." prefix + is_info_datetime = info_field in ( + "created_at", + "updated_at", + ) or info_field.endswith("_at") if isinstance(op_value, str): escaped_value = escape_cypher_string(op_value) - condition_parts.append( - f"n.info.{info_field} {cypher_op} '{escaped_value}'" - ) + if is_info_datetime: + condition_parts.append( + f"n.info.{info_field}::timestamp {cypher_op} '{escaped_value}'::timestamp" + ) + else: + condition_parts.append( + f"n.info.{info_field} {cypher_op} '{escaped_value}'" + ) else: condition_parts.append( f"n.info.{info_field} {cypher_op} {op_value}" @@ -4350,9 +4364,14 @@ def build_cypher_filter_condition(condition_dict: dict) -> str: # Direct property access (e.g., "created_at" is directly in n, not in n.info) if isinstance(op_value, str): escaped_value = escape_cypher_string(op_value) - condition_parts.append( - f"n.{key} {cypher_op} '{escaped_value}'" - ) + if is_datetime: + condition_parts.append( + f"n.{key}::timestamp {cypher_op} '{escaped_value}'::timestamp" + ) + else: + condition_parts.append( + f"n.{key} {cypher_op} '{escaped_value}'" + ) else: condition_parts.append(f"n.{key} {cypher_op} {op_value}") elif op == "=": @@ -4676,15 +4695,29 @@ def build_filter_condition(condition_dict: dict) -> str: sql_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="} sql_op = sql_op_map[op] + # Check if key is a datetime field + is_datetime = key in ("created_at", "updated_at") or key.endswith( + "_at" + ) + # Check if key starts with "info." prefix (for nested fields like info.A, info.B) if key.startswith("info."): # Nested field access: properties->'info'->'field_name' info_field = key[5:] # Remove "info." prefix + is_info_datetime = info_field in ( + "created_at", + "updated_at", + ) or info_field.endswith("_at") if isinstance(op_value, str): escaped_value = escape_sql_string(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype" - ) + if is_info_datetime: + condition_parts.append( + f"TRIM(BOTH '\"' FROM ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text)::timestamp {sql_op} '{escaped_value}'::timestamp" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype" + ) else: # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype value_json = json.dumps(op_value) @@ -4695,9 +4728,14 @@ def build_filter_condition(condition_dict: dict) -> str: # Direct property access (e.g., "created_at" is directly in properties, not in properties.info) if isinstance(op_value, str): escaped_value = escape_sql_string(op_value) - condition_parts.append( - f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype" - ) + if is_datetime: + condition_parts.append( + f"TRIM(BOTH '\"' FROM ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text)::timestamp {sql_op} '{escaped_value}'::timestamp" + ) + else: + condition_parts.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype" + ) else: # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype value_json = json.dumps(op_value) From d632ddede2ff857de0b1441846a4c10dc9112df6 Mon Sep 17 00:00:00 2001 From: Hustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 6 Jan 2026 17:46:06 +0800 Subject: [PATCH 39/48] add getMemory sdk (#827) * fix: update get_memory sdk * fix: update get_memory sdk * fix: update get_memory sdk * fix: update get_memory sdk --------- Co-authored-by: Elvis <1693372324@qq.com> --- src/memos/api/client.py | 6 +++++- src/memos/api/product_models.py | 13 ++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/memos/api/client.py b/src/memos/api/client.py index 1129ddddf..91bc86829 100644 --- a/src/memos/api/client.py +++ b/src/memos/api/client.py @@ -177,7 +177,9 @@ def search_memory( if retry == MAX_RETRY_COUNT - 1: raise - def get_memory(self, user_id: str, include_preference: str) -> MemOSGetMemoryResponse | None: + def get_memory( + self, user_id: str, include_preference: bool = True, page: int = 1, size: int = 10 + ) -> MemOSGetMemoryResponse | None: """get memories""" # Validate required parameters self._validate_required_params(include_preference=include_preference, user_id=user_id) @@ -186,6 +188,8 @@ def get_memory(self, user_id: str, include_preference: str) -> MemOSGetMemoryRes payload = { "include_preference": include_preference, "user_id": user_id, + "page": page, + "size": size, } for retry in range(MAX_RETRY_COUNT): diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index ee7a45c2d..d5f301c9d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -872,8 +872,8 @@ class GetMemoryData(BaseModel): memory_detail_list: list[MemoryDetail] = Field( default_factory=list, alias="memory_detail_list", description="List of memory details" ) - message_detail_list: list[MessageDetail] | None = Field( - None, alias="message_detail_list", description="List of message details (usually None)" + preference_detail_list: list[MessageDetail] | None = Field( + None, alias="preference_detail_list", description="List of preference detail" ) @@ -1025,7 +1025,7 @@ class MemOSGetMemoryResponse(BaseModel): code: int = Field(..., description="Response status code") message: str = Field(..., description="Response message") - data: SearchMemoryData = Field(..., description="Get results data") + data: GetMemoryData = Field(..., description="Get results data") @property def memories(self) -> list[MemoryDetail]: @@ -1033,15 +1033,10 @@ def memories(self) -> list[MemoryDetail]: return self.data.memory_detail_list @property - def preferences(self) -> list[MemoryDetail]: + def preferences(self) -> list[MessageDetail] | None: """Convenient access to preference list.""" return self.data.preference_detail_list - @property - def tool_memories(self) -> list[MemoryDetail]: - """Convenient access to tool_memory list.""" - return self.data.tool_memory_detail_list - class MemOSGetKnowledgebaseFileResponse(BaseModel): """Response model for get KnowledgebaseFile operation based on actual API.""" From bbca35f67ad238ebcbf9eca497c86c8f2cec312c Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 7 Jan 2026 14:48:21 +0800 Subject: [PATCH 40/48] feat: Merge from main (some hot-fix) (#832) * fix: fix bugs when running local queue for memos * fix: remove an unnecessary function * fix: update README.md * update requirement,Dockerfile * fix: update README.md * fix: update README.md * feat: update readme * feat: fix NACOS * feat: add timer log * update requirements * fix: 12.26 update README.md * change local server name * update docker-compose.yml * fix: issues caused by no reading default use_redis from env * feat: fix requirements * add neo4j * fix: logs context and empty embedding * fix reranker * fix: conflict --------- Co-authored-by: chentang Co-authored-by: Elvis <1693372324@qq.com> Co-authored-by: pursues <15180521816@163.com> Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> Co-authored-by: liji <532311301@qq.com> Co-authored-by: harvey_xiang Co-authored-by: lijicode <34564964+lijicode@users.noreply.github.com> Co-authored-by: Zehao Lin Co-authored-by: fridayL --- README.md | 217 +++++++----------- docker/.env.example | 153 ++++++------ docker/docker-compose.yml | 3 +- docker/requirements-full.txt | 4 +- docker/requirements.txt | 39 ++-- src/memos/api/config.py | 2 +- src/memos/api/server_api.py | 4 +- .../task_schedule_modules/dispatcher.py | 1 + .../task_schedule_modules/redis_queue.py | 8 +- .../mem_scheduler/utils/status_tracker.py | 3 + .../tree_text_memory/retrieve/searcher.py | 45 ++++ src/memos/utils.py | 2 +- 12 files changed, 252 insertions(+), 229 deletions(-) diff --git a/README.md b/README.md index 29a50c1da..edd4e8905 100644 --- a/README.md +++ b/README.md @@ -118,89 +118,31 @@ showcasing its capabilities in **information extraction**, **temporal and cross- - **🔌 Extensible**: Easily extend and customize memory modules, data sources, and LLM integrations. -## 📦 Installation +## 🚀 Quickstart Guide -### Install via pip - -```bash -pip install MemoryOS -``` - -### Optional Dependencies - -MemOS provides several optional dependency groups for different features. You can install them based on your needs. - -| Feature | Package Name | -| --------------------- | ------------------------- | -| Tree Memory | `MemoryOS[tree-mem]` | -| Memory Reader | `MemoryOS[mem-reader]` | -| Memory Scheduler | `MemoryOS[mem-scheduler]` | - -Example installation commands: - -```bash -pip install MemoryOS[tree-mem] -pip install MemoryOS[tree-mem,mem-reader] -pip install MemoryOS[mem-scheduler] -pip install MemoryOS[tree-mem,mem-reader,mem-scheduler] -``` - -### External Dependencies - -#### Ollama Support - -To use MemOS with [Ollama](https://ollama.com/), first install the Ollama CLI: - -```bash -curl -fsSL https://ollama.com/install.sh | sh -``` - -#### Transformers Support - -To use functionalities based on the `transformers` library, ensure you have [PyTorch](https://pytorch.org/get-started/locally/) installed (CUDA version recommended for GPU acceleration). - -#### Download Examples +### Get API Key + - Sign up and get started on[`MemOS dashboard`](https://memos-dashboard.openmem.net/cn/quickstart/?source=landing) + - Open the API Keys Console in the MemOS dashboard and copy the API Key into the initialization code -To download example code, data and configurations, run the following command: - -```bash -memos download_examples -``` - - -## 🚀 Getting Started - -### ⭐️ MemOS online API -The easiest way to use MemOS. Equip your agent with memory **in minutes**! - -Sign up and get started on[`MemOS dashboard`](https://memos-dashboard.openmem.net/cn/quickstart/?source=landing). - - -### Self-Hosted Server -1. Get the repository. -```bash -git clone https://github.com/MemTensor/MemOS.git -cd MemOS -pip install -r ./docker/requirements.txt -``` +### Install via pip -2. Configure `docker/.env.example` and copy to `MemOS/.env` -3. Start the service. ```bash -uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8001 --workers 8 +pip install MemoryOS -U ``` -### Interface SDK -#### Here is a quick example showing how to create all interface SDK +### Basic Usage -This interface is used to add messages, supporting multiple types of content and batch additions. MemOS will automatically parse the messages and handle memory for reference in subsequent conversations. +- Initialize MemOS client with API Key to start sending requests ```python # Please make sure MemoS is installed (pip install MemoryOS -U) from memos.api.client import MemOSClient # Initialize the client using the API Key client = MemOSClient(api_key="YOUR_API_KEY") +``` +- This API allows you to add one or more messages to a specific conversation. As illustrated in the examples bellow, you can add messages in real time during a user-assistant interaction, import historical messages in bulk, or enrich the conversation with user preferences and behavior data. All added messages are transformed into memories by MemOS, enabling their retrieval in future conversations to support chat history management, user behavior tracking, and personalized interactions. +```python messages = [ {"role": "user", "content": "I have planned to travel to Guangzhou during the summer vacation. What chain hotels are available for accommodation?"}, {"role": "assistant", "content": "You can consider [7 Days, All Seasons, Hilton], and so on."}, @@ -214,79 +156,90 @@ res = client.add_message(messages=messages, user_id=user_id, conversation_id=con print(f"result: {res}") ``` -This interface is used to retrieve the memories of a specified user, returning the memory fragments most relevant to the input query for Agent use. The recalled memory fragments include 'factual memory', 'preference memory', and 'tool memory'. +- This API allows you to query a user’s memory and returns the fragments most relevant to the input. These can serve as references for the model when generating responses. As shown in the examples bellow, You can retrieve memory in real time during a user’s conversation with the AI, or perform a global search across their entire memory to create user profiles or support personalized recommendations, improving both dialogue coherence and personalization. +In the latest update, in addition to “Fact Memory”, the system now supports “Preference Memory”, enabling LLM to respond in a way that better understands the user. ```python -# Please make sure MemoS is installed (pip install MemoryOS -U) -from memos.api.client import MemOSClient - -# Initialize the client using the API Key -client = MemOSClient(api_key="YOUR_API_KEY") - query = "I want to go out to play during National Day. Can you recommend a city I haven't been to and a hotel brand I haven't stayed at?" user_id = "memos_user_123" -conversation_id = "0928" +conversation_id = "0610" res = client.search_memory(query=query, user_id=user_id, conversation_id=conversation_id) print(f"result: {res}") ``` -This interface is used to delete the memory of specified users and supports batch deletion. -```python -# Please make sure MemoS is installed (pip install MemoryOS -U) -from memos.api.client import MemOSClient - -# Initialize the client using the API Key -client = MemOSClient(api_key="YOUR_API_KEY") - -user_ids = ["memos_user_123"] -# Replace with the memory ID -memory_ids = ["6b23b583-f4c4-4a8f-b345-58d0c48fea04"] -res = client.delete_memory(user_ids=user_ids, memory_ids=memory_ids) - -print(f"result: {res}") -``` - -This interface is used to add feedback to messages in the current session, allowing MemOS to correct its memory based on user feedback. -```python -# Please make sure MemoS is installed (pip install MemoryOS -U) -from memos.api.client import MemOSClient - -# Initialize the client using the API Key -client = MemOSClient(api_key="YOUR_API_KEY") - -user_id = "memos_user_123" -conversation_id = "memos_feedback_conv" -feedback_content = "No, let's change it now to a meal allowance of 150 yuan per day and a lodging subsidy of 700 yuan per day for first-tier cities; for second- and third-tier cities, it remains the same as before." -# Replace with the knowledgebase ID -allow_knowledgebase_ids = ["basee5ec9050-c964-484f-abf1-ce3e8e2aa5b7"] - -res = client.add_feedback( - user_id=user_id, - conversation_id=conversation_id, - feedback_content=feedback_content, - allow_knowledgebase_ids=allow_knowledgebase_ids -) - -print(f"result: {res}") -``` - -This interface is used to create a knowledgebase associated with a project -```python -# Please make sure MemoS is installed (pip install MemoryOS -U) -from memos.api.client import MemOSClient - -# Initialize the client using the API Key -client = MemOSClient(api_key="YOUR_API_KEY") -knowledgebase_name = "Financial Reimbursement Knowledge Base" -knowledgebase_description = "A compilation of all knowledge related to the company's financial reimbursements." +### Self-Hosted Server +1. Get the repository. + ```bash + git clone https://github.com/MemTensor/MemOS.git + cd MemOS + pip install -r ./docker/requirements.txt + ``` +2. Configure `docker/.env.example` and copy to `MemOS/.env` + - The `OPENAI_API_KEY`,`MOS_EMBEDDER_API_KEY`,`MEMRADER_API_KEY` and others can be applied for through [`BaiLian`](https://bailian.console.aliyun.com/?spm=a2c4g.11186623.0.0.2f2165b08fRk4l&tab=api#/api). + - Fill in the corresponding configuration in the `MemOS/.env` file. +3. Start the service. -res = client.create_knowledgebase( - knowledgebase_name=knowledgebase_name, - knowledgebase_description=knowledgebase_description -) -print(f"result: {res}") -``` +- Launch via Docker + ###### Tips: Please ensure that Docker Compose is installed successfully and that you have navigated to the docker directory (via `cd docker`) before executing the following command. + ```bash + # Enter docker directory + docker compose up + ``` + ##### If you prefer to deploy using Docker, please refer to the [`Docker Reference`](https://docs.openmem.net/open_source/getting_started/rest_api_server/#method-1-docker-use-repository-dependency-package-imagestart-recommended-use). + +- Launch via the uvicorn command line interface (CLI) + ###### Tips: Please ensure that Neo4j and Qdrant are running before executing the following command. + ```bash + uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8001 --workers 1 + ``` + ##### For detailed integration steps, see the [`CLI Reference`](https://docs.openmem.net/open_source/getting_started/rest_api_server/#method-3client-install-with-CLI). + + + +Example + - Add User Message + ```python + import requests + import json + + data = { + "user_id": "8736b16e-1d20-4163-980b-a5063c3facdc", + "mem_cube_id": "b32d0977-435d-4828-a86f-4f47f8b55bca", + "messages": [ + { + "role": "user", + "content": "I like strawberry" + } + ], + "async_mode": "sync" + } + headers = { + "Content-Type": "application/json" + } + url = "http://localhost:8000/product/add" + + res = requests.post(url=url, headers=headers, data=json.dumps(data)) + print(f"result: {res.json()}") + ``` + - Search User Memory + ```python + import requests + import json + + data = { + "query": "What do I like", + "user_id": "8736b16e-1d20-4163-980b-a5063c3facdc", + "mem_cube_id": "b32d0977-435d-4828-a86f-4f47f8b55bca" + } + headers = { + "Content-Type": "application/json" + } + url = "http://localhost:8000/product/search" + + res = requests.post(url=url, headers=headers, data=json.dumps(data)) + print(f"result: {res.json()}") + ``` ## 💬 Community & Support diff --git a/docker/.env.example b/docker/.env.example index dc4252133..ee26c7bcd 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -3,32 +3,31 @@ ## Base TZ=Asia/Shanghai -ENV_NAME=PLAYGROUND_OFFLINE # Tag shown in DingTalk notifications (e.g., PROD_ONLINE/TEST); no runtime effect unless ENABLE_DINGDING_BOT=true MOS_CUBE_PATH=/tmp/data_test # local data path MEMOS_BASE_PATH=. # CLI/SDK cache path MOS_ENABLE_DEFAULT_CUBE_CONFIG=true # enable default cube config MOS_ENABLE_REORGANIZE=false # enable memory reorg +# MOS Text Memory Type MOS_TEXT_MEM_TYPE=general_text # general_text | tree_text ASYNC_MODE=sync # async/sync, used in default cube config ## User/session defaults -MOS_USER_ID=root -MOS_SESSION_ID=default_session -MOS_MAX_TURNS_WINDOW=20 +# Top-K for LLM in the Product API(old version) MOS_TOP_K=50 ## Chat LLM (main dialogue) +# LLM model name for the Product API MOS_CHAT_MODEL=gpt-4o-mini +# Temperature for LLM in the Product API MOS_CHAT_TEMPERATURE=0.8 +# Max tokens for LLM in the Product API MOS_MAX_TOKENS=2048 +# Top-P for LLM in the Product API MOS_TOP_P=0.9 +# LLM for the Product API backend MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm -MOS_MODEL_SCHEMA=memos.configs.llm.VLLMLLMConfig # vllm only: config class path; keep default unless you extend it OPENAI_API_KEY=sk-xxx # [required] when provider=openai OPENAI_API_BASE=https://api.openai.com/v1 # [required] base for the key -OPENAI_BASE_URL= # compatibility for eval/scheduler -VLLM_API_KEY= # required when provider=vllm -VLLM_API_BASE=http://localhost:8088/v1 # required when provider=vllm ## MemReader / retrieval LLM MEMRADER_MODEL=gpt-4o-mini @@ -37,40 +36,61 @@ MEMRADER_API_BASE=http://localhost:3000/v1 # [required] base for the key MEMRADER_MAX_TOKENS=5000 ## Embedding & rerank +# embedding dim EMBEDDING_DIMENSION=1024 +# set default embedding backend MOS_EMBEDDER_BACKEND=universal_api # universal_api | ollama +# set openai style MOS_EMBEDDER_PROVIDER=openai # required when universal_api +# embedding model MOS_EMBEDDER_MODEL=bge-m3 # siliconflow → use BAAI/bge-m3 +# embedding url MOS_EMBEDDER_API_BASE=http://localhost:8000/v1 # required when universal_api +# embedding model key MOS_EMBEDDER_API_KEY=EMPTY # required when universal_api OLLAMA_API_BASE=http://localhost:11434 # required when backend=ollama +# reranker config MOS_RERANKER_BACKEND=http_bge # http_bge | http_bge_strategy | cosine_local +# reranker url MOS_RERANKER_URL=http://localhost:8001 # required when backend=http_bge* +# reranker model MOS_RERANKER_MODEL=bge-reranker-v2-m3 # siliconflow → use BAAI/bge-reranker-v2-m3 MOS_RERANKER_HEADERS_EXTRA= # extra headers, JSON string, e.g. {"Authorization":"Bearer your_token"} +# use source MOS_RERANKER_STRATEGY=single_turn -MOS_RERANK_SOURCE= # optional rerank scope, e.g., history/stream/custom # External Services (for evaluation scripts) +# API key for reproducting Zep(compertitor product) evaluation ZEP_API_KEY=your_zep_api_key_here +# API key for reproducting Mem0(compertitor product) evaluation MEM0_API_KEY=your_mem0_api_key_here +# API key for reproducting MemU(compertitor product) evaluation +MEMU_API_KEY=your_memu_api_key_here +# API key for reproducting MEMOBASE(compertitor product) evaluation +MEMOBASE_API_KEY=your_memobase_api_key_here +# Project url for reproducting MEMOBASE(compertitor product) evaluation +MEMOBASE_PROJECT_URL=your_memobase_project_url_here +# LLM for evaluation MODEL=gpt-4o-mini +# embedding model for evaluation EMBEDDING_MODEL=nomic-embed-text:latest + ## Internet search & preference memory +# Enable web search ENABLE_INTERNET=false +# API key for BOCHA Search BOCHA_API_KEY= # required if ENABLE_INTERNET=true -XINYU_API_KEY= -XINYU_SEARCH_ENGINE_ID= +# default search mode SEARCH_MODE=fast # fast | fine | mixture -FAST_GRAPH=false -BM25_CALL=false -VEC_COT_CALL=false +# Slow retrieval strategy configuration, rewrite is the rewrite strategy FINE_STRATEGY=rewrite # rewrite | recreate | deep_search -ENABLE_ACTIVATION_MEMORY=false +# Whether to enable preference memory ENABLE_PREFERENCE_MEMORY=true +# Preference Memory Add Mode PREFERENCE_ADDER_MODE=fast # fast | safe +# Whether to deduplicate explicit preferences based on factual memory DEDUP_PREF_EXP_BY_TEXTUAL=false ## Reader chunking @@ -81,66 +101,71 @@ MEM_READER_CHAT_CHUNK_SESS_SIZE=10 # sessions per chunk (default mode) MEM_READER_CHAT_CHUNK_OVERLAP=2 # overlap between chunks ## Scheduler (MemScheduler / API) +# Enable or disable the main switch for configuring the memory scheduler during MemOS class initialization MOS_ENABLE_SCHEDULER=false +# Determine the number of most relevant memory entries that the scheduler retrieves or processes during runtime (such as reordering or updating working memory) MOS_SCHEDULER_TOP_K=10 +# The time interval (in seconds) for updating "Activation Memory" (usually referring to caching or short-term memory mechanisms) MOS_SCHEDULER_ACT_MEM_UPDATE_INTERVAL=300 +# The size of the context window considered by the scheduler when processing tasks (such as the number of recent messages or conversation rounds) MOS_SCHEDULER_CONTEXT_WINDOW_SIZE=5 +# The maximum number of working threads allowed in the scheduler thread pool for concurrent task execution MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS=10000 +# The polling interval (in seconds) for the scheduler to consume new messages/tasks from the queue. The smaller the value, the faster the response, but the CPU usage may be higher MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS=0.01 +# Whether to enable the parallel distribution function of the scheduler to improve the throughput of concurrent operations MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH=true +# The specific switch to enable or disable the "Activate Memory" function in the scheduler logic MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY=false +# Control whether the scheduler instance is actually started during server initialization. If false, the scheduler object may be created but its background loop will not be started API_SCHEDULER_ON=true +# Specifically define the window size for API search operations in OptimizedScheduler. It is passed to the ScherderrAPIModule to control the scope of the search context API_SEARCH_WINDOW_SIZE=5 +# Specify how many rounds of previous conversations (history) to retrieve and consider during the 'hybrid search' (fast search+asynchronous fine search). This helps provide context aware search results API_SEARCH_HISTORY_TURNS=5 ## Graph / vector stores +# Neo4j database selection mode NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb +# Neo4j database url NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j* +# Neo4j database user NEO4J_USER=neo4j # required when backend=neo4j* +# Neo4j database password NEO4J_PASSWORD=12345678 # required when backend=neo4j* +# Neo4j database name NEO4J_DB_NAME=neo4j # required for shared-db mode -MOS_NEO4J_SHARED_DB=true # if true, all users share one DB; if false, each user gets their own DB -NEO4J_AUTO_CREATE=false # [IMPORTANT] set to false for Neo4j Community Edition -NEO4J_USE_MULTI_DB=false # alternative to MOS_NEO4J_SHARED_DB (logic is inverse) +# Neo4j database data sharing with Memos +MOS_NEO4J_SHARED_DB=false QDRANT_HOST=localhost QDRANT_PORT=6333 # For Qdrant Cloud / remote endpoint (takes priority if set): QDRANT_URL=your_qdrant_url QDRANT_API_KEY=your_qdrant_key +# milvus server uri MILVUS_URI=http://localhost:19530 # required when ENABLE_PREFERENCE_MEMORY=true MILVUS_USER_NAME=root # same as above MILVUS_PASSWORD=12345678 # same as above -NEBULAR_HOSTS=["localhost"] -NEBULAR_USER=root -NEBULAR_PASSWORD=xxxxxx -NEBULAR_SPACE=shared-tree-textual-memory -NEBULAR_WORKING_MEMORY=20 -NEBULAR_LONGTERM_MEMORY=1000000 -NEBULAR_USER_MEMORY=1000000 - -## Relational DB (user manager / PolarDB) -MOS_USER_MANAGER_BACKEND=sqlite # sqlite | mysql -MYSQL_HOST=localhost # required when backend=mysql -MYSQL_PORT=3306 -MYSQL_USERNAME=root -MYSQL_PASSWORD=12345678 -MYSQL_DATABASE=memos_users -MYSQL_CHARSET=utf8mb4 + +# PolarDB endpoint/host POLAR_DB_HOST=localhost +# PolarDB port POLAR_DB_PORT=5432 +# PolarDB username POLAR_DB_USER=root +# PolarDB password POLAR_DB_PASSWORD=123456 +# PolarDB database name POLAR_DB_DB_NAME=shared_memos_db +# PolarDB Server Mode: +# If set to true, use Multi-Database Mode where each user has their own independent database (physical isolation). +# If set to false (default), use Shared Database Mode where all users share one database with logical isolation via username. POLAR_DB_USE_MULTI_DB=false +# PolarDB connection pool size POLARDB_POOL_MAX_CONN=100 -## Redis (scheduler queue) — fill only if you want scheduler queues in Redis; otherwise in-memory queue is used -REDIS_HOST=localhost # global Redis endpoint (preferred over MEMSCHEDULER_*) -REDIS_PORT=6379 -REDIS_DB=0 -REDIS_PASSWORD= -REDIS_SOCKET_TIMEOUT= -REDIS_SOCKET_CONNECT_TIMEOUT= +## Related configurations of Redis +# Reddimq sends scheduling information and synchronization information for some variables MEMSCHEDULER_REDIS_HOST= # fallback keys if not using the global ones MEMSCHEDULER_REDIS_PORT= MEMSCHEDULER_REDIS_DB= @@ -148,41 +173,26 @@ MEMSCHEDULER_REDIS_PASSWORD= MEMSCHEDULER_REDIS_TIMEOUT= MEMSCHEDULER_REDIS_CONNECT_TIMEOUT= -## MemScheduler LLM -MEMSCHEDULER_OPENAI_API_KEY= # LLM key for scheduler’s own calls (OpenAI-compatible); leave empty if scheduler not using LLM -MEMSCHEDULER_OPENAI_BASE_URL= # Base URL for the above; can reuse OPENAI_API_BASE -MEMSCHEDULER_OPENAI_DEFAULT_MODEL=gpt-4o-mini ## Nacos (optional config center) +# Nacos turns off long polling listening, defaults to true NACOS_ENABLE_WATCH=false +# The monitoring interval for long rotation training is 60 seconds, and the default 30 seconds can be left unconfigured NACOS_WATCH_INTERVAL=60 +# nacos server address NACOS_SERVER_ADDR= +# nacos dataid NACOS_DATA_ID= +# nacos group NACOS_GROUP=DEFAULT_GROUP +# nacos namespace NACOS_NAMESPACE= +# nacos ak AK= +# nacos sk SK= -## DingTalk bot & OSS upload -ENABLE_DINGDING_BOT=false # set true -> fields below required -DINGDING_ACCESS_TOKEN_USER= -DINGDING_SECRET_USER= -DINGDING_ACCESS_TOKEN_ERROR= -DINGDING_SECRET_ERROR= -DINGDING_ROBOT_CODE= -DINGDING_APP_KEY= -DINGDING_APP_SECRET= -OSS_ENDPOINT= # bot image upload depends on OSS -OSS_REGION= -OSS_BUCKET_NAME= -OSS_ACCESS_KEY_ID= -OSS_ACCESS_KEY_SECRET= -OSS_PUBLIC_BASE_URL= - -## SDK / external client -MEMOS_API_KEY= -MEMOS_BASE_URL=https://memos.memtensor.cn/api/openmem/v1 - +# chat model for chat api CHAT_MODEL_LIST='[{ "backend": "deepseek", "api_base": "http://localhost:1234", @@ -190,3 +200,16 @@ CHAT_MODEL_LIST='[{ "model_name_or_path": "deepseek-r1", "support_models": ["deepseek-r1"] }]' + +# RabbitMQ host name for message-log pipeline +MEMSCHEDULER_RABBITMQ_HOST_NAME= +# RabbitMQ user name for message-log pipeline +MEMSCHEDULER_RABBITMQ_USER_NAME= +# RabbitMQ password for message-log pipeline +MEMSCHEDULER_RABBITMQ_PASSWORD= +# RabbitMQ virtual host for message-log pipeline +MEMSCHEDULER_RABBITMQ_VIRTUAL_HOST=memos +# Erase connection state on connect for message-log pipeline +MEMSCHEDULER_RABBITMQ_ERASE_ON_CONNECT=true +# RabbitMQ port for message-log pipeline +MEMSCHEDULER_RABBITMQ_PORT=5672 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 0f680505f..0a8e2c634 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -53,7 +53,7 @@ services: - "6333:6333" # REST API - "6334:6334" # gRPC API volumes: - - ./qdrant_data:/qdrant/storage + - qdrant_data:/qdrant/storage environment: QDRANT__SERVICE__GRPC_PORT: 6334 QDRANT__SERVICE__HTTP_PORT: 6333 @@ -64,6 +64,7 @@ services: volumes: neo4j_data: neo4j_logs: + qdrant_data: networks: memos_network: diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt index 57c26067f..be9ed2068 100644 --- a/docker/requirements-full.txt +++ b/docker/requirements-full.txt @@ -159,7 +159,7 @@ tzdata==2025.2 ujson==5.10.0 urllib3==2.5.0 uvicorn==0.35.0 -uvloop==0.21.0 +uvloop==0.22.1; sys_platform != 'win32' volcengine-python-sdk==4.0.6 watchfiles==1.1.0 websockets==15.0.1 @@ -179,7 +179,7 @@ pathable==0.4.4 pathvalidate==3.3.1 platformdirs==4.5.0 pluggy==1.6.0 -psycopg2-binary==2.9.9 +psycopg2-binary==2.9.11 py-key-value-aio==0.2.8 py-key-value-shared==0.2.8 PyJWT==2.10.1 diff --git a/docker/requirements.txt b/docker/requirements.txt index aa01fa626..f89617c10 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -1,21 +1,18 @@ annotated-types==0.7.0 -anyio==4.9.0 -async-timeout==5.0.1 -attrs==25.3.0 -authlib==1.6.0 -beautifulsoup4==4.13.4 -certifi==2025.7.14 -cffi==1.17.1 -charset-normalizer==3.4.2 -chonkie==1.1.1 -click==8.2.1 -cobble==0.1.4 -colorama==0.4.6 -coloredlogs==15.0.1 +anyio==4.11.0 +attrs==25.4.0 +Authlib==1.6.5 +beartype==0.22.5 +cachetools==6.2.2 +certifi==2025.11.12 +cffi==2.0.0 +charset-normalizer==3.4.4 +chonkie==1.1.0 +click==8.3.0 concurrent-log-handler==0.9.28 -cryptography==45.0.5 -cyclopts==3.22.2 -defusedxml==0.7.1 +cryptography==46.0.3 +cyclopts==4.2.3 +diskcache==5.6.3 distro==1.9.0 dnspython==2.8.0 docstring_parser==0.17.0 @@ -29,7 +26,6 @@ fastmcp==2.13.0.2 filelock==3.20.0 fsspec==2025.10.0 grpcio==1.76.0 -neo4j==5.28.1 h11==0.16.0 hf-xet==1.2.0 httpcore==1.0.9 @@ -56,6 +52,7 @@ MarkupSafe==3.0.3 mcp==1.21.1 mdurl==0.1.2 more-itertools==10.8.0 +neo4j==5.28.1 numpy==2.3.4 ollama==0.4.9 openai==1.109.1 @@ -68,10 +65,10 @@ pathvalidate==3.3.1 pika==1.3.2 platformdirs==4.5.0 pluggy==1.6.0 -portalocker==3.2.0 +portalocker==2.8.0 prometheus_client==0.23.1 protobuf==6.33.1 -psycopg2-binary==2.9.9 +psycopg2-binary==2.9.11 py-key-value-aio==0.2.8 py-key-value-shared==0.2.8 pycparser==2.23 @@ -90,7 +87,7 @@ python-dotenv==1.2.1 python-multipart==0.0.20 pytz==2025.2 PyYAML==6.0.3 -qdrant-client +qdrant-client==1.14.3 redis==6.4.0 referencing==0.36.2 regex==2025.11.3 @@ -123,6 +120,6 @@ tzdata==2025.2 ujson==5.11.0 urllib3==2.5.0 uvicorn==0.38.0 -uvloop==0.22.1 +uvloop==0.22.1; sys_platform != 'win32' watchfiles==1.1.1 websockets==15.0.1 diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7298658ff..daf9b6cfe 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -204,7 +204,7 @@ def init(cls) -> None: sk = os.getenv("SK") if not (server_addr and data_id and ak and sk): - logger.warning("❌ missing NACOS_SERVER_ADDR / AK / SK / DATA_ID") + logger.warning("missing NACOS_SERVER_ADDR / AK / SK / DATA_ID") return base_url = f"http://{server_addr}/nacos/v1/cs/configs" diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 0dfef99d9..ac9ed8d88 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -13,8 +13,8 @@ logger = logging.getLogger(__name__) app = FastAPI( - title="MemOS Product REST APIs", - description="A REST API for managing multiple users with MemOS Product.", + title="MemOS Server REST APIs", + description="A REST API for managing multiple users with MemOS Server.", version="1.0.1", ) diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index e2c1621d4..cdd491183 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -128,6 +128,7 @@ def status_tracker(self) -> TaskStatusTracker | None: if self._status_tracker is None: try: self._status_tracker = TaskStatusTracker(self.redis) + # Propagate to submodules when created lazily if self.memos_message_queue: self.memos_message_queue.set_status_tracker(self._status_tracker) except Exception as e: diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 941c52164..557a45466 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -1216,7 +1216,7 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) - logger.info( - f"Refreshed stream keys cache: {cache_count} active keys, " - f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." - ) + logger.info( + f"Refreshed stream keys cache: {cache_count} active keys, " + f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." + ) diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py index 2a995b239..4977cfc3c 100644 --- a/src/memos/mem_scheduler/utils/status_tracker.py +++ b/src/memos/mem_scheduler/utils/status_tracker.py @@ -17,6 +17,9 @@ def __init__(self, redis_client: "redis.Redis | None"): self.redis = redis_client def _get_key(self, user_id: str) -> str: + if not self.redis: + return + return f"memos:task_meta:{user_id}" def _get_task_items_key(self, user_id: str, task_id: str) -> str: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 7e28c174b..3612d37eb 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -290,6 +290,51 @@ def _parse_task( return parsed_goal, query_embedding, context, query + @timed + def _retrieve_simple( + self, + query: str, + top_k: int, + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + """Retrieve from by keywords and embedding""" + query_words = [] + if self.tokenizer: + query_words = self.tokenizer.tokenize_mixed(query) + else: + query_words = query.strip().split() + query_words = [query, *query_words] + logger.info(f"[SIMPLESEARCH] Query words: {query_words}") + query_embeddings = self.embedder.embed(query_words) + + items = self.graph_retriever.retrieve_from_mixed( + top_k=top_k * 2, + memory_scope=None, + query_embedding=query_embeddings, + search_filter=search_filter, + user_name=user_name, + use_fast_graph=self.use_fast_graph, + ) + logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") + documents = [getattr(item, "memory", "") for item in items] + if not documents: + return [] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) + return self.reranker.rerank( + query=query, + query_embedding=query_embeddings[0], + graph_results=selected_items, + top_k=top_k, + ) + @timed def _retrieve_paths( self, diff --git a/src/memos/utils.py b/src/memos/utils.py index 4f2666efd..594180e8f 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -79,7 +79,6 @@ def wrapper(*args, **kwargs): status = "SUCCESS" if success_flag else "FAILED" status_info = f", status: {status}" - if not success_flag and exc_type is not None: status_info += ( f", error_type: {exc_type.__name__}, error_message: {exc_message}" @@ -88,6 +87,7 @@ def wrapper(*args, **kwargs): msg = ( f"[TIMER_WITH_STATUS] {log_prefix or fn.__name__} " f"took {elapsed_ms:.0f} ms{status_info}, args: {ctx_str}" + f", result: {result}" ) logger.info(msg) From 9c363b49643495e00869b40ab444bf2f848c3d30 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Wed, 7 Jan 2026 15:22:40 +0800 Subject: [PATCH 41/48] feat: add OpenAI token log (#831) * feat: timer false * feat: add openai token log * fix: format error --------- Co-authored-by: harvey_xiang Co-authored-by: CaralHsi --- src/memos/llms/openai.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index ea488329d..f49f1d7d1 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -1,4 +1,5 @@ import json +import time from collections.abc import Generator @@ -46,9 +47,16 @@ def generate(self, messages: MessageList, **kwargs) -> str: "extra_body": kwargs.get("extra_body", self.config.extra_body), "tools": kwargs.get("tools", NOT_GIVEN), } + start_time = time.perf_counter() logger.info(f"OpenAI LLM Request body: {request_body}") + response = self.client.chat.completions.create(**request_body) - logger.info(f"Response from OpenAI: {response.model_dump_json()}") + + cost_time = time.perf_counter() - start_time + logger.info( + f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}" + ) + tool_calls = getattr(response.choices[0].message, "tool_calls", None) if isinstance(tool_calls, list) and len(tool_calls) > 0: return self.tool_call_parser(tool_calls) From 8d630600544a382dec0d36360e78e7e5f2311889 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:25:23 +0800 Subject: [PATCH 42/48] fix: feedback llm output fail to load json (#833) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback * patch: get_memory func filter user id and make page chunk * add total num * add total num * add milvus pagination * fix merge implicit explicit pref * fix merge implicit explicit pref * fix merge implicit explicit pref * fix json load bug --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/mem_feedback/feedback.py | 99 ++++++++++++++++++------------ src/memos/mem_feedback/utils.py | 81 ++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 39 deletions(-) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index fad15a7cd..15d7c336a 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -18,6 +18,8 @@ from memos.log import get_logger from memos.mem_feedback.base import BaseMemFeedback from memos.mem_feedback.utils import ( + extract_bracket_content, + extract_square_brackets_content, general_split_into_chunks, make_mem_item, should_keep_update, @@ -118,7 +120,7 @@ def _retry_db_operation(self, operation): return operation() except Exception as e: logger.error( - f"[1224 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True + f"[0107 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True ) raise @@ -132,7 +134,7 @@ def _batch_embed(self, texts: list[str], embed_bs: int = 5): results.extend(self._embed_once(batch)) except Exception as e: logger.error( - f"[1224 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" + f"[0107 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}" ) results.extend([[0.0] * dim for _ in range(len(batch))]) return results @@ -148,7 +150,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i lambda: self.memory_manager.add(to_add_memories, user_name=user_name, use_batch=False) ) logger.info( - f"[1224 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." + f"[0107 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}." ) return { "record": { @@ -180,12 +182,12 @@ def _keyword_replace_judgement(self, feedback_content: str) -> dict | None: user_feedback=feedback_content, ) - judge_res = self._get_llm_response(prompt) + judge_res = self._get_llm_response(prompt, load_type="bracket") if judge_res: return judge_res else: logger.warning( - "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[0107 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return {} @@ -205,12 +207,12 @@ def _feedback_judgement( feedback_time=feedback_time, ) - judge_res = self._get_llm_response(prompt) + judge_res = self._get_llm_response(prompt, load_type="square_bracket") if judge_res: return judge_res else: logger.warning( - "[1224 Feedback Core: _feedback_judgement] feedback judgement failed, return []" + "[0107 Feedback Core: _feedback_judgement] feedback judgement failed, return []" ) return [] @@ -276,7 +278,7 @@ def _single_update_operation( """ if "preference" in old_memory_item.metadata.__dict__: logger.info( - f"[1224 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" + f"[0107 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" ) return self._single_update_pref( old_memory_item, new_memory_item, user_id, user_name, operation @@ -408,11 +410,11 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> self.graph_store.delete_node(mid, user_name=user_name) logger.info( - f"[1224 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + f"[0107 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" ) except Exception as e: logger.warning( - f"[1224 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" + f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) def semantics_feedback( @@ -473,7 +475,7 @@ def semantics_feedback( chat_history=history_str, ) - future = executor.submit(self._get_llm_response, prompt) + future = executor.submit(self._get_llm_response, prompt, load_type="bracket") future_to_chunk_idx[future] = chunk for future in concurrent.futures.as_completed(future_to_chunk_idx): try: @@ -486,7 +488,7 @@ def semantics_feedback( all_operations.extend(chunk_operations["operations"]) except Exception as e: logger.error( - f"[1224 Feedback Core: semantics_feedback] Operation failed: {e}" + f"[0107 Feedback Core: semantics_feedback] Operation failed: {e}" ) standard_operations = self.standard_operations(all_operations, current_memories) @@ -536,7 +538,7 @@ def semantics_feedback( update_results.append(result) except Exception as e: logger.error( - f"[1224 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", + f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) if update_results: @@ -564,7 +566,7 @@ def _feedback_memory( ] if filterd_ids: logger.warning( - f"[1224 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[0107 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) current_memories = [ @@ -596,7 +598,7 @@ def _feedback_memory( results[i] = node except Exception as e: logger.error( - f"[1224 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", + f"[0107 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}", exc_info=True, ) mem_res = [r for r in results if r] @@ -660,7 +662,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): if not retrieved_ids: logger.info( - f"[1224 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." + f"[0107 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." ) filterd_ids = [ @@ -668,7 +670,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): ] if filterd_ids: logger.warning( - f"[1224 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." + f"[0107 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}." ) return [ TextualMemoryItem(**item) @@ -676,22 +678,41 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): if "mode:fast" not in item["metadata"]["tags"] ] - def _get_llm_response(self, prompt: str, dsl: bool = True) -> dict: + def _get_llm_response( + self, + prompt: str, + dsl: bool = True, + load_type: Literal["bracket", "square_bracket"] | None = None, + ) -> dict: messages = [{"role": "user", "content": prompt}] + response_text = "" try: response_text = self.llm.generate(messages, temperature=0.3, timeout=60) - if dsl: + if not dsl: + return response_text + try: response_text = response_text.replace("```", "").replace("json", "") cleaned_text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", response_text) response_json = json.loads(cleaned_text) - else: - return response_text + return response_json + except (json.JSONDecodeError, ValueError) as e: + if load_type == "bracket": + response_json = extract_bracket_content(response_text) + return response_json + elif load_type == "square_bracket": + response_json = extract_square_brackets_content(response_text) + return response_json + else: + logger.error( + f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}" + ) + return None + except Exception as e: logger.error( f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}" ) - response_json = None - return response_json + return None def filter_fault_update(self, operations: list[dict]): """To address the randomness of large model outputs, it is necessary to conduct validity evaluation on the texts used for memory override operations.""" @@ -710,7 +731,7 @@ def filter_fault_update(self, operations: list[dict]): raw_operations_str = {"operations": chunk} prompt = template.format(raw_operations=str(raw_operations_str)) - future = executor.submit(self._get_llm_response, prompt) + future = executor.submit(self._get_llm_response, prompt, load_type="bracket") future_to_chunk_idx[future] = chunk for future in concurrent.futures.as_completed(future_to_chunk_idx): try: @@ -722,9 +743,9 @@ def filter_fault_update(self, operations: list[dict]): ): all_judge.extend(judge_res["operations_judgement"]) except Exception as e: - logger.error(f"[1224 Feedback Core: filter_fault_update] Judgement failed: {e}") + logger.error(f"[0107 Feedback Core: filter_fault_update] Judgement failed: {e}") - logger.info(f"[1224 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") + logger.info(f"[0107 Feedback Core: filter_fault_update] LLM judgement: {all_judge}") id2op = {item["id"]: item for item in updated_operations} valid_updates = [] for judge in all_judge: @@ -735,7 +756,7 @@ def filter_fault_update(self, operations: list[dict]): valid_updates.append(valid_update) logger.info( - f"[1224 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" + f"[0107 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}" ) return valid_updates + [item for item in operations if item["operation"] != "UPDATE"] @@ -767,7 +788,7 @@ def correct_item(data): if not should_keep_update(data["text"], data["old_memory"]): logger.warning( - f"[1224 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}" + f"[0107 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}" ) return None @@ -787,14 +808,14 @@ def correct_item(data): return data except Exception: logger.error( - f"[1224 Feedback Core: standard_operations] Error processing operation item: {data}", + f"[0107 Feedback Core: standard_operations] Error processing operation item: {data}", exc_info=True, ) return None dehallu_res = [correct_item(item) for item in operations] dehalluded_operations = [item for item in dehallu_res if item] - logger.info(f"[1224 Feedback Core: dehalluded_operations] {dehalluded_operations}") + logger.info(f"[0107 Feedback Core: dehalluded_operations] {dehalluded_operations}") # c add objects add_texts = [] @@ -808,7 +829,7 @@ def correct_item(data): elif item["operation"].lower() == "update": llm_operations.append(item) logger.info( - f"[1224 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" + f"[0107 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories" ) # Update takes precedence over add @@ -822,7 +843,7 @@ def correct_item(data): ] if filtered_items: logger.info( - f"[1224 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" + f"[0107 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}" ) return update_items else: @@ -870,7 +891,7 @@ def _doc_filter(self, doc_scope: str, memories: list[TextualMemoryItem]): memid for inscope_file in inscope_docs for memid in filename2_memid[inscope_file] ] logger.info( - f"[1224 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" + f"[0107 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" ) filter_memories = [mem for mem in memories if mem.id in inscope_ids] return filter_memories @@ -924,7 +945,7 @@ def process_keyword_replace( retrieved_memories = self._doc_filter(doc_scope, retrieved_memories) logger.info( - f"[1224 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." + f"[0107 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories." ) if not retrieved_memories: @@ -1009,7 +1030,7 @@ def check_validity(item): info.update({"user_id": user_id, "user_name": user_name, "session_id": session_id}) logger.info( - f"[1224 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" + f"[0107 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" ) # feedback keywords update kwp_judge = self._keyword_replace_judgement(feedback_content) @@ -1042,7 +1063,7 @@ def check_validity(item): if not valid_feedback: logger.warning( - f"[1224 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." + f"[0107 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}." ) return {"record": {"add": [], "update": []}} @@ -1090,13 +1111,13 @@ def check_validity(item): add_memories = mem_record["record"]["add"] update_memories = mem_record["record"]["update"] logger.info( - f"[1224 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." + f"[0107 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}." ) return mem_record except Exception as e: logger.error( - f"[1224 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" + f"[0107 Feedback Core: process_feedback_core] Error for user {user_name}: {e}" ) return {"record": {"add": [], "update": []}} diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py index 8cb7f97a3..8e3b2f34c 100644 --- a/src/memos/mem_feedback/utils.py +++ b/src/memos/mem_feedback/utils.py @@ -1,3 +1,6 @@ +import json +import re + from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata @@ -147,3 +150,81 @@ def make_mem_item(text: str, **kwargs) -> TextualMemoryItem: info=info_, ), ) + + +def extract_bracket_content(text): + """ + Extract and parse JSON content enclosed in curly braces {} from text. + """ + # Strategy 1: Greedy match to capture the outermost complete brace pair + greedy_match = re.search(r"\{.*\}", text, re.DOTALL) + if greedy_match is None: + error_msg = f"No curly brace content found in text: {text}" + raise ValueError(error_msg) + + greedy_content = greedy_match.group(0) + + # Strategy 2: Non-greedy match to find all brace pairs, use the last one + non_greedy_matches = re.findall(r"\{.*?\}", text, re.DOTALL) + if not non_greedy_matches: + error_msg = f"No curly brace content found in text: {text}" + raise ValueError(error_msg) + + non_greedy_content = non_greedy_matches[-1] + + for content in [greedy_content, non_greedy_content]: + try: + parsed_data = json.loads(content) + return parsed_data + except json.JSONDecodeError: + continue + + for content in [greedy_content, non_greedy_content]: + try: + fixed_content = content.replace("{{", "{").replace("}}", "}") + parsed_data = json.loads(fixed_content) + return parsed_data + except json.JSONDecodeError: + continue + + error_msg = f"Failed to parse JSON content from curly braces. Text preview: {text}" + raise ValueError(error_msg) + + +def extract_square_brackets_content(text): + """ + Extract and parse JSON content enclosed in square brackets [] from text. + """ + # Strategy 1: Greedy match to capture the outermost complete bracket pair + greedy_match = re.search(r"\[.*\]", text, re.DOTALL) + if greedy_match is None: + error_msg = f"No square bracket content found in text: {text}" + raise ValueError(error_msg) + + greedy_content = greedy_match.group(0) + + # Strategy 2: Non-greedy match to find all bracket pairs, use the last one + non_greedy_matches = re.findall(r"\[.*?\]", text, re.DOTALL) + if not non_greedy_matches: + error_msg = f"No square bracket content found in text: {text}" + raise ValueError(error_msg) + + non_greedy_content = non_greedy_matches[-1] + + for content in [greedy_content, non_greedy_content]: + try: + parsed_data = json.loads(content) + return parsed_data + except json.JSONDecodeError: + continue + + for content in [greedy_content, non_greedy_content]: + try: + fixed_content = content.replace("{{", "{").replace("}}", "}") + parsed_data = json.loads(fixed_content) + return parsed_data + except json.JSONDecodeError: + continue + + error_msg = f"Failed to parse JSON content from square brackets. Text preview: {text}" + raise ValueError(error_msg) From 0e41b643e38ef1a053fd19ff75e3d555c1a0519a Mon Sep 17 00:00:00 2001 From: Zehao Lin Date: Wed, 7 Jan 2026 19:50:26 +0800 Subject: [PATCH 43/48] fix: Use env exchange overrides for all scheduler messages (#834) * Use env exchange name/type overrides for scheduler * Default routing key to empty with env exchange override * Format rabbitmq_service after env override changes --------- Co-authored-by: glin1993@outlook.com <> --- .../webservice_modules/rabbitmq_service.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 46b2ad3d1..5a94d2af2 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -14,7 +14,6 @@ from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE -from memos.mem_scheduler.utils.misc_utils import is_cloud_env logger = get_logger(__name__) @@ -132,6 +131,15 @@ def initialize_rabbitmq( self.rabbitmq_exchange_type = self.rabbitmq_config.exchange_type logger.info(f"Using configured exchange type: {self.rabbitmq_exchange_type}") + env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") + env_exchange_type = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_TYPE") + if env_exchange_name: + self.rabbitmq_exchange_name = env_exchange_name + logger.info(f"Using env exchange name override: {self.rabbitmq_exchange_name}") + if env_exchange_type: + self.rabbitmq_exchange_type = env_exchange_type + logger.info(f"Using env exchange type override: {self.rabbitmq_exchange_type}") + # Start connection process parameters = self.get_rabbitmq_connection_param() self.rabbitmq_connection = SelectConnection( @@ -313,15 +321,16 @@ def rabbitmq_publish_message(self, message: dict): if label == "knowledgeBaseUpdate": routing_key = "" - # Cloud environment override: applies to specific message types if MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set + # Env override: apply to all message types when MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") - if is_cloud_env() and env_exchange_name and label in ["taskStatus", "knowledgeBaseUpdate"]: + env_routing_key = os.getenv("MEMSCHEDULER_RABBITMQ_ROUTING_KEY") + if env_exchange_name: exchange_name = env_exchange_name - routing_key = "" # Routing key is always empty in cloud environment for these types - - # Specific diagnostic logging for messages affected by cloud environment settings + routing_key = ( + env_routing_key if env_routing_key is not None and env_routing_key != "" else "" + ) logger.info( - f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. " + f"[DIAGNOSTIC] Publishing {label} message with env exchange override. " f"Exchange: {exchange_name}, Routing Key: '{routing_key}'." ) logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}") From 3ee82f31aacdb3d07dd3592c3b7166fbd6f88b5a Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Wed, 7 Jan 2026 20:27:47 +0800 Subject: [PATCH 44/48] feat: support single-assistant mem-reader (#835) feat: support 'single-assistant' mem reader --- src/memos/templates/mem_reader_prompts.py | 39 +++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 26795a2b1..9432d6303 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -69,6 +69,26 @@ "summary": "Tom is currently focused on managing a new project with a tight schedule. After a team meeting on June 25, 2025, he realized the original deadline of December 15 might not be feasible due to backend delays. Concerned about insufficient testing time, he welcomed Jerry’s suggestion of proposing an extension. Tom plans to raise the idea of shifting the deadline to January 5, 2026 in the next morning’s meeting. His actions reflect both stress about timelines and a proactive, team-oriented problem-solving approach." } +Dialogue: +assistant: [10:30 AM, August 15, 2025]: The book Deep Work you mentioned is +indeed very suitable for your current situation. The book explains … (omitted). The author suggests setting aside 2–3 hours of focused work blocks each day and turning off all notifications during that time. Considering that you need to submit a report next week, you could try using the 9:00–11:00 AM time slot for focused work. + +Output: +{ + "memory list": [ + { + "key": "Deep Work Book Recommendation", + "memory_type": "LongTermMemory", + "value": "On August 15, 2025, the assistant recommended the book 'Deep Work' to the user and introduced its suggestion of reserving 2–3 hours per day for focused work while turning off all notifications. Based on the user's need to submit a report the following week, the assistant also suggested trying 9:00–11:00 AM as a focused work time block.", + "tags": ["book recommendation", "deep work", "time management", "report"] + } + ], + "summary": "The assistant recommended the book 'Deep Work' to the user and introduced the work methods discussed in the book." +} + +Note: When the dialogue contains only assistant messages, phrasing such as +“assistant recommended” or “assistant suggested” should be used, rather than incorrectly attributing the content to the user’s statements or plans. + Another Example in Chinese (注意: 当user的语言为中文时,你就需要也输出中文): { "memory list": [ @@ -163,6 +183,25 @@ "summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。" } +对话: +assistant: [2025年8月15日上午10:30]: +你提到的那本《深度工作》确实很适合你现在的情况。这本书讲了......(略),作者建议每天留出2-3 +小时的专注时间块,期间关闭所有通知。考虑到你下周要交的报告,可以试试早上9点到11点这个时段。 + +输出: +{ + "memory list": [ + { + "key": "深度工作书籍推荐", + "memory_type": "LongTermMemory", + "value": "2025年8月15日助手向用户推荐了《深度工作》一书,并介绍了书中建议的每天留出2-3小时专注时间块、关闭所有通知的方法。助手还根据用户下周需要提交报告的情况,建议用户尝试早上9点到11点作为专注时段。", + "tags": ["书籍推荐", "深度工作", "时间管理", "报告"] + } + ], + "summary": "助手向用户推荐了《深度工作》一书,并介绍了了其中的工作方法" +} +注意:当对话仅有助手消息时,应使用"助手推荐"、"助手建议"等表述,而非将其错误归因为用户的陈述或计划。 + 另一个中文示例(注意:当用户语言为中文时,您也需输出中文): { "memory list": [ From 51a782b7db318a6862ab1012f49a715ed99bf048 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Thu, 8 Jan 2026 16:25:02 +0800 Subject: [PATCH 45/48] fix: knowledge base adopt raw text (#836) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net * deduplicate add objects * use _add_memories_parallel * delete Special characters * delete Special characters * delete Special characters * delete Special characters * add source_doc_id * add source_doc_id * add reranker in init com.. * fix circle import * add feedback judgement * add feedback judgement * add pref feedback * add pref feedback * patch: get_memory func filter user id and make page chunk * add total num * add total num * add milvus pagination * fix merge implicit explicit pref * fix merge implicit explicit pref * fix merge implicit explicit pref * fix json load bug * knowledge raw_text replace memory * knowledge raw_text replace memory * knowledge raw_text replace memory --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/api/handlers/formatters_handler.py | 104 ++++++++++++++++++- src/memos/api/handlers/memory_handler.py | 2 +- src/memos/api/handlers/search_handler.py | 15 +++ src/memos/reranker/concat.py | 12 ++- src/memos/reranker/http_bge.py | 16 +-- 5 files changed, 138 insertions(+), 11 deletions(-) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 94988295b..ca87d95d2 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -7,9 +7,13 @@ from typing import Any +from memos.log import get_logger from memos.templates.instruction_completion import instruct_completion +logger = get_logger(__name__) + + def to_iter(running: Any) -> list[Any]: """ Normalize running tasks to a list of task objects. @@ -29,7 +33,9 @@ def to_iter(running: Any) -> list[Any]: return list(running) if running else [] -def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]: +def format_memory_item( + memory_data: Any, include_embedding: bool = False, save_sources: bool = True +) -> dict[str, Any]: """ Format a single memory item for API response. @@ -49,7 +55,8 @@ def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dic memory["ref_id"] = ref_id if not include_embedding: memory["metadata"]["embedding"] = [] - memory["metadata"]["sources"] = [] + if not save_sources: + memory["metadata"]["sources"] = [] memory["metadata"]["usage"] = [] memory["metadata"]["ref_id"] = ref_id memory["metadata"]["id"] = memory_id @@ -125,3 +132,96 @@ def post_process_textual_mem( } ) return memories_result + + +def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): + """ + Separate knowledge and conversation memories from retrieval results. + """ + knowledge_mem = [] + conversation_mem = [] + for item in memories: + sources = item["metadata"]["sources"] + if ( + len(sources) > 0 + and "type" in sources[0] + and sources[0]["type"] == "file" + and "content" in sources[0] + and sources[0]["content"] != "" + ): # TODO change to memory_type + knowledge_mem.append(item) + else: + conversation_mem.append(item) + + logger.info( + f"Retrieval results number of knowledge_mem: {len(knowledge_mem)}, conversation_mem: {len(conversation_mem)}" + ) + return knowledge_mem, conversation_mem + + +def rerank_knowledge_mem( + reranker: Any, + query: str, + text_mem: list[dict[str, Any]], + top_k: int, + file_mem_proportion: float = 0.5, +) -> list[dict[str, Any]]: + """ + Rerank knowledge memories and keep conversation memories. + """ + memid2cubeid = {} + memories_list = [] + for memory_group in text_mem: + cube_id = memory_group["cube_id"] + memories = memory_group["memories"] + memories_list.extend(memories) + for memory in memories: + memid2cubeid[memory["id"]] = cube_id + + knowledge_mem, conversation_mem = separate_knowledge_and_conversation_mem(memories_list) + knowledge_mem_top_k = max(int(top_k * file_mem_proportion), int(top_k - len(conversation_mem))) + reranked_knowledge_mem = reranker.rerank(query, knowledge_mem, top_k=len(knowledge_mem)) + reranked_knowledge_mem = [item[0] for item in reranked_knowledge_mem] + + # TODO revoke sources replace memory value + for item in reranked_knowledge_mem: + item["memory"] = item["metadata"]["sources"][0]["content"] + item["metadata"]["sources"] = [] + + for item in conversation_mem: + item["metadata"]["sources"] = [] + + # deduplicate: remove items with duplicate memory content + original_count = len(reranked_knowledge_mem) + seen_memories = set[Any]() + deduplicated_knowledge_mem = [] + for item in reranked_knowledge_mem: + memory_content = item.get("memory", "") + if memory_content and memory_content not in seen_memories: + seen_memories.add(memory_content) + deduplicated_knowledge_mem.append(item) + deduplicated_count = len(deduplicated_knowledge_mem) + logger.info( + f"After filtering duplicate knowledge base text from sources, count changed from {original_count} to {deduplicated_count}" + ) + + reranked_knowledge_mem = deduplicated_knowledge_mem[:knowledge_mem_top_k] + conversation_mem_top_k = top_k - len(reranked_knowledge_mem) + cubeid2memories = {} + text_mem_res = [] + + for memory in reranked_knowledge_mem + conversation_mem[:conversation_mem_top_k]: + cube_id = memid2cubeid[memory["id"]] + if cube_id not in cubeid2memories: + cubeid2memories[cube_id] = [] + cubeid2memories[cube_id].append(memory) + + for cube_id, memories in cubeid2memories.items(): + text_mem_res.append( + { + "cube_id": cube_id, + "memories": memories, + } + ) + + return text_mem_res diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index ef829d757..14bb8eec5 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -204,7 +204,7 @@ def handle_get_memories( preferences, total_pref = naive_mem_cube.pref_mem.get_memory_by_filter( filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size ) - format_preferences = [format_memory_item(item) for item in preferences] + format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] return GetMemoryResponse( message="Memories retrieved successfully", diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 3774410dc..32a970b22 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -5,9 +5,12 @@ using dependency injection for better modularity and testability. """ +import time + from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( @@ -69,6 +72,18 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Restore original top_k for downstream logic or response metadata search_req.top_k = original_top_k + start_time = time.time() + text_mem = results["text_mem"] + results["text_mem"] = rerank_knowledge_mem( + self.reranker, + query=search_req.query, + text_mem=text_mem, + top_k=original_top_k, + file_mem_proportion=0.5, + ) + rerank_time = time.time() - start_time + self.logger.info(f"[Knowledge_replace_memory_time] Rerank time: {rerank_time} seconds") + self.logger.info( f"[SearchHandler] Final search results: count={len(results)} results={results}" ) diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py index 502af18b6..b39496a1c 100644 --- a/src/memos/reranker/concat.py +++ b/src/memos/reranker/concat.py @@ -83,10 +83,18 @@ def concat_original_source( merge_field = ["sources"] if rerank_source is None else rerank_source.split(",") documents = [] for item in graph_results: - memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m + m = item.get("memory") if isinstance(item, dict) else getattr(item, "memory", None) + + memory = _TAG1.sub("", m) if isinstance(m, str) else m + sources = [] for field in merge_field: - source = getattr(item.metadata, field, None) + if isinstance(item, dict): + metadata = item.get("metadata", {}) + source = metadata.get(field) if isinstance(metadata, dict) else None + else: + source = getattr(item.metadata, field, None) if hasattr(item, "metadata") else None + if source is None: continue sources.append((memory, source)) diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 4e9054f1e..32034cf6d 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -129,7 +129,7 @@ def __init__( def rerank( self, query: str, - graph_results: list[TextualMemoryItem], + graph_results: list[TextualMemoryItem] | list[dict[str, Any]], top_k: int, search_priority: dict | None = None, **kwargs, @@ -164,11 +164,15 @@ def rerank( if self.rerank_source: documents = concat_original_source(graph_results, self.rerank_source) else: - documents = [ - (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m) - for item in graph_results - ] - documents = [d for d in documents if isinstance(d, str) and d] + documents = [] + filtered_graph_results = [] + for item in graph_results: + m = item.get("memory") if isinstance(item, dict) else getattr(item, "memory", None) + + if isinstance(m, str) and m: + documents.append(_TAG1.sub("", m)) + filtered_graph_results.append(item) + graph_results = filtered_graph_results logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...") From 7ffdb523e2fa6634903667c18120c495eb467bf9 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 8 Jan 2026 17:46:24 +0800 Subject: [PATCH 46/48] Feat/optimize cloud service api (#839) * add get_user_names_by_memory_ids api * modify delete api * modify bug * add extract limit in implicit memory * close internet search in chat api, modify implicit pref prompt * modify bug * add a new internal method for check cube id exist * modify code * add get memory by memory id api --------- Co-authored-by: yuan.wang --- src/memos/api/handlers/memory_handler.py | 43 ++++++++++++++++++++++++ src/memos/api/routers/server_router.py | 8 +++++ 2 files changed, 51 insertions(+) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 14bb8eec5..7110fae09 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -176,6 +176,49 @@ def handle_get_subgraph( raise +def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemoryResponse: + """ + Handler for getting a single memory by its ID. + + Tries to retrieve from text memory first, then preference memory if not found. + + Args: + memory_id: The ID of the memory to retrieve + naive_mem_cube: Memory cube instance + + Returns: + GetMemoryResponse with the memory data + """ + + try: + memory = naive_mem_cube.text_mem.get(memory_id) + except Exception: + memory = None + + # If not found in text memory, try preference memory + pref = None + if memory is None and naive_mem_cube.pref_mem is not None: + collection_names = ["explicit_preference", "implicit_preference"] + for collection_name in collection_names: + try: + pref = naive_mem_cube.pref_mem.get_with_collection_name(collection_name, memory_id) + if pref is not None: + break + except Exception: + continue + + # Get the data from whichever memory source succeeded + data = (memory or pref).model_dump() if (memory or pref) else None + + return GetMemoryResponse( + message="Memory retrieved successfully" + if data + else f"Memory with ID {memory_id} not found", + code=200, + data=data, + ) + + def handle_get_memories( get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube ) -> GetMemoryResponse: diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a4052d313..8371c41b9 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -314,6 +314,14 @@ def get_memories(memory_req: GetMemoryRequest): ) +@router.get("/get_memory/{memory_id}", summary="Get memory by id", response_model=GetMemoryResponse) +def get_memory_by_id(memory_id: str): + return handlers.memory_handler.handle_get_memory( + memory_id=memory_id, + naive_mem_cube=naive_mem_cube, + ) + + @router.post( "/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse ) From 5f811d4ca5737cdb197fddc06743cfe79da36861 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 8 Jan 2026 18:42:02 +0800 Subject: [PATCH 47/48] Feat/fix palyground bug (#841) * fix playground bug, internet search judge * fix playground internet bug * modify delete mem * modify tool resp bug in multi cube * fix bug in playground chat handle and search inter * modify prompt * fix bug in playground * fix bug playfround * fix bug * fix code * fix model bug in playground * modify plan b * llm param modify * add logger in playground * modify code * fix bug * modify code * modify code * fix bug * fix search bug in plarground * fixx bug * move schadualr to back * modify pref location * modify fast net search * add tags and new package * modify prompt fix bug * remove nltk due to image promblem * prompt modify * modify bug remove redundant field * modify bug * fix playground bug * fix bug * bust internet topk * bust to 50 * fix bug cite * modify search * remote query add in playground * modify bug * modify pref bug * move add position * modify chat prompt * modify overthinking * add logger in playground chat * midify mem * remove must in prompt * add logger * add logger * remove dedup in playground --------- Co-authored-by: yuan.wang Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> Co-authored-by: CaralHsi --- src/memos/api/handlers/memory_handler.py | 13 ------------- src/memos/api/routers/server_router.py | 2 -- 2 files changed, 15 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 7110fae09..a4f500560 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -23,10 +23,6 @@ remove_embedding_recursive, sort_children_by_memory_type, ) -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( - cosine_similarity_matrix, - find_best_unrelated_subgroup, -) if TYPE_CHECKING: @@ -41,7 +37,6 @@ def handle_get_all_memories( mem_cube_id: str, memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], naive_mem_cube: Any, - embedder: Any, ) -> MemoryResponse: """ Main handler for getting all memories. @@ -64,14 +59,6 @@ def handle_get_all_memories( # Get all text memories from the graph database memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id) - mems = [mem.get("memory", "") for mem in memories.get("nodes", [])] - embeddings = embedder.embed(mems) - similarity_matrix = cosine_similarity_matrix(embeddings) - selected_indices, _ = find_best_unrelated_subgroup( - embeddings, similarity_matrix, bar=0.9 - ) - memories["nodes"] = [memories["nodes"][i] for i in selected_indices] - # Format and convert to tree structure memories_cleaned = remove_embedding_recursive(memories) custom_type_ratios = { diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 8371c41b9..86b75d73e 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -88,7 +88,6 @@ naive_mem_cube = components["naive_mem_cube"] redis_client = components["redis_client"] status_tracker = TaskStatusTracker(redis_client=redis_client) -embedder = components["embedder"] graph_db = components["graph_db"] vector_db = components["vector_db"] @@ -302,7 +301,6 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): ), memory_type=memory_req.memory_type or "text_mem", naive_mem_cube=naive_mem_cube, - embedder=embedder, ) From 8b30a4414848b8b91ad857113db43984eefd8cb7 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Fri, 9 Jan 2026 10:43:18 +0800 Subject: [PATCH 48/48] refactor&fix: fix a range of bugs in scheduler and revise fine add apis (#840) * fix bugs: try to fix bugs in _submit_web_logs * fix bugs: try to address bugs * fix bugs * refactor: modify examples * revise add operation and fix an unbelievable bug * address the bug issues * the doc file has a format problem which has been fixed in this commit * add a range of new feats for the add operation * address the incompatible issue of local scheduler * feat(scheduler): optimize redis queue consumer group management - Proactively ensure consumer groups exist in '_refresh_stream_keys' for newly discovered streams. - Remove redundant consumer group checks in '_read_new_messages_batch' to improve read performance. - Clean up 'seen_streams' cache when streams are deleted to ensure correct group recreation. - This change reduces unnecessary Redis calls during high-frequency polling. * fix(tests): resolve AttributeError in SimpleStructMemReader tests - Import 'parse_json_result' from 'memos.mem_reader.utils' instead of accessing it as an instance attribute. - Fixes 'AttributeError: 'SimpleStructMemReader' object has no attribute 'parse_json_result'' in 'test_parse_json_result_success' and 'test_parse_json_result_failure'. - Remove incorrect mock assignment of 'parse_json_result' in 'test_process_chat_data'. * fix(mem_reader): pass info dict to add_before_search for correct user_id usage - Update 'add_before_search' signature in 'SimpleStructMemReader' to accept 'info' dict. - Pass 'info' (containing 'user_id' and 'session_id') to 'self.searcher.search' instead of using empty strings. - Add 'test_add_before_search' to 'TestSimpleStructMemReader' to verify the fix and ensure 'searcher.search' receives the correct 'info'. - This ensures that memory searches are scoped to the correct user and session. * refactor add_before_search from mem_reader to SingleCubeView * address bugs * fix: fix the qsize bug of task queue, and accept change from hotfix/scheduler * fix: address some issues to run old scheduler example and kv cache example * fix: address the issue of Top-level import of unavailable module 'torch' * fix: resolve linting errors and make optional dependencies lazy loaded - Fix ambiguous characters and commented-out code in examples/mem_scheduler/quick_start_examples.py - Fix nested if statements in src/memos/mem_os/core.py - Move torch and transformers imports to method scope in src/memos/llms/hf.py to support optional dependencies - Update tests/llms/test_hf.py to patch transformers module directly * refactor: revise the rewrite prompt to make it better * refactor: update examples * refactor: update examples for scheduler * fix bugs: address the unsupported xautoclaim command when redis version larger than 6.2.0 via adding a new feature of manul auto claim with the combination of xpending + xclaim * refactor: review settings * refactor: adjust examples to make it run better for code debugging * refactor: review slow add apis to get a better performance on Halumen * fix bugs: address the issue when set user_redis_queue to false, the status_tracker is still using * refactor: allow the code to run without rabbitmq * refactor: create a _parse_pending_entry for redis queue * refactor: add a try/catch for status_tracker --- .gitignore | 1 + docker/.env.example | 1 + .../mem_scheduler/quick_start_examples.py | 8 +- .../scheduler_for_async_tasks.py | 2 +- src/memos/mem_reader/simple_struct.py | 4 +- src/memos/mem_scheduler/base_scheduler.py | 9 +- .../task_schedule_modules/dispatcher.py | 31 -- .../task_schedule_modules/redis_queue.py | 276 +++++++++++++++--- .../webservice_modules/rabbitmq_service.py | 42 ++- src/memos/templates/mem_reader_prompts.py | 79 ++--- 10 files changed, 313 insertions(+), 140 deletions(-) diff --git a/.gitignore b/.gitignore index 8319a4d2f..ac31eb41a 100644 --- a/.gitignore +++ b/.gitignore @@ -204,6 +204,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +.trae # VSCode .vscode* diff --git a/docker/.env.example b/docker/.env.example index ee26c7bcd..3674cd69b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -123,6 +123,7 @@ API_SCHEDULER_ON=true API_SEARCH_WINDOW_SIZE=5 # Specify how many rounds of previous conversations (history) to retrieve and consider during the 'hybrid search' (fast search+asynchronous fine search). This helps provide context aware search results API_SEARCH_HISTORY_TURNS=5 +MEMSCHEDULER_USE_REDIS_QUEUE=false ## Graph / vector stores # Neo4j database selection mode diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py index fbfef4d76..724663be6 100644 --- a/examples/mem_scheduler/quick_start_examples.py +++ b/examples/mem_scheduler/quick_start_examples.py @@ -146,7 +146,9 @@ def kv_cache_only(): def run_scheduler_example(): # 使用 MemScheduler 加载主 MOS(Memory-Oriented System)配置文件 - config = parse_yaml("./examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml") + config = parse_yaml( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + ) # 将解析出的配置字典传入 MOSConfig 构造器, 构建配置对象 mos_config = MOSConfig(**config) # 使用配置对象初始化 MOS 系统实例 @@ -159,12 +161,12 @@ def run_scheduler_example(): # 从 YAML 文件加载 MemCube(记忆立方体)的通用配置 config = GeneralMemCubeConfig.from_yaml_file( - "./examples/data/config/mem_scheduler/mem_cube_config.yaml" + f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" ) # 定义 MemCube 的唯一标识符 mem_cube_id = "mem_cube_5" # 定义 MemCube 的本地存储路径(路径中包含用户 ID 和 MemCube ID) - mem_cube_name_or_path = f"./outputs/mem_scheduler/{user_id}/{mem_cube_id}" + mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" # 如果该路径已存在, 则先删除旧目录 if Path(mem_cube_name_or_path).exists(): diff --git a/examples/mem_scheduler/scheduler_for_async_tasks.py b/examples/mem_scheduler/scheduler_for_async_tasks.py index a767b57c4..7f544c3da 100644 --- a/examples/mem_scheduler/scheduler_for_async_tasks.py +++ b/examples/mem_scheduler/scheduler_for_async_tasks.py @@ -57,7 +57,7 @@ def submit_tasks(): TEST_HANDLER_LABEL = "test_handler" mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) -# 10s to restart +# 5s to restart mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 5_000 tmp_dir = Path("./tmp") diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 61a7d2b6d..fa72bd063 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -614,11 +614,9 @@ def _read_memory( serialized_origin_memories = json.dumps( [one.memory for one in original_memory_group], indent=2 ) - revised_memory_list = self.rewrite_memories( + revised_memory_list = self.filter_hallucination_in_memories( messages=combined_messages, memory_list=original_memory_group, - user_only=os.getenv("SIMPLE_STRUCT_REWRITE_USER_ONLY", "true").lower() - == "false", ) serialized_revised_memories = json.dumps( [one.memory for one in revised_memory_list], indent=2 diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 3f5c90b67..4c9310cbb 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -225,7 +225,7 @@ def initialize_modules( process_llm = chat_llm try: - if redis_client: + if redis_client and self.use_redis_queue: self.status_tracker = TaskStatusTracker(redis_client) if self.dispatcher: self.dispatcher.status_tracker = self.status_tracker @@ -305,7 +305,7 @@ def status_tracker(self) -> TaskStatusTracker | None: available via RedisSchedulerModule. This mirrors the lazy pattern used by `mem_cube` so downstream modules can safely access the tracker. """ - if self._status_tracker is None: + if self._status_tracker is None and self.use_redis_queue: try: self._status_tracker = TaskStatusTracker(self.redis) # Propagate to submodules when created lazily @@ -314,7 +314,8 @@ def status_tracker(self) -> TaskStatusTracker | None: if self.memos_message_queue: self.memos_message_queue.set_status_tracker(self._status_tracker) except Exception as e: - logger.warning(f"Failed to lazily initialize status_tracker: {e}", exc_info=True) + logger.warning(f"Failed to lazy-initialize status_tracker: {e}", exc_info=True) + return self._status_tracker @status_tracker.setter @@ -869,6 +870,8 @@ def _submit_web_logs( messages = [messages] # transform single message to list for message in messages: + if self.rabbitmq_config is None: + return try: # Always call publish; the publisher now caches when offline and flushes after reconnect logger.info( diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index cdd491183..2099da5a1 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -108,8 +108,6 @@ def __init__( ) self.metrics = metrics - self._status_tracker: TaskStatusTracker | None = None - # Use setter to allow propagation and keep a single source of truth self.status_tracker = status_tracker self.submit_web_logs = submit_web_logs # ADDED @@ -118,35 +116,6 @@ def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None: return # This is handled in BaseScheduler now - @property - def status_tracker(self) -> TaskStatusTracker | None: - """Lazy-initialized status tracker for the dispatcher. - - If the tracker is None, attempt to initialize from the Redis-backed - components available to the dispatcher (queue or orchestrator). - """ - if self._status_tracker is None: - try: - self._status_tracker = TaskStatusTracker(self.redis) - # Propagate to submodules when created lazily - if self.memos_message_queue: - self.memos_message_queue.set_status_tracker(self._status_tracker) - except Exception as e: - logger.warning(f"Failed to lazily initialize status_tracker: {e}", exc_info=True) - return self._status_tracker - - @status_tracker.setter - def status_tracker(self, value: TaskStatusTracker | None) -> None: - self._status_tracker = value - # Propagate to the queue if possible - try: - if self.memos_message_queue and hasattr(self.memos_message_queue, "status_tracker"): - self.memos_message_queue.status_tracker = value - except Exception as e: - logger.warning( - f"Failed to propagate dispatcher status_tracker to queue: {e}", exc_info=True - ) - def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ Create a wrapper around the handler to track task execution and capture results. diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 557a45466..1c9683542 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -81,6 +81,7 @@ def __init__( # Consumer state self._is_listening = False self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + self.supports_xautoclaim = False # Connection state self._is_connected = False @@ -105,6 +106,7 @@ def __init__( # Auto-initialize Redis connection if self.auto_initialize_redis(): self._is_connected = True + self._check_xautoclaim_support() self.seen_streams = set() @@ -143,6 +145,33 @@ def __init__( logger.debug(f"Initial stream keys refresh failed: {e}") self._start_stream_keys_refresh_thread() + def _check_xautoclaim_support(self): + """Check if the Redis server supports xautoclaim (v6.2+).""" + if not self._redis_conn: + return + + try: + info = self._redis_conn.info("server") + version_str = info.get("redis_version", "0.0.0") + # Simple version parsing + parts = [int(p) for p in version_str.split(".") if p.isdigit()] + while len(parts) < 3: + parts.append(0) + + major, minor, _ = parts[:3] + if major > 6 or (major == 6 and minor >= 2): + self.supports_xautoclaim = True + else: + self.supports_xautoclaim = False + + logger.info( + f"[REDIS_QUEUE] Redis version {version_str}. " + f"Supports xautoclaim: {self.supports_xautoclaim}" + ) + except Exception as e: + logger.warning(f"Failed to check Redis version: {e}") + self.supports_xautoclaim = False + def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}" return stream_key @@ -623,41 +652,67 @@ def _compute_pending_need( need_pending = max(0, batch_size - new_count) return need_pending if need_pending > 0 else 0 + def _parse_pending_entry(self, entry) -> tuple[str, int]: + """Extract message_id and idle_time from a pending entry (dict, tuple, or object).""" + if isinstance(entry, dict): + return entry.get("message_id"), entry.get("time_since_delivered") + elif isinstance(entry, tuple | list): + return entry[0], entry[2] + else: + # Assume object (redis-py 5.x+ PendingMessage) + return getattr(entry, "message_id", None), getattr(entry, "time_since_delivered", 0) + + def _manual_xautoclaim( + self, stream_key: str, min_idle_time: int, count: int + ) -> tuple[str, list[tuple[str, dict]], list[str]]: + """ + Simulate xautoclaim using xpending and xclaim for compatibility with older Redis versions. + """ + # 1. Get pending entries (fetch slightly more to increase chance of finding idle ones) + fetch_count = count * 3 + pending_entries = self._redis_conn.xpending_range( + stream_key, self.consumer_group, "-", "+", fetch_count + ) + + if not pending_entries: + return "0-0", [], [] + + claim_ids = [] + for entry in pending_entries: + # entry structure depends on redis-py version/decoding + # Assuming list of dicts: {'message_id': '...', 'time_since_delivered': ms, ...} + # or list of tuples + msg_id, idle_time = self._parse_pending_entry(entry) + + if idle_time >= min_idle_time: + claim_ids.append(msg_id) + if len(claim_ids) >= count: + break + + if not claim_ids: + return "0-0", [], [] + + # 2. Claim messages + claimed_messages = self._redis_conn.xclaim( + stream_key, self.consumer_group, self.consumer_name, min_idle_time, claim_ids + ) + + return "0-0", claimed_messages, [] + def _claim_pending_messages( self, stream_key: str, need_pending_count: int, task_label: str ) -> list[tuple[str, list[tuple[str, dict]]]]: """Claim pending messages exceeding idle threshold, with group existence handling.""" - try: - claimed_result = self._redis_conn.xautoclaim( - name=stream_key, - groupname=self.consumer_group, - consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label), - start_id="0-0", - count=need_pending_count, - justid=False, - ) - if len(claimed_result) == 2: - next_id, claimed = claimed_result - deleted_ids = [] - elif len(claimed_result) == 3: - next_id, claimed, deleted_ids = claimed_result - else: - raise ValueError(f"Unexpected xautoclaim response length: {len(claimed_result)}") + min_idle = self.orchestrator.get_task_idle_min(task_label=task_label) - return [(stream_key, claimed)] if claimed else [] - except Exception as read_err: - err_msg = str(read_err).lower() - if "nogroup" in err_msg or "no such key" in err_msg: - logger.warning( - f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)." - ) - self._ensure_consumer_group(stream_key=stream_key) + # Use native xautoclaim if supported (Redis 6.2+) + if self.supports_xautoclaim: + try: claimed_result = self._redis_conn.xautoclaim( name=stream_key, groupname=self.consumer_group, consumername=self.consumer_name, - min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label), + min_idle_time=min_idle, start_id="0-0", count=need_pending_count, justid=False, @@ -670,25 +725,64 @@ def _claim_pending_messages( else: raise ValueError( f"Unexpected xautoclaim response length: {len(claimed_result)}" - ) from read_err + ) return [(stream_key, claimed)] if claimed else [] - return [] - - def _batch_claim_pending_messages( - self, claims_spec: list[tuple[str, int, str]] - ) -> list[tuple[str, list[tuple[str, dict]]]]: - """Batch-claim pending messages across multiple streams. + except Exception as read_err: + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)." + ) + self._ensure_consumer_group(stream_key=stream_key) + claimed_result = self._redis_conn.xautoclaim( + name=stream_key, + groupname=self.consumer_group, + consumername=self.consumer_name, + min_idle_time=min_idle, + start_id="0-0", + count=need_pending_count, + justid=False, + ) + if len(claimed_result) == 2: + next_id, claimed = claimed_result + deleted_ids = [] + elif len(claimed_result) == 3: + next_id, claimed, deleted_ids = claimed_result + else: + raise ValueError( + f"Unexpected xautoclaim response length: {len(claimed_result)}" + ) from read_err - Args: - claims_spec: List of tuples (stream_key, need_pending_count, task_label) + return [(stream_key, claimed)] if claimed else [] + return [] - Returns: - A list of (stream_key, claimed_entries) pairs for all successful claims. - """ - if not self._redis_conn or not claims_spec: + # Fallback to manual xautoclaim for older Redis versions + try: + _next, claimed, _deleted = self._manual_xautoclaim( + stream_key, min_idle, need_pending_count + ) + return [(stream_key, claimed)] if claimed else [] + except Exception as read_err: + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (manual xautoclaim)." + ) + self._ensure_consumer_group(stream_key=stream_key) + try: + _next, claimed, _deleted = self._manual_xautoclaim( + stream_key, min_idle, need_pending_count + ) + return [(stream_key, claimed)] if claimed else [] + except Exception: + return [] return [] + def _batch_claim_native( + self, claims_spec: list[tuple[str, int, str]] + ) -> list[tuple[str, list[tuple[str, dict]]]]: + """Batch-claim pending messages using Redis xautoclaim pipeline (Redis 6.2+).""" pipe = self._redis_conn.pipeline(transaction=False) for stream_key, need_count, label in claims_spec: pipe.xautoclaim( @@ -702,14 +796,11 @@ def _batch_claim_pending_messages( ) try: - # Execute with raise_on_error=False so we get exceptions in the results list - # instead of aborting the whole batch. results = pipe.execute(raise_on_error=False) except Exception as e: logger.error(f"Pipeline execution critical failure: {e}") results = [e] * len(claims_spec) - # Handle individual failures (e.g. NOGROUP) by retrying just that stream final_results = [] for i, res in enumerate(results): if isinstance(res, Exception): @@ -736,12 +827,8 @@ def _batch_claim_pending_messages( else: final_results.append(res) - results = final_results - - claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] - for (stream_key, _need_count, _label), claimed_result in zip( - claims_spec, results, strict=False - ): + claimed_pairs = [] + for (stream_key, _, _), claimed_result in zip(claims_spec, final_results, strict=False): try: if not claimed_result: continue @@ -760,6 +847,98 @@ def _batch_claim_pending_messages( return claimed_pairs + def _batch_claim_manual( + self, claims_spec: list[tuple[str, int, str]] + ) -> list[tuple[str, list[tuple[str, dict]]]]: + """Batch-claim pending messages using 2-phase pipeline (Redis < 6.2).""" + # Phase 1: Fetch pending messages for all streams + pending_pipe = self._redis_conn.pipeline(transaction=False) + for stream_key, need_count, _label in claims_spec: + fetch_count = need_count * 3 + pending_pipe.xpending_range(stream_key, self.consumer_group, "-", "+", fetch_count) + + try: + pending_results = pending_pipe.execute(raise_on_error=False) + except Exception as e: + logger.error(f"Pending fetch pipeline failed: {e}") + return [] + + # Phase 2: Filter and prepare claim pipeline + claim_pipe = self._redis_conn.pipeline(transaction=False) + streams_to_claim_indices = [] + claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = [] + + for i, (stream_key, need_count, label) in enumerate(claims_spec): + pending_res = pending_results[i] + min_idle = self.orchestrator.get_task_idle_min(task_label=label) + + if isinstance(pending_res, Exception): + err_msg = str(pending_res).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + try: + self._ensure_consumer_group(stream_key) + _next, claimed, _ = self._manual_xautoclaim( + stream_key, min_idle, need_count + ) + if claimed: + claimed_pairs.append((stream_key, claimed)) + except Exception as retry_err: + logger.warning(f"Retry manual claim failed for {stream_key}: {retry_err}") + continue + + if not pending_res: + continue + + claim_ids = [] + for entry in pending_res: + msg_id, idle_time = self._parse_pending_entry(entry) + if idle_time >= min_idle: + claim_ids.append(msg_id) + if len(claim_ids) >= need_count: + break + + if claim_ids: + claim_pipe.xclaim( + stream_key, + self.consumer_group, + self.consumer_name, + min_idle, + claim_ids, + ) + streams_to_claim_indices.append(i) + + if streams_to_claim_indices: + try: + claim_results = claim_pipe.execute(raise_on_error=False) + for idx_in_results, original_idx in enumerate(streams_to_claim_indices): + res = claim_results[idx_in_results] + stream_key = claims_spec[original_idx][0] + if isinstance(res, list) and res: + claimed_pairs.append((stream_key, res)) + except Exception as e: + logger.error(f"Claim pipeline failed: {e}") + + return claimed_pairs + + def _batch_claim_pending_messages( + self, claims_spec: list[tuple[str, int, str]] + ) -> list[tuple[str, list[tuple[str, dict]]]]: + """Batch-claim pending messages across multiple streams. + + Args: + claims_spec: List of tuples (stream_key, need_pending_count, task_label) + + Returns: + A list of (stream_key, claimed_entries) pairs for all successful claims. + """ + if not self._redis_conn or not claims_spec: + return [] + + if self.supports_xautoclaim: + return self._batch_claim_native(claims_spec) + + return self._batch_claim_manual(claims_spec) + def _convert_messages( self, messages: list[tuple[str, list[tuple[str, dict]]]] ) -> list[ScheduleMessageItem]: @@ -994,6 +1173,7 @@ def connect(self) -> None: # Test the connection self._redis_conn.ping() self._is_connected = True + self._check_xautoclaim_support() logger.debug("Redis connection established successfully") # Start stream keys refresher when connected self._start_stream_keys_refresh_thread() diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py index 5a94d2af2..a07934b8e 100644 --- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py @@ -30,6 +30,7 @@ def __init__(self): Initialize RabbitMQ connection settings. """ super().__init__() + self.auth_config = None # RabbitMQ settings self.rabbitmq_config: RabbitMQConfig | None = None @@ -99,22 +100,35 @@ def initialize_rabbitmq( ) return + if self.is_rabbitmq_connected(): + logger.warning("RabbitMQ is already connected. Skipping initialization.") + return + from pika.adapters.select_connection import SelectConnection - if config is None: - if config_path is None and AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - elif Path(config_path).exists(): - auth_config = AuthConfig.from_local_config(config_path=config_path) + if config is not None: + if isinstance(config, RabbitMQConfig): + self.rabbitmq_config = config + elif isinstance(config, dict): + self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq else: - auth_config = AuthConfig.from_local_env() - self.rabbitmq_config = auth_config.rabbitmq - elif isinstance(config, RabbitMQConfig): - self.rabbitmq_config = config - elif isinstance(config, dict): - self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq + logger.error(f"Unsupported config type: {type(config)}") + return + else: - logger.error("Not implemented") + if config_path is not None and Path(config_path).exists(): + self.auth_config = AuthConfig.from_local_config(config_path=config_path) + elif AuthConfig.default_config_exists(): + self.auth_config = AuthConfig.from_local_config() + else: + self.auth_config = AuthConfig.from_local_env() + self.rabbitmq_config = self.auth_config.rabbitmq + + if self.rabbitmq_config is None: + logger.error( + "Failed to load RabbitMQ configuration. Please check your config file or environment variables." + ) + return # Load exchange configuration from config if self.rabbitmq_config: @@ -140,7 +154,7 @@ def initialize_rabbitmq( self.rabbitmq_exchange_type = env_exchange_type logger.info(f"Using env exchange type override: {self.rabbitmq_exchange_type}") - # Start connection process + # Start connection process parameters = self.get_rabbitmq_connection_param() self.rabbitmq_connection = SelectConnection( parameters, @@ -156,7 +170,7 @@ def initialize_rabbitmq( self._io_loop_thread.start() logger.info("RabbitMQ connection process started") except Exception: - logger.error("Fail to initialize auth_config", exc_info=True) + logger.error("Failed to initialize RabbitMQ connection", exc_info=True) finally: with self._rabbitmq_lock: self._rabbitmq_initializing = False diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index 9432d6303..20f8150b7 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -796,43 +796,48 @@ """ SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """ -You are a strict memory validator. -Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages. - -Rules: -1. **User-Only Origin**: Verify facts against USER messages ONLY. If the Assistant repeats a User fact, it is VALID. If the Assistant introduces a new detail (e.g., 'philanthropy') that the User did not explicitly confirm, it is INVALID. -2. **No Inference Allowed**: Do NOT keep memories based on implication, emotion, preference, or generalization. Only verbatim or direct restatements of user-provided facts are valid. However, minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED. -3. **Hallucination = Deletion**: If a memory contains any detail not directly expressed by the user, mark it for deletion. -4. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. - -Examples: -Messages: -- [user]: I love coding in Python. -- [assistant]: That's great! I assume you also contribute to open source projects? -Memory: User enjoys Python and contributes to open source. -Result: {{"keep": false, "reason": "User never stated they contribute to open source; this came from Assistant's assumption."}} - -Messages: -- [user]: I am tired. -- [assistant]: I hear you are tired. Rest is important. -Memory: User stated they are tired. -Result: {{"keep": true, "reason": "Direct restatement of user input, even if Assistant repeated it."}} - -Inputs: -messages: -{messages_inline} - -memories: -{memories_inline} - -Output Format: -- Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices. -- Each value must be: {{ "keep": boolean, "reason": string }} -- "keep": true only if the memory is a direct reflection of the user's explicit words. -- "reason": brief, factual, and cites missing or unsupported content. - -Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. -""" + You are a strict memory validator. + Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages. + + Rules: + 1. **Explicit Denial & Inconsistency**: If a memory claims something that the user explicitly denied or is clearly inconsistent with the user's statements, mark it for deletion. + 2. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite. + + Example: + Messages: + [user]: I'm planning a trip to Japan next month for about a week. + [assistant]: That sounds great! Are you planning to visit Tokyo Disneyland? + [user]: No, I won't be going to Tokyo this time. I plan to stay in Kyoto and Osaka to avoid crowds. + + Memories: + {{ + "0": "User plans to travel to Japan for a week next month.", + "1": "User intends to visit Tokyo Disneyland.", + "2": "User plans to stay in Kyoto and Osaka." + }} + + Output: + {{ + "0": {{ "keep": true, "reason": "Explicitly stated by user." }}, + "1": {{ "keep": false, "reason": "User explicitly denied visiting Tokyo." }}, + "2": {{ "keep": true, "reason": "Explicitly stated by user." }} + }} + + Inputs: + Messages: + {messages_inline} + + Memories: + {memories_inline} + + Output Format: + - Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices. + - Each value must be: {{ "keep": boolean, "reason": string }} + - "keep": true only if the memory is a direct reflection of the user's explicit words. + - "reason": brief, factual, and cites missing or unsupported content. + + Important: Output **only** the JSON. No extra text, explanations, markdown, or fields. + """ SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT = """