diff --git a/.gitignore b/.gitignore
index 8319a4d2f..ac31eb41a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -204,6 +204,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
+.trae
# VSCode
.vscode*
diff --git a/README.md b/README.md
index a2f713b51..f19b97cc1 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,6 @@
MemOS is an open-source **Agent Memory framework** that empowers AI agents with **long-term memory, personality consistency, and contextual recall**. It enables agents to **remember past interactions**, **learn over time**, and **build evolving identities** across sessions.
Designed for **AI companions, role-playing NPCs, and multi-agent systems**, MemOS provides a unified API for **memory representation, retrieval, and update** — making it the foundation for next-generation **memory-augmented AI agents**.
-
-🆕 **MemOS 2.0** introduces **knowledge base system**, **multi-modal memory** (images & documents), **tool memory** for Agent optimization, **memory feedback mechanism** for precise control, and **enterprise-grade architecture** with Redis Streams scheduler and advanced DB optimizations.
@@ -117,17 +115,6 @@ showcasing its capabilities in **information extraction**, **temporal and cross-
- **Textual Memory**: For storing and retrieving unstructured or structured text knowledge.
- **Activation Memory**: Caches key-value pairs (`KVCacheMemory`) to accelerate LLM inference and context reuse.
- **Parametric Memory**: Stores model adaptation parameters (e.g., LoRA weights).
- - **Tool Memory** 🆕: Records Agent tool call trajectories and experiences to improve planning capabilities.
-- **📚 Knowledge Base System** 🆕: Build multi-dimensional knowledge bases with automatic document/URL parsing, splitting, and cross-project sharing capabilities.
-- **🔧 Memory Controllability** 🆕:
- - **Feedback Mechanism**: Use `add_feedback` API to correct, supplement, or replace existing memories with natural language.
- - **Precise Deletion**: Delete specific memories by User ID or Memory ID via API or MCP tools.
-- **👁️ Multi-Modal Support** 🆕: Support for image understanding and memory, including chart parsing in documents.
-- **⚡ Advanced Architecture**:
- - **DB Optimization**: Enhanced connection management and batch insertion for high-concurrency scenarios.
- - **Advanced Retrieval**: Custom tag and info field filtering with complex logical operations.
- - **Redis Streams Scheduler**: Multi-level queue architecture with intelligent orchestration for fair multi-tenant scheduling.
- - **Stream & Non-Stream Chat**: Ready-to-use streaming and non-streaming chat interfaces.
- **🔌 Extensible**: Easily extend and customize memory modules, data sources, and LLM integrations.
- **🏂 Lightweight Deployment** 🆕: Support for quick mode and complete mode deployment options.
@@ -181,6 +168,7 @@ res = client.search_memory(query=query, user_id=user_id, conversation_id=convers
print(f"result: {res}")
```
+
### Self-Hosted Server
1. Get the repository.
```bash
@@ -215,7 +203,7 @@ Example
```python
import requests
import json
-
+
data = {
"user_id": "8736b16e-1d20-4163-980b-a5063c3facdc",
"mem_cube_id": "b32d0977-435d-4828-a86f-4f47f8b55bca",
@@ -231,7 +219,7 @@ Example
"Content-Type": "application/json"
}
url = "http://localhost:8000/product/add"
-
+
res = requests.post(url=url, headers=headers, data=json.dumps(data))
print(f"result: {res.json()}")
```
@@ -239,7 +227,7 @@ Example
```python
import requests
import json
-
+
data = {
"query": "What do I like",
"user_id": "8736b16e-1d20-4163-980b-a5063c3facdc",
@@ -249,7 +237,7 @@ Example
"Content-Type": "application/json"
}
url = "http://localhost:8000/product/search"
-
+
res = requests.post(url=url, headers=headers, data=json.dumps(data))
print(f"result: {res.json()}")
```
diff --git a/docker/.env.example b/docker/.env.example
index f1979fe4c..3674cd69b 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -76,6 +76,7 @@ MODEL=gpt-4o-mini
# embedding model for evaluation
EMBEDDING_MODEL=nomic-embed-text:latest
+
## Internet search & preference memory
# Enable web search
ENABLE_INTERNET=false
@@ -122,6 +123,7 @@ API_SCHEDULER_ON=true
API_SEARCH_WINDOW_SIZE=5
# Specify how many rounds of previous conversations (history) to retrieve and consider during the 'hybrid search' (fast search+asynchronous fine search). This helps provide context aware search results
API_SEARCH_HISTORY_TURNS=5
+MEMSCHEDULER_USE_REDIS_QUEUE=false
## Graph / vector stores
# Neo4j database selection mode
@@ -211,4 +213,4 @@ MEMSCHEDULER_RABBITMQ_VIRTUAL_HOST=memos
# Erase connection state on connect for message-log pipeline
MEMSCHEDULER_RABBITMQ_ERASE_ON_CONNECT=true
# RabbitMQ port for message-log pipeline
-MEMSCHEDULER_RABBITMQ_PORT=5672
\ No newline at end of file
+MEMSCHEDULER_RABBITMQ_PORT=5672
diff --git a/docker/requirements-full.txt b/docker/requirements-full.txt
index e8911cbb9..be9ed2068 100644
--- a/docker/requirements-full.txt
+++ b/docker/requirements-full.txt
@@ -183,4 +183,4 @@ psycopg2-binary==2.9.11
py-key-value-aio==0.2.8
py-key-value-shared==0.2.8
PyJWT==2.10.1
-pytest==9.0.2
\ No newline at end of file
+pytest==9.0.2
diff --git a/docker/requirements.txt b/docker/requirements.txt
index 939b13678..f89617c10 100644
--- a/docker/requirements.txt
+++ b/docker/requirements.txt
@@ -1,6 +1,3 @@
-# Docker optimized requirements - Core dependencies only
-# Excludes Windows-specific and heavy GPU packages for faster builds
-
annotated-types==0.7.0
anyio==4.11.0
attrs==25.4.0
@@ -125,4 +122,4 @@ urllib3==2.5.0
uvicorn==0.38.0
uvloop==0.22.1; sys_platform != 'win32'
watchfiles==1.1.1
-websockets==15.0.1
\ No newline at end of file
+websockets==15.0.1
diff --git a/docs/README.md b/docs/README.md
index bf5fea70d..8be17ffb7 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -1,3 +1,3 @@
All documentation has been moved to a separate repository: https://github.com/MemTensor/MemOS-Docs. Please edit documentation there.
-所有文档已迁移至独立仓库:https://github.com/MemTensor/MemOS-Docs。请在该仓库中编辑文档。
+所有文档已迁移至独立仓库 https://github.com/MemTensor/MemOS-Docs 。请在该仓库中编辑文档。
diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh
index 6f5f3b7b0..f65873cb9 100755
--- a/evaluation/scripts/run_prefeval_eval.sh
+++ b/evaluation/scripts/run_prefeval_eval.sh
@@ -108,7 +108,9 @@ python $LIB_SCRIPT search \
--input $IDS_FILE \
--output $SEARCH_FILE \
--top-k $TOP_K \
- --max-workers $WORKERS
+ --max-workers $WORKERS \
+ --lib $LIB \
+ --version $VERSION
if [ $? -ne 0 ]; then
echo "Error: $LIB_SCRIPT 'search' mode failed."
@@ -121,7 +123,9 @@ echo "Running $LIB_SCRIPT in 'response' mode..."
python $LIB_SCRIPT response \
--input $SEARCH_FILE \
--output $RESPONSE_FILE \
- --max-workers $WORKERS
+ --max-workers $WORKERS \
+ --lib $LIB \
+ --version $VERSION
if [ $? -ne 0 ]; then
echo "Error: $LIB_SCRIPT 'response' mode failed."
diff --git a/examples/data/config/mem_scheduler/mem_cube_config.yaml b/examples/data/config/mem_scheduler/mem_cube_config.yaml
new file mode 100644
index 000000000..398d8dbb3
--- /dev/null
+++ b/examples/data/config/mem_scheduler/mem_cube_config.yaml
@@ -0,0 +1,21 @@
+user_id: "user_test"
+cube_id: "user_test/mem_cube_naive"
+text_mem:
+ backend: "naive_text"
+ config:
+ extractor_llm:
+ backend: "huggingface_singleton"
+ config:
+ model_name_or_path: "Qwen/Qwen3-0.6B"
+ temperature: 0.1
+ max_tokens: 1024
+act_mem:
+ backend: "kv_cache"
+ config:
+ memory_filename: "activation_memory.pickle"
+ extractor_llm:
+ backend: "huggingface_singleton"
+ config:
+ model_name_or_path: "Qwen/Qwen3-0.6B"
+ temperature: 0.8
+ max_tokens: 1024
diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml
index bd9910300..a5e91dc4e 100644
--- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml
+++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml
@@ -10,16 +10,12 @@ mem_reader:
backend: "simple_struct"
config:
llm:
- backend: "openai"
+ backend: "huggingface_singleton"
config:
- model_name_or_path: "gpt-4o-mini"
- temperature: 0.8
- max_tokens: 4096
- top_p: 0.9
- top_k: 50
+ model_name_or_path: "Qwen/Qwen3-1.7B"
+ temperature: 0.1
remove_think_prefix: true
- api_key: "sk-xxxxxx"
- api_base: "https://api.openai.com/v1"
+ max_tokens: 4096
embedder:
backend: "ollama"
config:
diff --git a/examples/mem_scheduler/quick_start_examples.py b/examples/mem_scheduler/quick_start_examples.py
new file mode 100644
index 000000000..724663be6
--- /dev/null
+++ b/examples/mem_scheduler/quick_start_examples.py
@@ -0,0 +1,312 @@
+import json
+import shutil
+import sys
+import uuid
+
+from pathlib import Path
+
+from transformers import DynamicCache
+
+from memos.configs.mem_cube import GeneralMemCubeConfig
+from memos.configs.mem_os import MOSConfig
+from memos.configs.memory import MemoryConfigFactory
+from memos.mem_cube.general import GeneralMemCube
+from memos.mem_os.main import MOS
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.schemas.task_schemas import (
+ ANSWER_TASK_LABEL,
+ MEM_UPDATE_TASK_LABEL,
+ QUERY_TASK_LABEL,
+)
+from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.mem_scheduler.utils.misc_utils import parse_yaml
+from memos.memories.activation.item import KVCacheItem
+from memos.memories.factory import MemoryFactory
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+
+def get_cache_info(cache):
+ if not cache:
+ return None
+
+ num_layers = 0
+ total_size_bytes = 0
+
+ if hasattr(cache, "layers"):
+ num_layers = len(cache.layers)
+ for layer in cache.layers:
+ if hasattr(layer, "key_cache") and layer.key_cache is not None:
+ total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size()
+ if hasattr(layer, "value_cache") and layer.value_cache is not None:
+ total_size_bytes += layer.value_cache.nelement() * layer.value_cache.element_size()
+
+ if hasattr(layer, "keys") and layer.keys is not None:
+ total_size_bytes += layer.keys.nelement() * layer.keys.element_size()
+ if hasattr(layer, "values") and layer.values is not None:
+ total_size_bytes += layer.values.nelement() * layer.values.element_size()
+
+ elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
+ num_layers = len(cache.key_cache)
+ for k, v in zip(cache.key_cache, cache.value_cache, strict=False):
+ if k is not None:
+ total_size_bytes += k.nelement() * k.element_size()
+ if v is not None:
+ total_size_bytes += v.nelement() * v.element_size()
+
+ return {
+ "num_layers": num_layers,
+ "size_bytes": total_size_bytes,
+ "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB",
+ }
+
+
+def serialize_item(obj):
+ if isinstance(obj, list):
+ return [serialize_item(x) for x in obj]
+
+ if isinstance(obj, KVCacheItem):
+ return {
+ "id": obj.id,
+ "metadata": obj.metadata,
+ "records": obj.records.model_dump()
+ if hasattr(obj.records, "model_dump")
+ else obj.records,
+ "memory": get_cache_info(obj.memory),
+ }
+
+ if isinstance(obj, DynamicCache):
+ return get_cache_info(obj)
+
+ return str(obj)
+
+
+def kv_cache_only():
+ # 为 KVCacheMemory(HuggingFace 后端)创建配置
+ config = MemoryConfigFactory(
+ backend="kv_cache",
+ config={
+ "extractor_llm": {
+ "backend": "huggingface",
+ "config": {
+ "model_name_or_path": "Qwen/Qwen3-0.6B",
+ "max_tokens": 32,
+ "add_generation_prompt": True,
+ "remove_think_prefix": True,
+ },
+ },
+ },
+ )
+
+ # 实例化 KVCacheMemory
+ kv_mem = MemoryFactory.from_config(config)
+
+ # 提取一个 KVCacheItem(DynamicCache)
+ prompt = [
+ {"role": "user", "content": "What is MemOS?"},
+ {"role": "assistant", "content": "MemOS is a memory operating system for LLMs."},
+ ]
+ print("===== Extract KVCacheItem =====")
+ cache_item = kv_mem.extract(prompt)
+ print(json.dumps(serialize_item(cache_item), indent=2, default=str))
+
+ # 将缓存添加到内存中
+ kv_mem.add([cache_item])
+ print("All caches:")
+ print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str))
+
+ # 通过 ID 获取
+ retrieved = kv_mem.get(cache_item.id)
+ print("Retrieved:")
+ print(json.dumps(serialize_item(retrieved), indent=2, default=str))
+
+ # 合并缓存
+ item2 = kv_mem.extract([{"role": "user", "content": "Tell me a joke."}])
+ kv_mem.add([item2])
+ merged = kv_mem.get_cache([cache_item.id, item2.id])
+ print("Merged cache:")
+ print(json.dumps(serialize_item(merged), indent=2, default=str))
+
+ # 删除其中一个
+ kv_mem.delete([cache_item.id])
+ print("After delete:")
+ print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str))
+
+ # 导出和加载缓存
+ kv_mem.dump("tmp/kv_mem")
+ print("Dumped to tmp/kv_mem")
+ kv_mem.delete_all()
+ kv_mem.load("tmp/kv_mem")
+ print("Loaded caches:")
+ print(json.dumps(serialize_item(kv_mem.get_all()), indent=2, default=str))
+
+
+def run_scheduler_example():
+ # 使用 MemScheduler 加载主 MOS(Memory-Oriented System)配置文件
+ config = parse_yaml(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
+ )
+ # 将解析出的配置字典传入 MOSConfig 构造器, 构建配置对象
+ mos_config = MOSConfig(**config)
+ # 使用配置对象初始化 MOS 系统实例
+ mos = MOS(mos_config)
+
+ # 生成一个唯一的动态用户 ID(使用 UUID4)
+ user_id = str(uuid.uuid4())
+ # 在 MOS 系统中为该用户创建账户
+ mos.create_user(user_id=user_id)
+
+ # 从 YAML 文件加载 MemCube(记忆立方体)的通用配置
+ config = GeneralMemCubeConfig.from_yaml_file(
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
+ )
+ # 定义 MemCube 的唯一标识符
+ mem_cube_id = "mem_cube_5"
+ # 定义 MemCube 的本地存储路径(路径中包含用户 ID 和 MemCube ID)
+ mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
+
+ # 如果该路径已存在, 则先删除旧目录
+ if Path(mem_cube_name_or_path).exists():
+ shutil.rmtree(mem_cube_name_or_path)
+ print(f"{mem_cube_name_or_path} 目录非空,已被删除。")
+
+ # 根据加载的配置创建一个新的 MemCube 实例
+ mem_cube = GeneralMemCube(config)
+ # 将该 MemCube 实例序列化并保存到指定路径
+ mem_cube.dump(mem_cube_name_or_path)
+
+ # 在 MOS 系统中为当前用户注册这个 MemCube
+ mos.register_mem_cube(
+ mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
+ )
+
+ # 定义一个辅助函数, 用于获取缓存(如 KV Cache)的内存信息
+ def get_cache_info(cache):
+ # 如果缓存为空, 则直接返回 None
+ if not cache:
+ return None
+
+ num_layers = 0 # 记录缓存的层数
+ total_size_bytes = 0 # 记录总字节数
+
+ # 情况一: 缓存结构包含 layers 属性(如 HuggingFace 的缓存格式)
+ if hasattr(cache, "layers"):
+ num_layers = len(cache.layers)
+ for layer in cache.layers:
+ # 统计 key_cache 的内存占用(如果存在)
+ if hasattr(layer, "key_cache") and layer.key_cache is not None:
+ total_size_bytes += layer.key_cache.nelement() * layer.key_cache.element_size()
+ # 统计 value_cache 的内存占用(如果存在)
+ if hasattr(layer, "value_cache") and layer.value_cache is not None:
+ total_size_bytes += (
+ layer.value_cache.nelement() * layer.value_cache.element_size()
+ )
+
+ # 兼容其他可能的缓存命名方式(如 keys/values)
+ if hasattr(layer, "keys") and layer.keys is not None:
+ total_size_bytes += layer.keys.nelement() * layer.keys.element_size()
+ if hasattr(layer, "values") and layer.values is not None:
+ total_size_bytes += layer.values.nelement() * layer.values.element_size()
+
+ # 情况二: 缓存结构直接包含 key_cache 和 value_cache 列表(如某些自定义格式)
+ elif hasattr(cache, "key_cache") and hasattr(cache, "value_cache"):
+ num_layers = len(cache.key_cache)
+ for k, v in zip(cache.key_cache, cache.value_cache, strict=False):
+ if k is not None:
+ total_size_bytes += k.nelement() * k.element_size()
+ if v is not None:
+ total_size_bytes += v.nelement() * v.element_size()
+
+ # 返回结构化的缓存信息, 包括层数, 字节数和以 MB 为单位的可读格式
+ return {
+ "num_layers": num_layers,
+ "size_bytes": total_size_bytes,
+ "size_mb": f"{total_size_bytes / (1024 * 1024):.2f} MB",
+ }
+
+ # 定义自定义的查询(query)处理函数
+ def custom_query_handler(messages: list[ScheduleMessageItem]):
+ for msg in messages:
+ # 打印用户输入内容
+ print(f"\n[scheduler] 用户输入了查询:{msg.content}")
+ # 手动构造一个带有 MEM_UPDATE 标签的新消息, 用于触发记忆更新
+ new_msg = msg.model_copy(update={"label": MEM_UPDATE_TASK_LABEL})
+ # 将该消息提交给调度器处理
+ mos.mem_scheduler.submit_messages([new_msg])
+
+ # 定义自定义的回答(answer)处理函数
+ def custom_answer_handler(messages: list[ScheduleMessageItem]):
+ for msg in messages:
+ # 打印 LLM 的回复内容
+ print(f"\n[scheduler] LLM 回复了答案:{msg.content}")
+
+ # 定义自定义的记忆更新(mem_update)处理函数
+ def custom_mem_update_handler(messages: list[ScheduleMessageItem]):
+ for msg in messages:
+ mem_cube = mos.mem_cubes.get(msg.mem_cube_id)
+ kv_mem = mem_cube.act_mem
+ # 如果该 MemCube 配置了文本记忆(TreeTextMemory / NaiveTextMemory)
+ if mem_cube and mem_cube.text_mem:
+ # 在文本记忆中搜索与当前内容相关的记忆(返回 top_k=3 条)
+ results = mem_cube.text_mem.search(msg.content, top_k=3)
+ for mem in results:
+ print(f"\n[scheduler] 检索到的记忆:{mem.memory}")
+ print("\n[scheduler] 转换为激活记忆......")
+ # 从文本记忆中提取对应的 KV 缓存项
+ cache_item = kv_mem.extract(mem.memory)
+ # 附加元信息
+ cache_item.records.text_memories = [mem.memory]
+ cache_item.records.timestamp = get_utc_now()
+ # 将该缓存项添加到激活记忆中
+ kv_mem.add([cache_item])
+ print("\n[scheduler] 完成!")
+
+ # 将上述三个自定义处理器注册到调度器的分发器中, 分别对应不同任务标签
+ mos.mem_scheduler.dispatcher.register_handlers(
+ {
+ QUERY_TASK_LABEL: custom_query_handler, # 查询任务
+ ANSWER_TASK_LABEL: custom_answer_handler, # 回答任务
+ MEM_UPDATE_TASK_LABEL: custom_mem_update_handler, # 记忆更新任务
+ }
+ )
+
+ # 初始添加两条测试消息(用户和助手的对话)到系统中
+ messages = [
+ {"role": "user", "content": "I like playing football."},
+ {"role": "assistant", "content": "I like playing football too."},
+ ]
+ mos.add(messages, user_id=user_id, mem_cube_id=mem_cube_id)
+
+ # 进入聊天循环: 展示 TreeTextMemory 的记忆节点结构 + KV Cache 的状态
+ while True:
+ # 获取用户输入并去除首尾空格
+ user_input = input("👤 [You] ").strip()
+ print()
+ # 调用 MOS 系统进行聊天响应
+ response = mos.chat(user_input, user_id=user_id)
+ # 获取该用户当前 MemCube 中的所有记忆内容
+ retrieved_memories = mos.get_all(mem_cube_id=mem_cube_id, user_id=user_id)
+
+ # 打印助手的回复
+ print(f"🤖 [Assistant] {response}")
+
+ # 获取文本记忆部分 - TreeTextMemory
+ memories = retrieved_memories["text_mem"][0]["memories"]
+ for mem in memories:
+ print(f"[文本记忆] {mem.memory}")
+
+ # 获取对应的 MemCube 和其激活记忆(KV Cache)
+ mem_cube = mos.mem_scheduler.mem_cube
+ kv_mem = mem_cube.act_mem
+ # 遍历所有激活记忆项, 打印其缓存信息和记录
+ for cache_item in kv_mem.get_all():
+ print(f"[激活记忆] {get_cache_info(cache_item.memory)} (记录:{cache_item.records})")
+
+
+if __name__ == "__main__":
+ kv_cache_only()
+
+ run_scheduler_example()
diff --git a/examples/mem_scheduler/task_stop_rerun.py b/examples/mem_scheduler/scheduler_for_async_tasks.py
similarity index 98%
rename from examples/mem_scheduler/task_stop_rerun.py
rename to examples/mem_scheduler/scheduler_for_async_tasks.py
index b5e62ff8f..7f544c3da 100644
--- a/examples/mem_scheduler/task_stop_rerun.py
+++ b/examples/mem_scheduler/scheduler_for_async_tasks.py
@@ -25,7 +25,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]):
task_id = str(msg.item_id)
file_path = tmp_dir / f"{task_id}.txt"
try:
- sleep(1)
+ sleep(5)
file_path.write_text(f"Task {task_id} processed.\n")
print(f"writing {file_path} done")
except Exception as e:
@@ -57,8 +57,8 @@ def submit_tasks():
TEST_HANDLER_LABEL = "test_handler"
mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler})
-# 10s to restart
-mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 10_000
+# 5s to restart
+mem_scheduler.orchestrator.tasks_min_idle_ms[TEST_HANDLER_LABEL] = 5_000
tmp_dir = Path("./tmp")
tmp_dir.mkdir(exist_ok=True)
@@ -88,6 +88,6 @@ def submit_tasks():
print(f"[Result] Final files in tmp: {len(list(tmp_dir.glob('*.txt')))})")
# 7. Stop the scheduler
+sleep(20)
print("Stopping the scheduler...")
-sleep(5)
mem_scheduler.stop()
diff --git a/poetry.lock b/poetry.lock
index 187b6c4aa..fb818e665 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -1098,7 +1098,7 @@ description = "Lightweight in-process concurrent programming"
optional = false
python-versions = ">=3.9"
groups = ["main", "eval"]
-markers = "(platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") and python_version < \"3.14\""
+markers = "python_version < \"3.14\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"
files = [
{file = "greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be"},
{file = "greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac"},
@@ -2641,7 +2641,7 @@ files = [
{file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"},
{file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-cuda-cupti-cu12"
@@ -2657,7 +2657,7 @@ files = [
{file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"},
{file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
@@ -2671,7 +2671,7 @@ files = [
{file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"},
{file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-cuda-runtime-cu12"
@@ -2687,7 +2687,7 @@ files = [
{file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"},
{file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-cudnn-cu12"
@@ -2701,7 +2701,7 @@ files = [
{file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"},
{file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[package.dependencies]
nvidia-cublas-cu12 = "*"
@@ -2720,7 +2720,7 @@ files = [
{file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"},
{file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[package.dependencies]
nvidia-nvjitlink-cu12 = "*"
@@ -2736,7 +2736,7 @@ files = [
{file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"},
{file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-curand-cu12"
@@ -2752,7 +2752,7 @@ files = [
{file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"},
{file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-cusolver-cu12"
@@ -2768,7 +2768,7 @@ files = [
{file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"},
{file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[package.dependencies]
nvidia-cublas-cu12 = "*"
@@ -2789,7 +2789,7 @@ files = [
{file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"},
{file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[package.dependencies]
nvidia-nvjitlink-cu12 = "*"
@@ -2806,7 +2806,7 @@ files = [
{file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"},
{file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-nccl-cu12"
@@ -2819,7 +2819,7 @@ files = [
{file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"},
{file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-nvjitlink-cu12"
@@ -2833,7 +2833,7 @@ files = [
{file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"},
{file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "nvidia-nvtx-cu12"
@@ -2849,7 +2849,7 @@ files = [
{file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"},
{file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[[package]]
name = "ollama"
@@ -3163,8 +3163,8 @@ markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pre
[package.dependencies]
numpy = [
- {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
+ {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
python-dateutil = ">=2.8.2"
@@ -4003,14 +4003,14 @@ files = [
[[package]]
name = "qdrant-client"
-version = "1.14.3"
+version = "1.16.2"
description = "Client library for the Qdrant vector search engine"
optional = false
-python-versions = ">=3.9"
+python-versions = ">=3.10"
groups = ["main", "eval"]
files = [
- {file = "qdrant_client-1.14.3-py3-none-any.whl", hash = "sha256:66faaeae00f9b5326946851fe4ca4ddb1ad226490712e2f05142266f68dfc04d"},
- {file = "qdrant_client-1.14.3.tar.gz", hash = "sha256:bb899e3e065b79c04f5e47053d59176150c0a5dabc09d7f476c8ce8e52f4d281"},
+ {file = "qdrant_client-1.16.2-py3-none-any.whl", hash = "sha256:442c7ef32ae0f005e88b5d3c0783c63d4912b97ae756eb5e052523be682f17d3"},
+ {file = "qdrant_client-1.16.2.tar.gz", hash = "sha256:ca4ef5f9be7b5eadeec89a085d96d5c723585a391eb8b2be8192919ab63185f0"},
]
markers = {main = "extra == \"all\""}
@@ -4018,11 +4018,13 @@ markers = {main = "extra == \"all\""}
grpcio = ">=1.41.0"
httpx = {version = ">=0.20.0", extras = ["http2"]}
numpy = [
- {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""},
+ {version = ">=1.21", markers = "python_version == \"3.11\""},
+ {version = ">=1.21,<2.3.0", markers = "python_version == \"3.10\""},
{version = ">=1.26", markers = "python_version == \"3.12\""},
- {version = ">=2.1.0", markers = "python_version >= \"3.13\""},
+ {version = ">=2.1.0", markers = "python_version == \"3.13\""},
+ {version = ">=2.3.0", markers = "python_version >= \"3.14\""},
]
-portalocker = ">=2.7.0,<3.0.0"
+portalocker = ">=2.7.0,<4.0"
protobuf = ">=3.20.0"
pydantic = ">=1.10.8,<2.0.dev0 || >2.2.0"
urllib3 = ">=1.26.14,<3"
@@ -5441,7 +5443,7 @@ files = [
{file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"},
{file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"},
]
-markers = {main = "platform_machine == \"x86_64\" and extra == \"all\" and platform_system == \"Linux\"", eval = "platform_machine == \"x86_64\" and platform_system == \"Linux\""}
+markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
[package.dependencies]
setuptools = ">=40.8.0"
@@ -5650,7 +5652,7 @@ description = "Fast implementation of asyncio event loop on top of libuv"
optional = false
python-versions = ">=3.8.0"
groups = ["main"]
-markers = "platform_python_implementation != \"PyPy\" and sys_platform != \"win32\" and sys_platform != \"cygwin\""
+markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\""
files = [
{file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"},
{file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"},
@@ -6242,4 +6244,4 @@ tree-mem = ["neo4j", "schedule"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
-content-hash = "dab8e54c6f4c51597adbd0fa34be7a8adb3b3a9c733508f3cc2b93c0ed434ec1"
+content-hash = "22bfcac5ed0be1e3aea294e3da96ff1a4bd9d7b62865ad827e1508f5ade6b708"
diff --git a/pyproject.toml b/pyproject.toml
index 3c2eecf18..f869f7642 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -119,7 +119,7 @@ all = [
# We kindof don't want users to install them.
"torch (>=2.7.1,<3.0.0)",
"sentence-transformers (>=4.1.0,<5.0.0)",
- "qdrant-client (>=1.14.2,<2.0.0)",
+ "qdrant-client (>=1.16.0,<2.0.0)",
"volcengine-python-sdk (>=4.0.4,<5.0.0)",
"nltk (>=3.9.1,<4.0.0)",
"rake-nltk (>=1.0.6,<1.1.0)",
diff --git a/src/memos/api/client.py b/src/memos/api/client.py
index 1129ddddf..91bc86829 100644
--- a/src/memos/api/client.py
+++ b/src/memos/api/client.py
@@ -177,7 +177,9 @@ def search_memory(
if retry == MAX_RETRY_COUNT - 1:
raise
- def get_memory(self, user_id: str, include_preference: str) -> MemOSGetMemoryResponse | None:
+ def get_memory(
+ self, user_id: str, include_preference: bool = True, page: int = 1, size: int = 10
+ ) -> MemOSGetMemoryResponse | None:
"""get memories"""
# Validate required parameters
self._validate_required_params(include_preference=include_preference, user_id=user_id)
@@ -186,6 +188,8 @@ def get_memory(self, user_id: str, include_preference: str) -> MemOSGetMemoryRes
payload = {
"include_preference": include_preference,
"user_id": user_id,
+ "page": page,
+ "size": size,
}
for retry in range(MAX_RETRY_COUNT):
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index 5fef51ca0..daf9b6cfe 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -7,16 +7,19 @@
import re
import time
-from typing import Any
+from typing import TYPE_CHECKING, Any
import requests
from dotenv import load_dotenv
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
from memos.context.context import ContextThread
-from memos.mem_cube.general import GeneralMemCube
+
+
+if TYPE_CHECKING:
+ from memos.configs.mem_cube import GeneralMemCubeConfig
+ from memos.configs.mem_os import MOSConfig
+ from memos.mem_cube.general import GeneralMemCube
# Load environment variables
@@ -805,8 +808,12 @@ def get_start_default_config() -> dict[str, Any]:
return config
@staticmethod
- def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]:
+ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "GeneralMemCube"]:
"""Create configuration for a specific user."""
+ from memos.configs.mem_cube import GeneralMemCubeConfig
+ from memos.configs.mem_os import MOSConfig
+ from memos.mem_cube.general import GeneralMemCube
+
openai_config = APIConfig.get_openai_config()
qwen_config = APIConfig.qwen_config()
vllm_config = APIConfig.vllm_config()
@@ -933,12 +940,14 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
return default_config, default_mem_cube
@staticmethod
- def get_default_cube_config() -> GeneralMemCubeConfig | None:
+ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
"""Get default cube configuration for product initialization.
Returns:
GeneralMemCubeConfig | None: Default cube configuration if enabled, None otherwise.
"""
+ from memos.configs.mem_cube import GeneralMemCubeConfig
+
if not APIConfig.is_default_cube_config_enabled():
return None
diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py
index bcc3669b6..812cf2793 100644
--- a/src/memos/api/handlers/chat_handler.py
+++ b/src/memos/api/handlers/chat_handler.py
@@ -99,15 +99,13 @@ def __init__(
def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]:
"""
- Chat with MemOS for complete response (non-streaming).
-
- This implementation directly uses search/add handlers instead of mos_server.
+ Chat with MemOS for chat complete response (non-streaming).
Args:
chat_req: Chat complete request
Returns:
- Dictionary with response and references
+ Dictionary with chat complete response and reasoning
Raises:
HTTPException: If chat fails
@@ -140,6 +138,13 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]
+ # Drop internet memories forced
+ memories_list = [
+ mem
+ for mem in memories_list
+ if mem.get("metadata", {}).get("memory_type") != "OuterMemory"
+ ]
+
# Filter memories by threshold
filtered_memories = self._filter_memories_by_threshold(
memories_list, chat_req.threshold or 0.5
@@ -161,7 +166,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
{"role": "user", "content": chat_req.query},
]
- self.logger.info("Starting to generate complete response...")
+ self.logger.info("[Cloud Service] Starting to generate chat complete response...")
# Step 3: Generate complete response from LLM
if chat_req.model_name_or_path and chat_req.model_name_or_path not in self.chat_llms:
@@ -172,11 +177,23 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
- self.logger.info(f"[Cloud Service Chat Complete Model]: {model}")
+ self.logger.info(f"[Cloud Service] Chat Complete Model: {model}")
strat = time.time()
response = self.chat_llms[model].generate(current_messages, model_name_or_path=model)
end = time.time()
- self.logger.info(f"[Cloud Service Chat Complete Time]: {end - strat} seconds")
+ self.logger.info(f"[Cloud Service] Chat Complete Time: {end - strat} seconds")
+
+ if not response:
+ self.logger.error(
+ f"[Cloud Service] Chat Complete Failed, LLM response is {response}"
+ )
+ raise HTTPException(
+ status_code=500, detail="Chat complete failed, LLM response is None"
+ )
+
+ self.logger.info(
+ f"[Cloud Service] Chat Complete LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Complete LLM Response: {response}"
+ )
# Step 4: start add after chat asynchronously
if chat_req.add_message_on_answer:
@@ -192,7 +209,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
async_mode="async",
)
end = time.time()
- self.logger.info(f"[Cloud Service Chat Add Time]: {end - start} seconds")
+ self.logger.info(f"[Cloud Service] Chat Add Time: {end - start} seconds")
match = re.search(r"([\s\S]*?)", response)
reasoning_text = match.group(1) if match else None
@@ -208,14 +225,12 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
except ValueError as err:
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
except Exception as err:
- self.logger.error(f"Failed to complete chat: {traceback.format_exc()}")
+ self.logger.error(f"[Cloud Service] Failed to chat complete: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse:
"""
- Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers.
-
- This implementation directly uses search_handler and add_handler.
+ Chat with MemOS via Server-Sent Events (SSE) stream for chat stream response.
Args:
chat_req: Chat stream request
@@ -229,7 +244,7 @@ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse:
try:
def generate_chat_response() -> Generator[str, None, None]:
- """Generate chat response as SSE stream."""
+ """Generate chat stream response as SSE stream."""
try:
# Resolve readable cube IDs (for search)
readable_cube_ids = chat_req.readable_cube_ids or (
@@ -269,6 +284,13 @@ def generate_chat_response() -> Generator[str, None, None]:
if text_mem_results and text_mem_results[0].get("memories"):
memories_list = text_mem_results[0]["memories"]
+ # Drop internet memories forced
+ memories_list = [
+ mem
+ for mem in memories_list
+ if mem.get("metadata", {}).get("memory_type") != "OuterMemory"
+ ]
+
# Filter memories by threshold
filtered_memories = self._filter_memories_by_threshold(memories_list)
@@ -289,7 +311,7 @@ def generate_chat_response() -> Generator[str, None, None]:
]
self.logger.info(
- f"user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, "
+ f"[Cloud Service] chat stream user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, "
f"current_system_prompt: {system_prompt}"
)
@@ -304,14 +326,12 @@ def generate_chat_response() -> Generator[str, None, None]:
)
model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys()))
- self.logger.info(f"[Cloud Service Chat Stream Model]: {model}")
+ self.logger.info(f"[Cloud Service] Chat Stream Model: {model}")
start = time.time()
response_stream = self.chat_llms[model].generate_stream(
current_messages, model_name_or_path=model
)
- end = time.time()
- self.logger.info(f"[Cloud Service Chat Stream Time]: {end - start} seconds")
# Stream the response
buffer = ""
@@ -337,6 +357,13 @@ def generate_chat_response() -> Generator[str, None, None]:
chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n"
yield chunk_data
+ end = time.time()
+ self.logger.info(f"[Cloud Service] Chat Stream Time: {end - start} seconds")
+
+ self.logger.info(
+ f"[Cloud Service] Chat Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Stream LLM Response: {full_response}"
+ )
+
current_messages.append({"role": "assistant", "content": full_response})
if chat_req.add_message_on_answer:
# Resolve writable cube IDs (for add)
@@ -354,10 +381,10 @@ def generate_chat_response() -> Generator[str, None, None]:
)
end = time.time()
self.logger.info(
- f"[Cloud Service Chat Stream Add Time]: {end - start} seconds"
+ f"[Cloud Service] Chat Stream Add Time: {end - start} seconds"
)
except Exception as e:
- self.logger.error(f"Error in chat stream: {e}", exc_info=True)
+ self.logger.error(f"[Cloud Service] Error in chat stream: {e}", exc_info=True)
error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
yield error_data
@@ -377,14 +404,14 @@ def generate_chat_response() -> Generator[str, None, None]:
except ValueError as err:
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
except Exception as err:
- self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
+ self.logger.error(
+ f"[Cloud Service] Failed to start chat stream: {traceback.format_exc()}"
+ )
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> StreamingResponse:
"""
- Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers.
-
- This implementation directly uses search_handler and add_handler.
+ Chat with MemOS via Server-Sent Events (SSE) stream for playground chat stream response.
Args:
chat_req: Chat stream request
@@ -398,7 +425,7 @@ def handle_chat_stream_playground(self, chat_req: ChatPlaygroundRequest) -> Stre
try:
def generate_chat_response() -> Generator[str, None, None]:
- """Generate chat response as SSE stream."""
+ """Generate playground chat stream response as SSE stream."""
try:
import time
@@ -434,7 +461,9 @@ def generate_chat_response() -> Generator[str, None, None]:
start_time = time.time()
search_response = self.search_handler.handle_search_memories(search_req)
end_time = time.time()
- self.logger.info(f"first search time: {end_time - start_time}")
+ self.logger.info(
+ f"[PLAYGROUND CHAT] first search time: {end_time - start_time}"
+ )
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
@@ -453,7 +482,7 @@ def generate_chat_response() -> Generator[str, None, None]:
# get preference string
pref_string = search_response.data.get("pref_string", "")
- yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
+ yield f"data: {json.dumps({'type': 'reference', 'data': reference}, ensure_ascii=False)}\n\n"
# Prepare preference markdown string
if chat_req.include_preference:
@@ -481,7 +510,7 @@ def generate_chat_response() -> Generator[str, None, None]:
conversation=chat_req.history,
mode="fine",
)
- self.logger.info(f"[PLAYGROUND chat parsed_goal]: {parsed_goal}")
+ self.logger.info(f"[PLAYGROUND CHAT] parsed_goal: {parsed_goal}")
if chat_req.beginner_guide_step == "first":
chat_req.internet_search = False
@@ -512,12 +541,14 @@ def generate_chat_response() -> Generator[str, None, None]:
search_tool_memory=False,
)
- self.logger.info(f"[PLAYGROUND second search query]: {search_req.query}")
+ self.logger.info(f"[PLAYGROUND CHAT] second search query: {search_req.query}")
start_time = time.time()
search_response = self.search_handler.handle_search_memories(search_req)
end_time = time.time()
- self.logger.info(f"second search time: {end_time - start_time}")
+ self.logger.info(
+ f"[PLAYGROUND CHAT] second search time: {end_time - start_time}"
+ )
# for playground, add the query to memory without response
self._start_add_to_memory(
@@ -555,7 +586,7 @@ def generate_chat_response() -> Generator[str, None, None]:
internet_reference = self._get_internet_reference(
search_response.data.get("text_mem")[0]["memories"]
)
- yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
+ yield f"data: {json.dumps({'type': 'reference', 'data': reference}, ensure_ascii=False)}\n\n"
# Step 2: Build system prompt with memories
lang = detect_lang(chat_req.query)
@@ -578,13 +609,15 @@ def generate_chat_response() -> Generator[str, None, None]:
]
self.logger.info(
- f"user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, "
+ f"[PLAYGROUND CHAT] user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, "
f"current_system_prompt: {system_prompt}"
)
# Step 3: Generate streaming response from LLM
try:
model = next(iter(self.chat_llms.keys()))
+ self.logger.info(f"[PLAYGROUND CHAT] Chat Playground Stream Model: {model}")
+ start = time.time()
response_stream = self.chat_llms[model].generate_stream(
current_messages, model_name_or_path=model
)
@@ -629,10 +662,19 @@ def generate_chat_response() -> Generator[str, None, None]:
chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
yield chunk_data
+ end = time.time()
+ self.logger.info(
+ f"[PLAYGROUND CHAT] Chat Playground Stream Time: {end - start} seconds"
+ )
+ self.logger.info(
+ f"[PLAYGROUND CHAT] Chat Playground Stream LLM Input: {json.dumps(current_messages, ensure_ascii=False)} Chat Playground Stream LLM Response: {full_response}"
+ )
+
except Exception as llm_error:
# Log the error
self.logger.error(
- f"Error during LLM generation: {llm_error}", exc_info=True
+ f"[PLAYGROUND CHAT] Error during LLM generation: {llm_error}",
+ exc_info=True,
)
# Send error message to client
error_msg = f"模型生成错误: {llm_error!s}"
@@ -642,7 +684,7 @@ def generate_chat_response() -> Generator[str, None, None]:
if chat_req.internet_search or parsed_goal.internet_search:
# Yield internet reference after text response
- yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n"
+ yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference}, ensure_ascii=False)}\n\n"
# Calculate timing
time_end = time.time()
@@ -654,8 +696,8 @@ def generate_chat_response() -> Generator[str, None, None]:
# Get further suggestion
current_messages.append({"role": "assistant", "content": full_response})
further_suggestion = self._get_further_suggestion(current_messages)
- self.logger.info(f"further_suggestion: {further_suggestion}")
- yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n"
+ self.logger.info(f"[PLAYGROUND CHAT] further_suggestion: {further_suggestion}")
+ yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'end'})}\n\n"
@@ -685,7 +727,9 @@ def generate_chat_response() -> Generator[str, None, None]:
)
except Exception as e:
- self.logger.error(f"Error in chat stream: {e}", exc_info=True)
+ self.logger.error(
+ f"[PLAYGROUND CHAT] Error in playground chat stream: {e}", exc_info=True
+ )
error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
yield error_data
@@ -705,7 +749,9 @@ def generate_chat_response() -> Generator[str, None, None]:
except ValueError as err:
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
except Exception as err:
- self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
+ self.logger.error(
+ f"[PLAYGROUND CHAT] Failed to start playground chat stream: {traceback.format_exc()}"
+ )
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
def _dedup_and_supplement_memories(
diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py
index f968ea7b9..56f8ac195 100644
--- a/src/memos/api/handlers/component_init.py
+++ b/src/memos/api/handlers/component_init.py
@@ -177,7 +177,11 @@ def init_server() -> dict[str, Any]:
else None
)
llm = LLMFactory.from_config(llm_config)
- chat_llms = _init_chat_llms(chat_llm_config)
+ chat_llms = (
+ _init_chat_llms(chat_llm_config)
+ if os.getenv("ENABLE_CHAT_API", "false") == "true"
+ else None
+ )
embedder = EmbedderFactory.from_config(embedder_config)
mem_reader = MemReaderFactory.from_config(mem_reader_config)
reranker = RerankerFactory.from_config(reranker_config)
@@ -308,6 +312,7 @@ def init_server() -> dict[str, Any]:
mem_reader=mem_reader,
searcher=searcher,
reranker=feedback_reranker,
+ pref_mem=pref_mem,
)
# Initialize Scheduler
diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py
index 88875cacc..ca87d95d2 100644
--- a/src/memos/api/handlers/formatters_handler.py
+++ b/src/memos/api/handlers/formatters_handler.py
@@ -7,9 +7,13 @@
from typing import Any
+from memos.log import get_logger
from memos.templates.instruction_completion import instruct_completion
+logger = get_logger(__name__)
+
+
def to_iter(running: Any) -> list[Any]:
"""
Normalize running tasks to a list of task objects.
@@ -29,7 +33,9 @@ def to_iter(running: Any) -> list[Any]:
return list(running) if running else []
-def format_memory_item(memory_data: Any) -> dict[str, Any]:
+def format_memory_item(
+ memory_data: Any, include_embedding: bool = False, save_sources: bool = True
+) -> dict[str, Any]:
"""
Format a single memory item for API response.
@@ -47,8 +53,10 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]:
ref_id = f"[{memory_id.split('-')[0]}]"
memory["ref_id"] = ref_id
- memory["metadata"]["embedding"] = []
- memory["metadata"]["sources"] = []
+ if not include_embedding:
+ memory["metadata"]["embedding"] = []
+ if not save_sources:
+ memory["metadata"]["sources"] = []
memory["metadata"]["usage"] = []
memory["metadata"]["ref_id"] = ref_id
memory["metadata"]["id"] = memory_id
@@ -124,3 +132,96 @@ def post_process_textual_mem(
}
)
return memories_result
+
+
+def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]):
+ """
+ Separate knowledge and conversation memories from retrieval results.
+ """
+ knowledge_mem = []
+ conversation_mem = []
+ for item in memories:
+ sources = item["metadata"]["sources"]
+ if (
+ len(sources) > 0
+ and "type" in sources[0]
+ and sources[0]["type"] == "file"
+ and "content" in sources[0]
+ and sources[0]["content"] != ""
+ ): # TODO change to memory_type
+ knowledge_mem.append(item)
+ else:
+ conversation_mem.append(item)
+
+ logger.info(
+ f"Retrieval results number of knowledge_mem: {len(knowledge_mem)}, conversation_mem: {len(conversation_mem)}"
+ )
+ return knowledge_mem, conversation_mem
+
+
+def rerank_knowledge_mem(
+ reranker: Any,
+ query: str,
+ text_mem: list[dict[str, Any]],
+ top_k: int,
+ file_mem_proportion: float = 0.5,
+) -> list[dict[str, Any]]:
+ """
+ Rerank knowledge memories and keep conversation memories.
+ """
+ memid2cubeid = {}
+ memories_list = []
+ for memory_group in text_mem:
+ cube_id = memory_group["cube_id"]
+ memories = memory_group["memories"]
+ memories_list.extend(memories)
+ for memory in memories:
+ memid2cubeid[memory["id"]] = cube_id
+
+ knowledge_mem, conversation_mem = separate_knowledge_and_conversation_mem(memories_list)
+ knowledge_mem_top_k = max(int(top_k * file_mem_proportion), int(top_k - len(conversation_mem)))
+ reranked_knowledge_mem = reranker.rerank(query, knowledge_mem, top_k=len(knowledge_mem))
+ reranked_knowledge_mem = [item[0] for item in reranked_knowledge_mem]
+
+ # TODO revoke sources replace memory value
+ for item in reranked_knowledge_mem:
+ item["memory"] = item["metadata"]["sources"][0]["content"]
+ item["metadata"]["sources"] = []
+
+ for item in conversation_mem:
+ item["metadata"]["sources"] = []
+
+ # deduplicate: remove items with duplicate memory content
+ original_count = len(reranked_knowledge_mem)
+ seen_memories = set[Any]()
+ deduplicated_knowledge_mem = []
+ for item in reranked_knowledge_mem:
+ memory_content = item.get("memory", "")
+ if memory_content and memory_content not in seen_memories:
+ seen_memories.add(memory_content)
+ deduplicated_knowledge_mem.append(item)
+ deduplicated_count = len(deduplicated_knowledge_mem)
+ logger.info(
+ f"After filtering duplicate knowledge base text from sources, count changed from {original_count} to {deduplicated_count}"
+ )
+
+ reranked_knowledge_mem = deduplicated_knowledge_mem[:knowledge_mem_top_k]
+ conversation_mem_top_k = top_k - len(reranked_knowledge_mem)
+ cubeid2memories = {}
+ text_mem_res = []
+
+ for memory in reranked_knowledge_mem + conversation_mem[:conversation_mem_top_k]:
+ cube_id = memid2cubeid[memory["id"]]
+ if cube_id not in cubeid2memories:
+ cubeid2memories[cube_id] = []
+ cubeid2memories[cube_id].append(memory)
+
+ for cube_id, memories in cubeid2memories.items():
+ text_mem_res.append(
+ {
+ "cube_id": cube_id,
+ "memories": memories,
+ }
+ )
+
+ return text_mem_res
diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py
index a33ee9254..a4f500560 100644
--- a/src/memos/api/handlers/memory_handler.py
+++ b/src/memos/api/handlers/memory_handler.py
@@ -163,25 +163,97 @@ def handle_get_subgraph(
raise
+def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemoryResponse:
+ """
+ Handler for getting a single memory by its ID.
+
+ Tries to retrieve from text memory first, then preference memory if not found.
+
+ Args:
+ memory_id: The ID of the memory to retrieve
+ naive_mem_cube: Memory cube instance
+
+ Returns:
+ GetMemoryResponse with the memory data
+ """
+
+ try:
+ memory = naive_mem_cube.text_mem.get(memory_id)
+ except Exception:
+ memory = None
+
+ # If not found in text memory, try preference memory
+ pref = None
+ if memory is None and naive_mem_cube.pref_mem is not None:
+ collection_names = ["explicit_preference", "implicit_preference"]
+ for collection_name in collection_names:
+ try:
+ pref = naive_mem_cube.pref_mem.get_with_collection_name(collection_name, memory_id)
+ if pref is not None:
+ break
+ except Exception:
+ continue
+
+ # Get the data from whichever memory source succeeded
+ data = (memory or pref).model_dump() if (memory or pref) else None
+
+ return GetMemoryResponse(
+ message="Memory retrieved successfully"
+ if data
+ else f"Memory with ID {memory_id} not found",
+ code=200,
+ data=data,
+ )
+
+
def handle_get_memories(
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
) -> GetMemoryResponse:
# TODO: Implement get memory with filter
- memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"]
+ memories = naive_mem_cube.text_mem.get_all(
+ user_name=get_mem_req.mem_cube_id,
+ user_id=get_mem_req.user_id,
+ page=get_mem_req.page,
+ page_size=get_mem_req.page_size,
+ )
+ total_nodes = memories["total_nodes"]
+ total_edges = memories["total_edges"]
+ del memories["total_nodes"]
+ del memories["total_edges"]
+
preferences: list[TextualMemoryItem] = []
+ total_pref = 0
+
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
filter_params: dict[str, Any] = {}
if get_mem_req.user_id is not None:
filter_params["user_id"] = get_mem_req.user_id
if get_mem_req.mem_cube_id is not None:
filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
- preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params)
- preferences = [format_memory_item(mem) for mem in preferences]
+
+ preferences, total_pref = naive_mem_cube.pref_mem.get_memory_by_filter(
+ filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size
+ )
+ format_preferences = [format_memory_item(item, save_sources=False) for item in preferences]
+
return GetMemoryResponse(
message="Memories retrieved successfully",
data={
- "text_mem": [{"cube_id": get_mem_req.mem_cube_id, "memories": memories}],
- "pref_mem": [{"cube_id": get_mem_req.mem_cube_id, "memories": preferences}],
+ "text_mem": [
+ {
+ "cube_id": get_mem_req.mem_cube_id,
+ "memories": memories,
+ "total_nodes": total_nodes,
+ "total_edges": total_edges,
+ }
+ ],
+ "pref_mem": [
+ {
+ "cube_id": get_mem_req.mem_cube_id,
+ "memories": format_preferences,
+ "total_nodes": total_pref,
+ }
+ ],
},
)
@@ -204,8 +276,7 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
try:
if delete_mem_req.memory_ids is not None:
- for cube_id in delete_mem_req.writable_cube_ids:
- naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids, user_name=cube_id)
+ naive_mem_cube.text_mem.delete_by_memory_ids(delete_mem_req.memory_ids)
if naive_mem_cube.pref_mem is not None:
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
elif delete_mem_req.file_ids is not None:
@@ -213,13 +284,9 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube:
writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids
)
elif delete_mem_req.filter is not None:
- # TODO: Implement deletion by filter
- # Need to find memories matching filter and delete them
- logger.warning("Deletion by filter not implemented yet")
- return DeleteMemoryResponse(
- message="Deletion by filter not implemented yet",
- data={"status": "failure"},
- )
+ naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter)
+ if naive_mem_cube.pref_mem is not None:
+ naive_mem_cube.pref_mem.delete_by_filter(filter=delete_mem_req.filter)
except Exception as e:
logger.error(f"Failed to delete memories: {e}", exc_info=True)
return DeleteMemoryResponse(
diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py
index f7d6ee2c8..32a970b22 100644
--- a/src/memos/api/handlers/search_handler.py
+++ b/src/memos/api/handlers/search_handler.py
@@ -5,9 +5,17 @@
using dependency injection for better modularity and testability.
"""
+import time
+
+from typing import Any
+
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
+from memos.api.handlers.formatters_handler import rerank_knowledge_mem
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.log import get_logger
+from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
+ cosine_similarity_matrix,
+)
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
@@ -50,9 +58,31 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
"""
self.logger.info(f"[SearchHandler] Search Req is: {search_req}")
+ # Increase recall pool if deduplication is enabled to ensure diversity
+ original_top_k = search_req.top_k
+ if search_req.dedup == "sim":
+ search_req.top_k = original_top_k * 5
+
cube_view = self._build_cube_view(search_req)
results = cube_view.search_memories(search_req)
+ if search_req.dedup == "sim":
+ results = self._dedup_text_memories(results, original_top_k)
+ self._strip_embeddings(results)
+ # Restore original top_k for downstream logic or response metadata
+ search_req.top_k = original_top_k
+
+ start_time = time.time()
+ text_mem = results["text_mem"]
+ results["text_mem"] = rerank_knowledge_mem(
+ self.reranker,
+ query=search_req.query,
+ text_mem=text_mem,
+ top_k=original_top_k,
+ file_mem_proportion=0.5,
+ )
+ rerank_time = time.time() - start_time
+ self.logger.info(f"[Knowledge_replace_memory_time] Rerank time: {rerank_time} seconds")
self.logger.info(
f"[SearchHandler] Final search results: count={len(results)} results={results}"
@@ -63,6 +93,93 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
data=results,
)
+ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
+ buckets = results.get("text_mem", [])
+ if not buckets:
+ return results
+
+ flat: list[tuple[int, dict[str, Any], float]] = []
+ for bucket_idx, bucket in enumerate(buckets):
+ for mem in bucket.get("memories", []):
+ score = mem.get("metadata", {}).get("relativity", 0.0)
+ flat.append((bucket_idx, mem, score))
+
+ if len(flat) <= 1:
+ return results
+
+ embeddings = self._extract_embeddings([mem for _, mem, _ in flat])
+ if embeddings is None:
+ documents = [mem.get("memory", "") for _, mem, _ in flat]
+ embeddings = self.searcher.embedder.embed(documents)
+
+ similarity_matrix = cosine_similarity_matrix(embeddings)
+
+ indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
+ for flat_index, (bucket_idx, _, _) in enumerate(flat):
+ indices_by_bucket[bucket_idx].append(flat_index)
+
+ selected_global: list[int] = []
+ selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
+
+ ordered_indices = sorted(range(len(flat)), key=lambda idx: flat[idx][2], reverse=True)
+ for idx in ordered_indices:
+ bucket_idx = flat[idx][0]
+ if len(selected_by_bucket[bucket_idx]) >= target_top_k:
+ continue
+ # Use 0.92 threshold strictly
+ if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92):
+ selected_by_bucket[bucket_idx].append(idx)
+ selected_global.append(idx)
+
+ # Removed the 'filling' logic that was pulling back similar items.
+ # Now it will only return items that truly pass the 0.92 threshold,
+ # up to target_top_k.
+
+ for bucket_idx, bucket in enumerate(buckets):
+ selected_indices = selected_by_bucket.get(bucket_idx, [])
+ bucket["memories"] = [flat[i][1] for i in selected_indices]
+ return results
+
+ @staticmethod
+ def _is_unrelated(
+ index: int,
+ selected_indices: list[int],
+ similarity_matrix: list[list[float]],
+ similarity_threshold: float,
+ ) -> bool:
+ return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices)
+
+ @staticmethod
+ def _max_similarity(
+ index: int, selected_indices: list[int], similarity_matrix: list[list[float]]
+ ) -> float:
+ if not selected_indices:
+ return 0.0
+ return max(similarity_matrix[index][j] for j in selected_indices)
+
+ @staticmethod
+ def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
+ embeddings: list[list[float]] = []
+ for mem in memories:
+ embedding = mem.get("metadata", {}).get("embedding")
+ if not embedding:
+ return None
+ embeddings.append(embedding)
+ return embeddings
+
+ @staticmethod
+ def _strip_embeddings(results: dict[str, Any]) -> None:
+ for bucket in results.get("text_mem", []):
+ for mem in bucket.get("memories", []):
+ metadata = mem.get("metadata", {})
+ if "embedding" in metadata:
+ metadata["embedding"] = []
+ for bucket in results.get("tool_mem", []):
+ for mem in bucket.get("memories", []):
+ metadata = mem.get("metadata", {})
+ if "embedding" in metadata:
+ metadata["embedding"] = []
+
def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
"""
Normalize target cube ids from search_req.
diff --git a/src/memos/api/mcp_serve.py b/src/memos/api/mcp_serve.py
index 9eb1e59d0..838c2a76a 100644
--- a/src/memos/api/mcp_serve.py
+++ b/src/memos/api/mcp_serve.py
@@ -16,14 +16,88 @@
def load_default_config(user_id="default_user"):
+ """
+ Load MOS configuration from environment variables.
+
+ IMPORTANT for Neo4j Community Edition:
+ Community Edition does not support administrative commands like 'CREATE DATABASE'.
+ To avoid errors, ensure the following environment variables are set correctly:
+ - NEO4J_DB_NAME=neo4j (Must use the default database)
+ - NEO4J_AUTO_CREATE=false (Disable automatic database creation)
+ - NEO4J_USE_MULTI_DB=false (Disable multi-tenant database mode)
+ """
+ # Define mapping between environment variables and configuration parameters
+ # We support both clean names and MOS_ prefixed names for compatibility
+ env_mapping = {
+ "OPENAI_API_KEY": "openai_api_key",
+ "OPENAI_API_BASE": "openai_api_base",
+ "MOS_TEXT_MEM_TYPE": "text_mem_type",
+ "NEO4J_URI": "neo4j_uri",
+ "NEO4J_USER": "neo4j_user",
+ "NEO4J_PASSWORD": "neo4j_password",
+ "NEO4J_DB_NAME": "neo4j_db_name",
+ "NEO4J_AUTO_CREATE": "neo4j_auto_create",
+ "NEO4J_USE_MULTI_DB": "use_multi_db",
+ "MOS_NEO4J_SHARED_DB": "mos_shared_db", # Special handle later
+ "MODEL_NAME": "model_name",
+ "MOS_CHAT_MODEL": "model_name",
+ "EMBEDDER_MODEL": "embedder_model",
+ "MOS_EMBEDDER_MODEL": "embedder_model",
+ "CHUNK_SIZE": "chunk_size",
+ "CHUNK_OVERLAP": "chunk_overlap",
+ "ENABLE_MEM_SCHEDULER": "enable_mem_scheduler",
+ "MOS_ENABLE_SCHEDULER": "enable_mem_scheduler",
+ "ENABLE_ACTIVATION_MEMORY": "enable_activation_memory",
+ "TEMPERATURE": "temperature",
+ "MOS_CHAT_TEMPERATURE": "temperature",
+ "MAX_TOKENS": "max_tokens",
+ "MOS_MAX_TOKENS": "max_tokens",
+ "TOP_P": "top_p",
+ "MOS_TOP_P": "top_p",
+ "TOP_K": "top_k",
+ "MOS_TOP_K": "top_k",
+ "SCHEDULER_TOP_K": "scheduler_top_k",
+ "MOS_SCHEDULER_TOP_K": "scheduler_top_k",
+ "SCHEDULER_TOP_N": "scheduler_top_n",
+ }
+
+ kwargs = {"user_id": user_id}
+ for env_key, param_key in env_mapping.items():
+ val = os.getenv(env_key)
+ if val is not None:
+ # Strip quotes if they exist (sometimes happens with .env)
+ if (val.startswith('"') and val.endswith('"')) or (
+ val.startswith("'") and val.endswith("'")
+ ):
+ val = val[1:-1]
+
+ # Handle boolean conversions
+ if val.lower() in ("true", "false"):
+ kwargs[param_key] = val.lower() == "true"
+ else:
+ # Try numeric conversions (int first, then float)
+ try:
+ if "." in val:
+ kwargs[param_key] = float(val)
+ else:
+ kwargs[param_key] = int(val)
+ except ValueError:
+ kwargs[param_key] = val
+
+ # Logic handle for MOS_NEO4J_SHARED_DB vs use_multi_db
+ if "mos_shared_db" in kwargs:
+ kwargs["use_multi_db"] = not kwargs.pop("mos_shared_db")
+
+ # Extract mandatory or special params
+ openai_api_key = kwargs.pop("openai_api_key", os.getenv("OPENAI_API_KEY"))
+ openai_api_base = kwargs.pop("openai_api_base", "https://api.openai.com/v1")
+ text_mem_type = kwargs.pop("text_mem_type", "tree_text")
+
config, cube = get_default(
- openai_api_key=os.getenv("OPENAI_API_KEY"),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- text_mem_type=os.getenv("MOS_TEXT_MEM_TYPE"),
- user_id=user_id,
- neo4j_uri=os.getenv("NEO4J_URI"),
- neo4j_user=os.getenv("NEO4J_USER"),
- neo4j_password=os.getenv("NEO4J_PASSWORD"),
+ openai_api_key=openai_api_key,
+ openai_api_base=openai_api_base,
+ text_mem_type=text_mem_type,
+ **kwargs,
)
return config, cube
@@ -33,6 +107,7 @@ def __init__(self):
self.mcp = FastMCP("MOS Memory System")
config, cube = load_default_config()
self.mos_core = MOS(config=config)
+ self.mos_core.register_mem_cube(cube)
self._setup_tools()
def _setup_tools(self):
@@ -132,11 +207,14 @@ async def register_cube(
"""
try:
if not os.path.exists(cube_name_or_path):
- mos_config, cube_name_or_path = load_default_config(user_id=user_id)
+ _, cube = load_default_config(user_id=user_id)
+ cube_to_register = cube
+ else:
+ cube_to_register = cube_name_or_path
self.mos_core.register_mem_cube(
- cube_name_or_path, mem_cube_id=cube_id, user_id=user_id
+ cube_to_register, mem_cube_id=cube_id, user_id=user_id
)
- return f"Cube registered successfully: {cube_id or cube_name_or_path}"
+ return f"Cube registered successfully: {cube_id or cube_to_register}"
except Exception as e:
return f"Error registering cube: {e!s}"
@@ -489,14 +567,6 @@ def run(self, transport: str = "stdio", **kwargs):
args = parser.parse_args()
- # Set environment variables
- os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_API_BASE")
- os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
- os.environ["MOS_TEXT_MEM_TYPE"] = "tree_text" # "tree_text" need set neo4j
- os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI")
- os.environ["NEO4J_USER"] = os.getenv("NEO4J_USER")
- os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD")
-
# Create and run MCP server
server = MOSMCPStdioServer()
server.run(transport=args.transport, host=args.host, port=args.port)
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index adcb68a96..d5f301c9d 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest):
description="Number of textual memories to retrieve (top-K). Default: 10.",
)
+ dedup: Literal["no", "sim"] | None = Field(
+ None,
+ description=(
+ "Optional dedup option for textual memories. "
+ "Use 'no' for no dedup, 'sim' for similarity dedup. "
+ "If None, default exact-text dedup is applied."
+ ),
+ )
+
pref_top_k: int = Field(
6,
ge=0,
@@ -763,12 +772,19 @@ class GetMemoryRequest(BaseRequest):
mem_cube_id: str = Field(..., description="Cube ID")
user_id: str | None = Field(None, description="User ID")
include_preference: bool = Field(True, description="Whether to handle preference memory")
+ page: int | None = Field(
+ None,
+ description="Page number (starts from 1). If None, exports all data without pagination.",
+ )
+ page_size: int | None = Field(
+ None, description="Number of items per page. If None, exports all data without pagination."
+ )
class DeleteMemoryRequest(BaseRequest):
"""Request model for deleting memories."""
- writable_cube_ids: list[str] = Field(..., description="Writable cube IDs")
+ writable_cube_ids: list[str] = Field(None, description="Writable cube IDs")
memory_ids: list[str] | None = Field(None, description="Memory IDs")
file_ids: list[str] | None = Field(None, description="File IDs")
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
@@ -856,8 +872,8 @@ class GetMemoryData(BaseModel):
memory_detail_list: list[MemoryDetail] = Field(
default_factory=list, alias="memory_detail_list", description="List of memory details"
)
- message_detail_list: list[MessageDetail] | None = Field(
- None, alias="message_detail_list", description="List of message details (usually None)"
+ preference_detail_list: list[MessageDetail] | None = Field(
+ None, alias="preference_detail_list", description="List of preference detail"
)
@@ -1009,7 +1025,7 @@ class MemOSGetMemoryResponse(BaseModel):
code: int = Field(..., description="Response status code")
message: str = Field(..., description="Response message")
- data: SearchMemoryData = Field(..., description="Get results data")
+ data: GetMemoryData = Field(..., description="Get results data")
@property
def memories(self) -> list[MemoryDetail]:
@@ -1017,15 +1033,10 @@ def memories(self) -> list[MemoryDetail]:
return self.data.memory_detail_list
@property
- def preferences(self) -> list[MemoryDetail]:
+ def preferences(self) -> list[MessageDetail] | None:
"""Convenient access to preference list."""
return self.data.preference_detail_list
- @property
- def tool_memories(self) -> list[MemoryDetail]:
- """Convenient access to tool_memory list."""
- return self.data.tool_memory_detail_list
-
class MemOSGetKnowledgebaseFileResponse(BaseModel):
"""Response model for get KnowledgebaseFile operation based on actual API."""
@@ -1168,3 +1179,26 @@ class AllStatusResponse(BaseResponse[AllStatusResponseData]):
"""Response model for full scheduler status operations."""
message: str = "Scheduler status summary retrieved successfully"
+
+
+# ─── Internal API Endpoints Models (for internal use) ───────────────────────────────────────────────────
+
+
+class GetUserNamesByMemoryIdsRequest(BaseRequest):
+ """Request model for getting user names by memory ids."""
+
+ memory_ids: list[str] = Field(..., description="Memory IDs")
+
+
+class GetUserNamesByMemoryIdsResponse(BaseResponse[dict[str, str | None]]):
+ """Response model for getting user names by memory ids."""
+
+
+class ExistMemCubeIdRequest(BaseRequest):
+ """Request model for checking if mem cube id exists."""
+
+ mem_cube_id: str = Field(..., description="Mem cube ID")
+
+
+class ExistMemCubeIdResponse(BaseResponse[dict[str, bool]]):
+ """Response model for checking if mem cube id exists."""
diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py
index fcb70a64c..86b75d73e 100644
--- a/src/memos/api/routers/server_router.py
+++ b/src/memos/api/routers/server_router.py
@@ -15,7 +15,7 @@
import random as _random
import socket
-from fastapi import APIRouter, Query
+from fastapi import APIRouter, HTTPException, Query
from memos.api import handlers
from memos.api.handlers.add_handler import AddHandler
@@ -33,9 +33,13 @@
ChatRequest,
DeleteMemoryRequest,
DeleteMemoryResponse,
+ ExistMemCubeIdRequest,
+ ExistMemCubeIdResponse,
GetMemoryPlaygroundRequest,
GetMemoryRequest,
GetMemoryResponse,
+ GetUserNamesByMemoryIdsRequest,
+ GetUserNamesByMemoryIdsResponse,
MemoryResponse,
SearchResponse,
StatusResponse,
@@ -43,6 +47,7 @@
SuggestionResponse,
TaskQueueResponse,
)
+from memos.graph_dbs.polardb import PolarDBGraphDB
from memos.log import get_logger
from memos.mem_scheduler.base_scheduler import BaseScheduler
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
@@ -64,12 +69,16 @@
# Initialize all handlers with dependency injection
search_handler = SearchHandler(dependencies)
add_handler = AddHandler(dependencies)
-chat_handler = ChatHandler(
- dependencies,
- components["chat_llms"],
- search_handler,
- add_handler,
- online_bot=components.get("online_bot"),
+chat_handler = (
+ ChatHandler(
+ dependencies,
+ components["chat_llms"],
+ search_handler,
+ add_handler,
+ online_bot=components.get("online_bot"),
+ )
+ if os.getenv("ENABLE_CHAT_API", "false") == "true"
+ else None
)
feedback_handler = FeedbackHandler(dependencies)
# Extract commonly used components for function-based handlers
@@ -79,6 +88,8 @@
naive_mem_cube = components["naive_mem_cube"]
redis_client = components["redis_client"]
status_tracker = TaskStatusTracker(redis_client=redis_client)
+graph_db = components["graph_db"]
+vector_db = components["vector_db"]
# =============================================================================
@@ -201,6 +212,10 @@ def chat_complete(chat_req: APIChatCompleteRequest):
This endpoint uses the class-based ChatHandler.
"""
+ if chat_handler is None:
+ raise HTTPException(
+ status_code=503, detail="Chat service is not available. Chat handler not initialized."
+ )
return chat_handler.handle_chat_complete(chat_req)
@@ -212,6 +227,10 @@ def chat_stream(chat_req: ChatRequest):
This endpoint uses the class-based ChatHandler which internally
composes SearchHandler and AddHandler for a clean architecture.
"""
+ if chat_handler is None:
+ raise HTTPException(
+ status_code=503, detail="Chat service is not available. Chat handler not initialized."
+ )
return chat_handler.handle_chat_stream(chat_req)
@@ -223,6 +242,10 @@ def chat_stream_playground(chat_req: ChatPlaygroundRequest):
This endpoint uses the class-based ChatHandler which internally
composes SearchHandler and AddHandler for a clean architecture.
"""
+ if chat_handler is None:
+ raise HTTPException(
+ status_code=503, detail="Chat service is not available. Chat handler not initialized."
+ )
return chat_handler.handle_chat_stream_playground(chat_req)
@@ -289,6 +312,14 @@ def get_memories(memory_req: GetMemoryRequest):
)
+@router.get("/get_memory/{memory_id}", summary="Get memory by id", response_model=GetMemoryResponse)
+def get_memory_by_id(memory_id: str):
+ return handlers.memory_handler.handle_get_memory(
+ memory_id=memory_id,
+ naive_mem_cube=naive_mem_cube,
+ )
+
+
@router.post(
"/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse
)
@@ -311,3 +342,53 @@ def feedback_memories(feedback_req: APIFeedbackRequest):
This endpoint uses the class-based FeedbackHandler for better code organization.
"""
return feedback_handler.handle_feedback_memories(feedback_req)
+
+
+# =============================================================================
+# Other API Endpoints (for internal use)
+# =============================================================================
+
+
+@router.post(
+ "/get_user_names_by_memory_ids",
+ summary="Get user names by memory ids",
+ response_model=GetUserNamesByMemoryIdsResponse,
+)
+def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest):
+ """Get user names by memory ids."""
+ if not isinstance(graph_db, PolarDBGraphDB):
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ "graph_db must be an instance of PolarDBGraphDB to use "
+ "get_user_names_by_memory_ids"
+ f"current graph_db is: {graph_db.__class__.__name__}"
+ ),
+ )
+ result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids)
+ if vector_db:
+ prefs = []
+ for collection_name in ["explicit_preference", "implicit_preference"]:
+ prefs.extend(
+ vector_db.get_by_ids(collection_name=collection_name, ids=request.memory_ids)
+ )
+ result.update({pref.id: pref.payload.get("mem_cube_id", None) for pref in prefs})
+ return GetUserNamesByMemoryIdsResponse(
+ code=200,
+ message="Successfully",
+ data=result,
+ )
+
+
+@router.post(
+ "/exist_mem_cube_id",
+ summary="Check if mem cube id exists",
+ response_model=ExistMemCubeIdResponse,
+)
+def exist_mem_cube_id(request: ExistMemCubeIdRequest):
+ """Check if mem cube id exists."""
+ return ExistMemCubeIdResponse(
+ code=200,
+ message="Successfully",
+ data=graph_db.exist_user_name(user_name=request.mem_cube_id),
+ )
diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py
index de375a4dc..b7771ac35 100644
--- a/src/memos/chunkers/markdown_chunker.py
+++ b/src/memos/chunkers/markdown_chunker.py
@@ -57,6 +57,6 @@ def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]:
except Exception as e:
logger.warning(f"warning chunking document: {e}")
chunks.append(doc.page_content)
-
+ logger.info(f"Generated chunks: {chunks[:5]}")
logger.debug(f"Generated {len(chunks)} chunks from input text")
return chunks
diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py
index 080962482..f39dfb8e2 100644
--- a/src/memos/chunkers/sentence_chunker.py
+++ b/src/memos/chunkers/sentence_chunker.py
@@ -20,12 +20,25 @@ def __init__(self, config: SentenceChunkerConfig):
from chonkie import SentenceChunker as ChonkieSentenceChunker
self.config = config
- self.chunker = ChonkieSentenceChunker(
- tokenizer_or_token_counter=config.tokenizer_or_token_counter,
- chunk_size=config.chunk_size,
- chunk_overlap=config.chunk_overlap,
- min_sentences_per_chunk=config.min_sentences_per_chunk,
- )
+
+ # Try new API first (v1.4.0+)
+ try:
+ self.chunker = ChonkieSentenceChunker(
+ tokenizer=config.tokenizer_or_token_counter,
+ chunk_size=config.chunk_size,
+ chunk_overlap=config.chunk_overlap,
+ min_sentences_per_chunk=config.min_sentences_per_chunk,
+ )
+ except (TypeError, AttributeError) as e:
+ # Fallback to old API ( list[str] | list[Chunk]:
diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py
index 89b58f417..428d6d09e 100644
--- a/src/memos/graph_dbs/nebular.py
+++ b/src/memos/graph_dbs/nebular.py
@@ -1207,7 +1207,7 @@ def clear(self, user_name: str | None = None) -> None:
@timed
def export_graph(
- self, include_embedding: bool = False, user_name: str | None = None
+ self, include_embedding: bool = False, user_name: str | None = None, **kwargs
) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.
diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py
index a0a4c6a50..64aedc8f4 100644
--- a/src/memos/graph_dbs/neo4j.py
+++ b/src/memos/graph_dbs/neo4j.py
@@ -1132,41 +1132,95 @@ def clear(self, user_name: str | None = None) -> None:
logger.error(f"[ERROR] Failed to clear database '{self.db_name}': {e}")
raise
- def export_graph(self, **kwargs) -> dict[str, Any]:
+ def export_graph(
+ self,
+ page: int | None = None,
+ page_size: int | None = None,
+ **kwargs,
+ ) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.
+ Args:
+ page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
+ page_size (int, optional): Number of items per page. If None, exports all data without pagination.
+ **kwargs: Additional keyword arguments, including:
+ - user_name (str, optional): User name for filtering in non-multi-db mode
+
Returns:
{
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
- "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
+ "edges": [ { "source": ..., "target": ..., "type": ... }, ... ],
+ "total_nodes": int, # Total number of nodes matching the filter criteria
+ "total_edges": int, # Total number of edges matching the filter criteria
}
"""
user_name = kwargs.get("user_name") if kwargs.get("user_name") else self.config.user_name
+
+ # Initialize total counts
+ total_nodes = 0
+ total_edges = 0
+
+ # Determine if pagination is needed
+ use_pagination = page is not None and page_size is not None
+
+ # Validate pagination parameters if pagination is enabled
+ if use_pagination:
+ if page < 1:
+ page = 1
+ if page_size < 1:
+ page_size = 10
+ skip = (page - 1) * page_size
+
with self.driver.session(database=self.db_name) as session:
- # Export nodes
- node_query = "MATCH (n:Memory)"
- edge_query = "MATCH (a:Memory)-[r]->(b:Memory)"
+ # Build base queries
+ node_base_query = "MATCH (n:Memory)"
+ edge_base_query = "MATCH (a:Memory)-[r]->(b:Memory)"
params = {}
if not self.config.use_multi_db and (self.config.user_name or user_name):
- node_query += " WHERE n.user_name = $user_name"
- edge_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
+ node_base_query += " WHERE n.user_name = $user_name"
+ edge_base_query += " WHERE a.user_name = $user_name AND b.user_name = $user_name"
params["user_name"] = user_name
- node_result = session.run(f"{node_query} RETURN n", params)
+ # Get total count of nodes before pagination
+ count_node_query = node_base_query + " RETURN COUNT(n) AS count"
+ count_node_result = session.run(count_node_query, params)
+ total_nodes = count_node_result.single()["count"]
+
+ # Export nodes with ORDER BY created_at DESC
+ node_query = node_base_query + " RETURN n ORDER BY n.created_at DESC, n.id DESC"
+ if use_pagination:
+ node_query += f" SKIP {skip} LIMIT {page_size}"
+
+ node_result = session.run(node_query, params)
nodes = [self._parse_node(dict(record["n"])) for record in node_result]
- # Export edges
- edge_result = session.run(
- f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) AS type", params
+ # Get total count of edges before pagination
+ count_edge_query = edge_base_query + " RETURN COUNT(r) AS count"
+ count_edge_result = session.run(count_edge_query, params)
+ total_edges = count_edge_result.single()["count"]
+
+ # Export edges with ORDER BY created_at DESC
+ edge_query = (
+ edge_base_query
+ + " RETURN a.id AS source, b.id AS target, type(r) AS type ORDER BY a.created_at DESC, b.created_at DESC, a.id DESC, b.id DESC"
)
+ if use_pagination:
+ edge_query += f" SKIP {skip} LIMIT {page_size}"
+
+ edge_result = session.run(edge_query, params)
edges = [
{"source": record["source"], "target": record["target"], "type": record["type"]}
for record in edge_result
]
- return {"nodes": nodes, "edges": edges}
+ return {
+ "nodes": nodes,
+ "edges": edges,
+ "total_nodes": total_nodes,
+ "total_edges": total_edges,
+ }
def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
"""
@@ -1347,6 +1401,15 @@ def _ensure_database_exists(self):
with self.driver.session(database="system") as session:
session.run(f"CREATE DATABASE `{self.db_name}` IF NOT EXISTS")
except ClientError as e:
+ if "Unsupported administration command" in str(
+ e
+ ) or "Unsupported administration" in str(e):
+ logger.warning(
+ f"Could not create database '{self.db_name}' because this Neo4j instance "
+ "(likely Community Edition) does not support administrative commands. "
+ "Please ensure the database exists manually or use the default 'neo4j' database."
+ )
+ return
if "ExistingDatabaseFound" in str(e):
pass # Ignore, database already exists
else:
@@ -1637,130 +1700,216 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
def delete_node_by_prams(
self,
- writable_cube_ids: list[str],
+ writable_cube_ids: list[str] | None = None,
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
) -> int:
"""
Delete nodes by memory_ids, file_ids, or filter.
+ Supports three scenarios:
+ 1. Delete by memory_ids (standalone)
+ 2. Delete by writable_cube_ids + file_ids (combined)
+ 3. Delete by filter (standalone, no writable_cube_ids needed)
Args:
- writable_cube_ids (list[str]): List of cube IDs (user_name) to filter nodes. Required parameter.
+ writable_cube_ids (list[str], optional): List of cube IDs (user_name) to filter nodes.
+ Only used with file_ids scenario. If not provided, no user_name filter will be applied.
memory_ids (list[str], optional): List of memory node IDs to delete.
- file_ids (list[str], optional): List of file node IDs to delete.
- filter (dict, optional): Filter dictionary to query matching nodes for deletion.
+ file_ids (list[str], optional): List of file node IDs to delete. Must be used with writable_cube_ids.
+ filter (dict, optional): Filter dictionary for metadata filtering.
+ Filter conditions are directly used in DELETE WHERE clause without pre-querying.
+ Does not require writable_cube_ids.
Returns:
int: Number of nodes deleted.
"""
+ batch_start_time = time.time()
logger.info(
f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
)
- print(
- f"[delete_node_by_prams] memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}, writable_cube_ids: {writable_cube_ids}"
- )
-
- # Validate writable_cube_ids
- if not writable_cube_ids or len(writable_cube_ids) == 0:
- raise ValueError("writable_cube_ids is required and cannot be empty")
-
- # Build WHERE conditions separately for memory_ids and file_ids
- where_clauses = []
- params = {}
# Build user_name condition from writable_cube_ids (OR relationship - match any cube_id)
+ # Only add user_name filter if writable_cube_ids is provided (for file_ids scenario)
user_name_conditions = []
- for idx, cube_id in enumerate(writable_cube_ids):
- param_name = f"cube_id_{idx}"
- user_name_conditions.append(f"n.user_name = ${param_name}")
- params[param_name] = cube_id
-
- # Handle memory_ids: query n.id
- if memory_ids and len(memory_ids) > 0:
- where_clauses.append("n.id IN $memory_ids")
- params["memory_ids"] = memory_ids
-
- # Handle file_ids: query n.file_ids field
- # All file_ids must be present in the array field (AND relationship)
- if file_ids and len(file_ids) > 0:
- file_id_and_conditions = []
- for idx, file_id in enumerate(file_ids):
- param_name = f"file_id_{idx}"
- params[param_name] = file_id
- # Check if this file_id is in the file_ids array field
- file_id_and_conditions.append(f"${param_name} IN n.file_ids")
- if file_id_and_conditions:
- # Use AND to require all file_ids to be present
- where_clauses.append(f"({' OR '.join(file_id_and_conditions)})")
+ params = {}
+ if writable_cube_ids and len(writable_cube_ids) > 0:
+ for idx, cube_id in enumerate(writable_cube_ids):
+ param_name = f"cube_id_{idx}"
+ user_name_conditions.append(f"n.user_name = ${param_name}")
+ params[param_name] = cube_id
- # Query nodes by filter if provided
- filter_ids = []
+ # Build filter conditions using common method (no query, direct use in WHERE clause)
+ filter_conditions = []
+ filter_params = {}
if filter:
- # Use get_by_metadata with empty filters list and filter
- filter_ids = self.get_by_metadata(
- filters=[],
- user_name=None,
- filter=filter,
- knowledgebase_ids=writable_cube_ids,
+ filter_conditions, filter_params = self._build_filter_conditions_cypher(
+ filter, param_counter_start=0, node_alias="n"
)
+ logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}")
+ params.update(filter_params)
- # If filter returned IDs, add condition for them
- if filter_ids:
- where_clauses.append("n.id IN $filter_ids")
- params["filter_ids"] = filter_ids
-
- # If no conditions (except user_name), return 0
- if not where_clauses:
+ # If no conditions to delete, return 0
+ if not memory_ids and not file_ids and not filter_conditions:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0
- # Build WHERE clause
- # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
- data_conditions = " OR ".join([f"({clause})" for clause in where_clauses])
+ # Build WHERE conditions list
+ where_clauses = []
- # Then, combine with user_name condition using AND (must match user_name AND one of the data conditions)
- user_name_where = " OR ".join(user_name_conditions)
- ids_where = f"({user_name_where}) AND ({data_conditions})"
+ # Scenario 1: memory_ids (standalone)
+ if memory_ids:
+ logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids")
+ where_clauses.append("n.id IN $memory_ids")
+ params["memory_ids"] = memory_ids
- logger.info(
- f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
- )
- print(
- f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
- )
+ # Scenario 2: file_ids + writable_cube_ids (combined)
+ if file_ids:
+ logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids")
+ file_id_conditions = []
+ for idx, file_id in enumerate(file_ids):
+ param_name = f"file_id_{idx}"
+ params[param_name] = file_id
+ # Check if this file_id is in the file_ids array field
+ file_id_conditions.append(f"${param_name} IN n.file_ids")
+ if file_id_conditions:
+ where_clauses.append(f"({' OR '.join(file_id_conditions)})")
+
+ # Scenario 3: filter (standalone, no writable_cube_ids needed)
+ if filter_conditions:
+ logger.info("[delete_node_by_prams] Processing filter conditions")
+ # Combine filter conditions with AND
+ filter_where = " AND ".join(filter_conditions)
+ where_clauses.append(f"({filter_where})")
+
+ # Build final WHERE clause
+ if not where_clauses:
+ logger.warning("[delete_node_by_prams] No WHERE conditions to delete")
+ return 0
+
+ # Combine all conditions with AND
+ data_conditions = " AND ".join([f"({clause})" for clause in where_clauses])
- # First count matching nodes to get accurate count
- count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
- logger.info(f"[delete_node_by_prams] count_query: {count_query}")
- print(f"[delete_node_by_prams] count_query: {count_query}")
+ # Add user_name filter if provided (for file_ids scenario)
+ if user_name_conditions:
+ user_name_where = " OR ".join(user_name_conditions)
+ final_where = f"({user_name_where}) AND ({data_conditions})"
+ else:
+ final_where = data_conditions
- # Then delete nodes
- delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
+ # Delete directly without pre-counting
+ delete_query = f"MATCH (n:Memory) WHERE {final_where} DETACH DELETE n"
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
- print(f"[delete_node_by_prams] delete_query: {delete_query}")
- print(f"[delete_node_by_prams] params: {params}")
deleted_count = 0
try:
with self.driver.session(database=self.db_name) as session:
- # Count nodes before deletion
- count_result = session.run(count_query, **params)
- count_record = count_result.single()
- expected_count = 0
- if count_record:
- expected_count = count_record["node_count"] or 0
-
- # Delete nodes
- session.run(delete_query, **params)
- # Use the count from before deletion as the actual deleted count
- deleted_count = expected_count
-
+ # Execute delete query
+ result = session.run(delete_query, **params)
+ # Consume the result to ensure deletion completes and get the summary
+ summary = result.consume()
+ # Get the count from the result summary
+ deleted_count = summary.counters.nodes_deleted if summary.counters else 0
+
+ elapsed_time = time.time() - batch_start_time
+ logger.info(
+ f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {deleted_count} nodes"
+ )
except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
raise
logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count
+
+ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
+ """Get user names by memory ids.
+
+ Args:
+ memory_ids: List of memory node IDs to query.
+
+ Returns:
+ dict[str, str | None]: Dictionary mapping memory_id to user_name.
+ - Key: memory_id
+ - Value: user_name if exists, None if memory_id does not exist
+ Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
+ """
+ if not memory_ids:
+ return {}
+
+ logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}")
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ # Query to get memory_id and user_name pairs
+ query = """
+ MATCH (n:Memory)
+ WHERE n.id IN $memory_ids
+ RETURN n.id AS memory_id, n.user_name AS user_name
+ """
+ logger.info(f"[get_user_names_by_memory_ids] query: {query}")
+
+ result = session.run(query, memory_ids=memory_ids)
+ result_dict = {}
+
+ # Build result dictionary from query results
+ for record in result:
+ memory_id = record["memory_id"]
+ user_name = record["user_name"]
+ result_dict[memory_id] = user_name if user_name else None
+
+ # Set None for memory_ids that were not found
+ for mid in memory_ids:
+ if mid not in result_dict:
+ result_dict[mid] = None
+
+ logger.info(
+ f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
+ f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
+ )
+
+ return result_dict
+ except Exception as e:
+ logger.error(
+ f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
+ )
+ raise
+
+ def exist_user_name(self, user_name: str) -> dict[str, bool]:
+ """Check if user name exists in the graph.
+
+ Args:
+ user_name: User name to check.
+
+ Returns:
+ dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence.
+ """
+ logger.info(f"[exist_user_name] Querying user_name {user_name}")
+ if not user_name:
+ return {user_name: False}
+
+ try:
+ with self.driver.session(database=self.db_name) as session:
+ # Query to check if user_name exists
+ query = """
+ MATCH (n:Memory)
+ WHERE n.user_name = $user_name
+ RETURN COUNT(n) AS count
+ """
+ logger.info(f"[exist_user_name] query: {query}")
+
+ result = session.run(query, user_name=user_name)
+ count = result.single()["count"]
+ result_dict = {user_name: count > 0}
+
+ logger.info(
+ f"[exist_user_name] user_name {user_name} exists: {result_dict[user_name]}"
+ )
+ return result_dict
+ except Exception as e:
+ logger.error(
+ f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
+ )
+ raise
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index c81e46804..e67f866ac 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -2502,38 +2502,121 @@ def clear(self, user_name: str | None = None) -> None:
@timed
def export_graph(
- self, include_embedding: bool = False, user_name: str | None = None
+ self,
+ include_embedding: bool = False,
+ user_name: str | None = None,
+ user_id: str | None = None,
+ page: int | None = None,
+ page_size: int | None = None,
+ filter: dict | None = None,
+ **kwargs,
) -> dict[str, Any]:
"""
Export all graph nodes and edges in a structured form.
Args:
include_embedding (bool): Whether to include the large embedding field.
user_name (str, optional): User name for filtering in non-multi-db mode
+ user_id (str, optional): User ID for filtering
+ page (int, optional): Page number (starts from 1). If None, exports all data without pagination.
+ page_size (int, optional): Number of items per page. If None, exports all data without pagination.
+ filter (dict, optional): Filter dictionary for metadata filtering. Supports "and", "or" logic and operators:
+ - "=": equality
+ - "in": value in list
+ - "contains": array contains value
+ - "gt", "lt", "gte", "lte": comparison operators
+ - "like": fuzzy matching
+ Example: {"and": [{"created_at": {"gte": "2025-01-01"}}, {"tags": {"contains": "AI"}}]}
Returns:
{
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
- "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
+ "edges": [ { "source": ..., "target": ..., "type": ... }, ... ],
+ "total_nodes": int, # Total number of nodes matching the filter criteria
+ "total_edges": int, # Total number of edges matching the filter criteria
}
"""
- user_name = user_name if user_name else self._get_config_value("user_name")
+ logger.info(
+ f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}"
+ )
+ user_id = user_id if user_id else self._get_config_value("user_id")
+
+ # Initialize total counts
+ total_nodes = 0
+ total_edges = 0
+
+ # Determine if pagination is needed
+ use_pagination = page is not None and page_size is not None
+
+ # Validate pagination parameters if pagination is enabled
+ if use_pagination:
+ if page < 1:
+ page = 1
+ if page_size < 1:
+ page_size = 10
+ offset = (page - 1) * page_size
+ else:
+ offset = None
+
conn = None
try:
conn = self._get_connection()
+ # Build WHERE conditions
+ where_conditions = []
+ if user_name:
+ where_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
+ )
+ if user_id:
+ where_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype"
+ )
+
+ # Build filter conditions using common method
+ filter_conditions = self._build_filter_conditions_sql(filter)
+ logger.info(f"[export_graph] filter_conditions: {filter_conditions}")
+ if filter_conditions:
+ where_conditions.extend(filter_conditions)
+
+ where_clause = ""
+ if where_conditions:
+ where_clause = f"WHERE {' AND '.join(where_conditions)}"
+
+ # Get total count of nodes before pagination
+ count_node_query = f"""
+ SELECT COUNT(*)
+ FROM "{self.db_name}_graph"."Memory"
+ {where_clause}
+ """
+ logger.info(f"[export_graph nodes count] Query: {count_node_query}")
+ with conn.cursor() as cursor:
+ cursor.execute(count_node_query)
+ total_nodes = cursor.fetchone()[0]
+
# Export nodes
+ # Build pagination clause if needed
+ pagination_clause = ""
+ if use_pagination:
+ pagination_clause = f"LIMIT {page_size} OFFSET {offset}"
+
if include_embedding:
node_query = f"""
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
- WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype
+ {where_clause}
+ ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,
+ id DESC
+ {pagination_clause}
"""
else:
node_query = f"""
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
- WHERE ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = '\"{user_name}\"'::agtype
+ {where_clause}
+ ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,
+ id DESC
+ {pagination_clause}
"""
-
+ logger.info(f"[export_graph nodes] Query: {node_query}")
with conn.cursor() as cursor:
cursor.execute(node_query)
node_results = cursor.fetchall()
@@ -2541,9 +2624,11 @@ def export_graph(
for row in node_results:
if include_embedding:
- properties_json, embedding_json = row
+ """row is (id, properties, embedding)"""
+ _, properties_json, embedding_json = row
else:
- properties_json = row
+ """row is (id, properties)"""
+ _, properties_json = row
embedding_json = None
# Parse properties from JSONB if it's a string
@@ -2555,20 +2640,13 @@ def export_graph(
else:
properties = properties_json if properties_json else {}
- # # Build node data
-
- """
- # node_data = {
- # "id": properties.get("id", node_id),
- # "memory": properties.get("memory", ""),
- # "metadata": properties
- # }
- """
-
- if include_embedding and embedding_json is not None:
+ # Remove embedding field if include_embedding is False
+ if not include_embedding:
+ properties.pop("embedding", None)
+ elif include_embedding and embedding_json is not None:
properties["embedding"] = embedding_json
- nodes.append(self._parse_node(json.loads(properties[1])))
+ nodes.append(self._parse_node(properties))
except Exception as e:
logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True)
@@ -2579,15 +2657,72 @@ def export_graph(
conn = None
try:
conn = self._get_connection()
+ # Build Cypher WHERE conditions for edges
+ cypher_where_conditions = []
+ if user_name:
+ cypher_where_conditions.append(f"a.user_name = '{user_name}'")
+ cypher_where_conditions.append(f"b.user_name = '{user_name}'")
+ if user_id:
+ cypher_where_conditions.append(f"a.user_id = '{user_id}'")
+ cypher_where_conditions.append(f"b.user_id = '{user_id}'")
+
+ # Build filter conditions for edges (apply to both source and target nodes)
+ filter_where_clause = self._build_filter_conditions_cypher(filter)
+ logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}")
+ if filter_where_clause:
+ # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists
+ # Remove the leading " AND " and replace n. with a. for source node and b. for target node
+ filter_clause = filter_where_clause.strip()
+ if filter_clause.startswith("AND "):
+ filter_clause = filter_clause[4:].strip()
+ # Replace n. with a. for source node and create a copy for target node
+ source_filter = filter_clause.replace("n.", "a.")
+ target_filter = filter_clause.replace("n.", "b.")
+ # Combine source and target filters with AND
+ combined_filter = f"({source_filter}) AND ({target_filter})"
+ cypher_where_conditions.append(combined_filter)
+
+ cypher_where_clause = ""
+ if cypher_where_conditions:
+ cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}"
+
+ # Get total count of edges before pagination
+ count_edge_query = f"""
+ SELECT COUNT(*)
+ FROM (
+ SELECT * FROM cypher('{self.db_name}_graph', $$
+ MATCH (a:Memory)-[r]->(b:Memory)
+ {cypher_where_clause}
+ RETURN a.id AS source, b.id AS target, type(r) as edge
+ $$) AS (source agtype, target agtype, edge agtype)
+ ) AS edges
+ """
+ logger.info(f"[export_graph edges count] Query: {count_edge_query}")
+ with conn.cursor() as cursor:
+ cursor.execute(count_edge_query)
+ total_edges = cursor.fetchone()[0]
+
# Export edges using cypher query
+ # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
+ # Build pagination clause if needed
+ edge_pagination_clause = ""
+ if use_pagination:
+ edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"
+
edge_query = f"""
- SELECT * FROM cypher('{self.db_name}_graph', $$
- MATCH (a:Memory)-[r]->(b:Memory)
- WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'
- RETURN a.id AS source, b.id AS target, type(r) as edge
- $$) AS (source agtype, target agtype, edge agtype)
+ SELECT source, target, edge FROM (
+ SELECT * FROM cypher('{self.db_name}_graph', $$
+ MATCH (a:Memory)-[r]->(b:Memory)
+ {cypher_where_clause}
+ RETURN a.id AS source, b.id AS target, type(r) as edge
+ ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC,
+ COALESCE(b.created_at, '1970-01-01T00:00:00') DESC,
+ a.id DESC, b.id DESC
+ $$) AS (source agtype, target agtype, edge agtype)
+ ) AS edges
+ {edge_pagination_clause}
"""
-
+ logger.info(f"[export_graph edges] Query: {edge_query}")
with conn.cursor() as cursor:
cursor.execute(edge_query)
edge_results = cursor.fetchall()
@@ -2653,7 +2788,12 @@ def export_graph(
finally:
self._return_connection(conn)
- return {"nodes": nodes, "edges": edges}
+ return {
+ "nodes": nodes,
+ "edges": edges,
+ "total_nodes": total_nodes,
+ "total_edges": total_edges,
+ }
@timed
def count_nodes(self, scope: str, user_name: str | None = None) -> int:
@@ -3356,7 +3496,7 @@ def add_nodes_batch(
logger.warning("[add_nodes_batch] Empty nodes list, skipping")
return
- logger.info(f"[add_nodes_batch] Adding {len(nodes)} nodes")
+ logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})")
# user_name comes from parameter; fallback to config if missing
effective_user_name = user_name if user_name else self.config.user_name
@@ -3485,92 +3625,89 @@ def add_nodes_batch(
if graph_id:
node["properties"]["graph_id"] = str(graph_id)
- # Batch insert using VALUES with multiple rows
- # Use psycopg2.extras.execute_values for efficient batch insert
- from psycopg2.extras import execute_values
-
- if embedding_column and any(node["embedding_vector"] for node in nodes_group):
- # Prepare data tuples for batch insert with embedding
- data_tuples = []
- for node in nodes_group:
- # Each tuple: (id, properties_json, embedding_json)
- data_tuples.append(
- (
- node["id"],
- json.dumps(node["properties"]),
- json.dumps(node["embedding_vector"])
- if node["embedding_vector"]
- else None,
+ # Use PREPARE/EXECUTE for efficient batch insert
+ # Generate unique prepare statement name to avoid conflicts
+ prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}"
+
+ try:
+ if embedding_column and any(
+ node["embedding_vector"] for node in nodes_group
+ ):
+ # PREPARE statement for insert with embedding
+ prepare_query = f"""
+ PREPARE {prepare_name} AS
+ INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})
+ VALUES (
+ ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring),
+ $2::text::agtype,
+ $3::vector
)
+ """
+ logger.info(
+ f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}"
+ )
+ logger.info(
+ f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}"
)
- # Build the INSERT query template
- insert_query = f"""
- INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column})
- VALUES %s
- """
+ cursor.execute(prepare_query)
- # Build the VALUES template for execute_values
- # Each row: (graph_id_function, agtype, vector)
- # Note: properties column is agtype, not jsonb
- template = f"""
- (
- ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
- %s::text::agtype,
- %s::vector
- )
- """
- # Execute batch insert
- execute_values(
- cursor,
- insert_query,
- data_tuples,
- template=template,
- page_size=100, # Insert in batches of 100
- )
- else:
- # Prepare data tuples for batch insert without embedding
- data_tuples = []
- for node in nodes_group:
- # Each tuple: (id, properties_json)
- data_tuples.append(
- (
- node["id"],
- json.dumps(node["properties"]),
+ # Execute prepared statement for each node
+ for node in nodes_group:
+ properties_json = json.dumps(node["properties"])
+ embedding_json = (
+ json.dumps(node["embedding_vector"])
+ if node["embedding_vector"]
+ else None
+ )
+
+ cursor.execute(
+ f"EXECUTE {prepare_name}(%s, %s, %s)",
+ (node["id"], properties_json, embedding_json),
)
+ else:
+ # PREPARE statement for insert without embedding
+ prepare_query = f"""
+ PREPARE {prepare_name} AS
+ INSERT INTO {self.db_name}_graph."Memory"(id, properties)
+ VALUES (
+ ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, $1::text::cstring),
+ $2::text::agtype
+ )
+ """
+ logger.info(
+ f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}"
+ )
+ logger.info(
+ f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}"
)
+ cursor.execute(prepare_query)
- # Build the INSERT query template
- insert_query = f"""
- INSERT INTO {self.db_name}_graph."Memory"(id, properties)
- VALUES %s
- """
+ # Execute prepared statement for each node
+ for node in nodes_group:
+ properties_json = json.dumps(node["properties"])
- # Build the VALUES template for execute_values
- # Note: properties column is agtype, not jsonb
- template = f"""
- (
- ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring),
- %s::text::agtype
+ cursor.execute(
+ f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json)
+ )
+ finally:
+ # DEALLOCATE prepared statement (always execute, even on error)
+ try:
+ cursor.execute(f"DEALLOCATE {prepare_name}")
+ logger.info(
+ f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}"
+ )
+ except Exception as dealloc_error:
+ logger.warning(
+ f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}"
)
- """
- logger.info(f"[add_nodes_batch] Inserting insert_query:{insert_query}")
- logger.info(f"[add_nodes_batch] Inserting data_tuples:{data_tuples}")
- # Execute batch insert
- execute_values(
- cursor,
- insert_query,
- data_tuples,
- template=template,
- page_size=100, # Insert in batches of 100
- )
logger.info(
f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}"
)
elapsed_time = time.time() - batch_start_time
logger.info(
- f"[add_nodes_batch] execute_values completed successfully in {elapsed_time:.2f}s"
+ f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s"
)
except Exception as e:
@@ -4196,15 +4333,29 @@ def build_cypher_filter_condition(condition_dict: dict) -> str:
cypher_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="}
cypher_op = cypher_op_map[op]
+ # Check if key is a datetime field
+ is_datetime = key in ("created_at", "updated_at") or key.endswith(
+ "_at"
+ )
+
# Check if key starts with "info." prefix (for nested fields like info.A, info.B)
if key.startswith("info."):
# Nested field access: n.info.field_name
info_field = key[5:] # Remove "info." prefix
+ is_info_datetime = info_field in (
+ "created_at",
+ "updated_at",
+ ) or info_field.endswith("_at")
if isinstance(op_value, str):
escaped_value = escape_cypher_string(op_value)
- condition_parts.append(
- f"n.info.{info_field} {cypher_op} '{escaped_value}'"
- )
+ if is_info_datetime:
+ condition_parts.append(
+ f"n.info.{info_field}::timestamp {cypher_op} '{escaped_value}'::timestamp"
+ )
+ else:
+ condition_parts.append(
+ f"n.info.{info_field} {cypher_op} '{escaped_value}'"
+ )
else:
condition_parts.append(
f"n.info.{info_field} {cypher_op} {op_value}"
@@ -4213,9 +4364,14 @@ def build_cypher_filter_condition(condition_dict: dict) -> str:
# Direct property access (e.g., "created_at" is directly in n, not in n.info)
if isinstance(op_value, str):
escaped_value = escape_cypher_string(op_value)
- condition_parts.append(
- f"n.{key} {cypher_op} '{escaped_value}'"
- )
+ if is_datetime:
+ condition_parts.append(
+ f"n.{key}::timestamp {cypher_op} '{escaped_value}'::timestamp"
+ )
+ else:
+ condition_parts.append(
+ f"n.{key} {cypher_op} '{escaped_value}'"
+ )
else:
condition_parts.append(f"n.{key} {cypher_op} {op_value}")
elif op == "=":
@@ -4309,70 +4465,133 @@ def build_cypher_filter_condition(condition_dict: dict) -> str:
elif op == "in":
# Handle in operator (for checking if field value is in a list)
# Supports array format: {"field": {"in": ["value1", "value2"]}}
- # Generates: n.field IN ['value1', 'value2'] or (n.field = 'value1' OR n.field = 'value2')
+ # For array fields (like file_ids, tags, sources), uses CONTAINS logic
+ # For scalar fields, uses equality or IN clause
if not isinstance(op_value, list):
raise ValueError(
f"in operator only supports array format. "
f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}"
)
+ # Check if key is an array field
+ is_array_field = key in ("file_ids", "tags", "sources")
+
# Check if key starts with "info." prefix
if key.startswith("info."):
info_field = key[5:] # Remove "info." prefix
- # Build OR conditions for nested properties (Apache AGE compatibility)
+ # Check if info field is an array field
+ is_info_array = info_field in ("tags", "sources", "file_ids")
+
if len(op_value) == 0:
# Empty list means no match
condition_parts.append("false")
elif len(op_value) == 1:
- # Single value, use equality
+ # Single value
item = op_value[0]
- if isinstance(item, str):
- escaped_value = escape_cypher_string(item)
- condition_parts.append(
- f"n.info.{info_field} = '{escaped_value}'"
- )
+ if is_info_array:
+ # For array fields, use CONTAINS (value IN array_field)
+ if isinstance(item, str):
+ escaped_value = escape_cypher_string(item)
+ condition_parts.append(
+ f"'{escaped_value}' IN n.info.{info_field}"
+ )
+ else:
+ condition_parts.append(
+ f"{item} IN n.info.{info_field}"
+ )
else:
- condition_parts.append(f"n.info.{info_field} = {item}")
- else:
- # Multiple values, use OR conditions instead of IN (Apache AGE compatibility)
- or_conditions = []
- for item in op_value:
+ # For scalar fields, use equality
if isinstance(item, str):
escaped_value = escape_cypher_string(item)
- or_conditions.append(
+ condition_parts.append(
f"n.info.{info_field} = '{escaped_value}'"
)
else:
- or_conditions.append(
+ condition_parts.append(
f"n.info.{info_field} = {item}"
)
+ else:
+ # Multiple values, use OR conditions
+ or_conditions = []
+ for item in op_value:
+ if is_info_array:
+ # For array fields, use CONTAINS (value IN array_field)
+ if isinstance(item, str):
+ escaped_value = escape_cypher_string(item)
+ or_conditions.append(
+ f"'{escaped_value}' IN n.info.{info_field}"
+ )
+ else:
+ or_conditions.append(
+ f"{item} IN n.info.{info_field}"
+ )
+ else:
+ # For scalar fields, use equality
+ if isinstance(item, str):
+ escaped_value = escape_cypher_string(item)
+ or_conditions.append(
+ f"n.info.{info_field} = '{escaped_value}'"
+ )
+ else:
+ or_conditions.append(
+ f"n.info.{info_field} = {item}"
+ )
if or_conditions:
condition_parts.append(
f"({' OR '.join(or_conditions)})"
)
else:
# Direct property access
- # Build array for IN clause or OR conditions
if len(op_value) == 0:
# Empty list means no match
condition_parts.append("false")
elif len(op_value) == 1:
- # Single value, use equality
+ # Single value
item = op_value[0]
- if isinstance(item, str):
- escaped_value = escape_cypher_string(item)
- condition_parts.append(f"n.{key} = '{escaped_value}'")
+ if is_array_field:
+ # For array fields, use CONTAINS (value IN array_field)
+ if isinstance(item, str):
+ escaped_value = escape_cypher_string(item)
+ condition_parts.append(
+ f"'{escaped_value}' IN n.{key}"
+ )
+ else:
+ condition_parts.append(f"{item} IN n.{key}")
else:
- condition_parts.append(f"n.{key} = {item}")
+ # For scalar fields, use equality
+ if isinstance(item, str):
+ escaped_value = escape_cypher_string(item)
+ condition_parts.append(
+ f"n.{key} = '{escaped_value}'"
+ )
+ else:
+ condition_parts.append(f"n.{key} = {item}")
else:
- # Multiple values, use IN clause
- escaped_items = [
- f"'{escape_cypher_string(str(item))}'"
- if isinstance(item, str)
- else str(item)
- for item in op_value
- ]
- array_str = "[" + ", ".join(escaped_items) + "]"
- condition_parts.append(f"n.{key} IN {array_str}")
+ # Multiple values
+ if is_array_field:
+ # For array fields, use OR conditions with CONTAINS
+ or_conditions = []
+ for item in op_value:
+ if isinstance(item, str):
+ escaped_value = escape_cypher_string(item)
+ or_conditions.append(
+ f"'{escaped_value}' IN n.{key}"
+ )
+ else:
+ or_conditions.append(f"{item} IN n.{key}")
+ if or_conditions:
+ condition_parts.append(
+ f"({' OR '.join(or_conditions)})"
+ )
+ else:
+ # For scalar fields, use IN clause
+ escaped_items = [
+ f"'{escape_cypher_string(str(item))}'"
+ if isinstance(item, str)
+ else str(item)
+ for item in op_value
+ ]
+ array_str = "[" + ", ".join(escaped_items) + "]"
+ condition_parts.append(f"n.{key} IN {array_str}")
elif op == "like":
# Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%')
# Check if key starts with "info." prefix
@@ -4476,29 +4695,52 @@ def build_filter_condition(condition_dict: dict) -> str:
sql_op_map = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="}
sql_op = sql_op_map[op]
+ # Check if key is a datetime field
+ is_datetime = key in ("created_at", "updated_at") or key.endswith(
+ "_at"
+ )
+
# Check if key starts with "info." prefix (for nested fields like info.A, info.B)
if key.startswith("info."):
# Nested field access: properties->'info'->'field_name'
info_field = key[5:] # Remove "info." prefix
+ is_info_datetime = info_field in (
+ "created_at",
+ "updated_at",
+ ) or info_field.endswith("_at")
if isinstance(op_value, str):
escaped_value = escape_sql_string(op_value)
- condition_parts.append(
- f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype"
- )
+ if is_info_datetime:
+ condition_parts.append(
+ f"TRIM(BOTH '\"' FROM ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text)::timestamp {sql_op} '{escaped_value}'::timestamp"
+ )
+ else:
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} '\"{escaped_value}\"'::agtype"
+ )
else:
+ # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype
+ value_json = json.dumps(op_value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} {op_value}::agtype"
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) {sql_op} ag_catalog.agtype_in('{value_json}')"
)
else:
# Direct property access (e.g., "created_at" is directly in properties, not in properties.info)
if isinstance(op_value, str):
escaped_value = escape_sql_string(op_value)
- condition_parts.append(
- f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype"
- )
+ if is_datetime:
+ condition_parts.append(
+ f"TRIM(BOTH '\"' FROM ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype)::text)::timestamp {sql_op} '{escaped_value}'::timestamp"
+ )
+ else:
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} '\"{escaped_value}\"'::agtype"
+ )
else:
+ # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype
+ value_json = json.dumps(op_value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} {op_value}::agtype"
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) {sql_op} ag_catalog.agtype_in('{value_json}')"
)
elif op == "=":
# Handle equality operator
@@ -4539,8 +4781,10 @@ def build_filter_condition(condition_dict: dict) -> str:
f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '[{op_value}]'::agtype"
)
else:
+ # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype
+ value_json = json.dumps(op_value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {op_value}::agtype"
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = ag_catalog.agtype_in('{value_json}')"
)
else:
# Direct property access
@@ -4567,8 +4811,10 @@ def build_filter_condition(condition_dict: dict) -> str:
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '{json_array}'::agtype"
)
else:
+ # For non-string list values, convert to JSON string and then to agtype
+ value_json = json.dumps(op_value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype"
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')"
)
else:
if key in ("tags", "sources"):
@@ -4576,32 +4822,149 @@ def build_filter_condition(condition_dict: dict) -> str:
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '[{op_value}]'::agtype"
)
else:
+ # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype
+ value_json = json.dumps(op_value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {op_value}::agtype"
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')"
)
elif op == "contains":
- # Handle contains operator (for string fields only)
- # Check if agtype contains value (using @> operator)
- if not isinstance(op_value, str):
- raise ValueError(
- f"contains operator only supports string format. "
- f"Use {{'{key}': {{'contains': '{op_value}'}}}} instead of {{'{key}': {{'contains': {op_value}}}}}"
- )
+ # Handle contains operator
+ # For array fields: check if array contains the value using @> operator
+ # For string fields: check if string contains the value using @> operator
# Check if key starts with "info." prefix
if key.startswith("info."):
info_field = key[5:] # Remove "info." prefix
- # String contains: use @> operator for agtype contains
- escaped_value = escape_sql_string(op_value)
+ escaped_value = escape_sql_string(str(op_value))
+ # For array fields, use @> with array format: '["value"]'::agtype
+ # For string fields, use @> with string format: '"value"'::agtype
+ # We'll use array format for contains to check if array contains the value
condition_parts.append(
- f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '\"{escaped_value}\"'::agtype"
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype"
)
else:
# Direct property access
- # String contains: use @> operator for agtype contains
- escaped_value = escape_sql_string(op_value)
+ escaped_value = escape_sql_string(str(op_value))
+ # For array fields, use @> with array format
condition_parts.append(
- f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '\"{escaped_value}\"'::agtype"
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype"
)
+ elif op == "in":
+ # Handle in operator (for checking if field value is in a list)
+ # Supports array format: {"field": {"in": ["value1", "value2"]}}
+ # For array fields (like file_ids, tags, sources), uses @> operator (contains)
+ # For scalar fields, uses = operator (equality)
+ if not isinstance(op_value, list):
+ raise ValueError(
+ f"in operator only supports array format. "
+ f"Use {{'{key}': {{'in': ['{op_value}']}}}} instead of {{'{key}': {{'in': '{op_value}'}}}}"
+ )
+ # Check if key is an array field
+ is_array_field = key in ("file_ids", "tags", "sources")
+
+ # Check if key starts with "info." prefix
+ if key.startswith("info."):
+ info_field = key[5:] # Remove "info." prefix
+ # Check if info field is an array field
+ is_info_array = info_field in ("tags", "sources", "file_ids")
+
+ if len(op_value) == 0:
+ # Empty list means no match
+ condition_parts.append("false")
+ elif len(op_value) == 1:
+ # Single value
+ item = op_value[0]
+ if is_info_array:
+ # For array fields, use @> operator (contains)
+ escaped_value = escape_sql_string(str(item))
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype"
+ )
+ else:
+ # For scalar fields, use equality
+ if isinstance(item, str):
+ escaped_value = escape_sql_string(item)
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype"
+ )
+ else:
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype"
+ )
+ else:
+ # Multiple values, use OR conditions
+ or_conditions = []
+ for item in op_value:
+ if is_info_array:
+ # For array fields, use @> operator (contains) to check if array contains the value
+ escaped_value = escape_sql_string(str(item))
+ or_conditions.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '[\"{escaped_value}\"]'::agtype"
+ )
+ else:
+ # For scalar fields, use equality
+ if isinstance(item, str):
+ escaped_value = escape_sql_string(item)
+ or_conditions.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype"
+ )
+ else:
+ or_conditions.append(
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = {item}::agtype"
+ )
+ if or_conditions:
+ condition_parts.append(
+ f"({' OR '.join(or_conditions)})"
+ )
+ else:
+ # Direct property access
+ if len(op_value) == 0:
+ # Empty list means no match
+ condition_parts.append("false")
+ elif len(op_value) == 1:
+ # Single value
+ item = op_value[0]
+ if is_array_field:
+ # For array fields, use @> operator (contains)
+ escaped_value = escape_sql_string(str(item))
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype"
+ )
+ else:
+ # For scalar fields, use equality
+ if isinstance(item, str):
+ escaped_value = escape_sql_string(item)
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype"
+ )
+ else:
+ condition_parts.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype"
+ )
+ else:
+ # Multiple values, use OR conditions
+ or_conditions = []
+ for item in op_value:
+ if is_array_field:
+ # For array fields, use @> operator (contains) to check if array contains the value
+ escaped_value = escape_sql_string(str(item))
+ or_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) @> '[\"{escaped_value}\"]'::agtype"
+ )
+ else:
+ # For scalar fields, use equality
+ if isinstance(item, str):
+ escaped_value = escape_sql_string(item)
+ or_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype"
+ )
+ else:
+ or_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {item}::agtype"
+ )
+ if or_conditions:
+ condition_parts.append(
+ f"({' OR '.join(or_conditions)})"
+ )
elif op == "like":
# Handle like operator (for fuzzy matching, similar to SQL LIKE '%value%')
# Check if key starts with "info." prefix
@@ -4647,8 +5010,10 @@ def build_filter_condition(condition_dict: dict) -> str:
f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{escaped_value}\"'::agtype"
)
else:
+ # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype
+ value_json = json.dumps(value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = '\"{value}\"'::agtype"
+ f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) = ag_catalog.agtype_in('{value_json}')"
)
else:
# Direct property access (simple equality)
@@ -4658,8 +5023,10 @@ def build_filter_condition(condition_dict: dict) -> str:
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{escaped_value}\"'::agtype"
)
else:
+ # For non-string values (numbers, booleans, etc.), convert to JSON string and then to agtype
+ value_json = json.dumps(value)
condition_parts.append(
- f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
+ f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = ag_catalog.agtype_in('{value_json}')"
)
return " AND ".join(condition_parts)
@@ -4776,7 +5143,8 @@ def delete_node_by_prams(
If not provided, no user_name filter will be applied.
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
- filter (dict, optional): Filter dictionary to query matching nodes for deletion.
+ filter (dict, optional): Filter dictionary for metadata filtering.
+ Filter conditions are directly used in DELETE WHERE clause without pre-querying.
Returns:
int: Number of nodes deleted.
@@ -4796,122 +5164,80 @@ def delete_node_by_prams(
f"agtype_access_operator(VARIADIC ARRAY[properties, '\"user_name\"'::agtype]) = '\"{cube_id}\"'::agtype"
)
- # Build WHERE conditions separately for memory_ids and file_ids
- where_conditions = []
-
- # Handle memory_ids: query properties.id
- if memory_ids and len(memory_ids) > 0:
- memory_id_conditions = []
- for node_id in memory_ids:
- memory_id_conditions.append(
- f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
- )
- if memory_id_conditions:
- where_conditions.append(f"({' OR '.join(memory_id_conditions)})")
-
- # Check if any file_id is in the file_ids array field (OR relationship)
- if file_ids and len(file_ids) > 0:
- file_id_conditions = []
- for file_id in file_ids:
- # Format: agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '"file_ids"'::agtype]), '"file_id"'::agtype)
- file_id_conditions.append(
- f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
- )
- if file_id_conditions:
- # Use OR to match any file_id in the array
- where_conditions.append(f"({' OR '.join(file_id_conditions)})")
-
- # Query nodes by filter if provided
- filter_ids = set()
+ # Build filter conditions using common method (no query, direct use in WHERE clause)
+ filter_conditions = []
if filter:
- # Parse filter to validate and transform field names (e.g., add "info." prefix if needed)
- parsed_filter = self.parse_filter(filter)
- if parsed_filter:
- # Use get_by_metadata with empty filters list and parsed filter
- filter_ids = set(
- self.get_by_metadata(
- filters=[],
- user_name=None,
- filter=parsed_filter,
- knowledgebase_ids=writable_cube_ids,
- )
- )
- else:
- logger.warning(
- "[delete_node_by_prams] Filter parsed to None, skipping filter query"
- )
-
- # If filter returned IDs, add condition for them
- if filter_ids:
- filter_id_conditions = []
- for node_id in filter_ids:
- filter_id_conditions.append(
- f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
- )
- if filter_id_conditions:
- where_conditions.append(f"({' OR '.join(filter_id_conditions)})")
+ filter_conditions = self._build_filter_conditions_sql(filter)
+ logger.info(f"[delete_node_by_prams] filter_conditions: {filter_conditions}")
- # If no conditions (except user_name), return 0
- if not where_conditions:
+ # If no conditions to delete, return 0
+ if not memory_ids and not file_ids and not filter_conditions:
logger.warning(
"[delete_node_by_prams] No nodes to delete (no memory_ids, file_ids, or filter provided)"
)
return 0
- # Build WHERE clause
- # First, combine memory_ids, file_ids, and filter conditions with OR (any condition can match)
- data_conditions = " OR ".join([f"({cond})" for cond in where_conditions])
-
- # Build final WHERE clause
- # If user_name_conditions exist, combine with data_conditions using AND
- # Otherwise, use only data_conditions
- if user_name_conditions:
- user_name_where = " OR ".join(user_name_conditions)
- where_clause = f"({user_name_where}) AND ({data_conditions})"
- else:
- where_clause = f"({data_conditions})"
-
- # Use SQL DELETE query for better performance
- # First count matching nodes to get accurate count
- count_query = f"""
- SELECT COUNT(*)
- FROM "{self.db_name}_graph"."Memory"
- WHERE {where_clause}
- """
- logger.info(f"[delete_node_by_prams] count_query: {count_query}")
-
- # Then delete nodes
- delete_query = f"""
- DELETE FROM "{self.db_name}_graph"."Memory"
- WHERE {where_clause}
- """
-
- logger.info(
- f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
- )
- logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
-
conn = None
- deleted_count = 0
+ total_deleted_count = 0
try:
conn = self._get_connection()
with conn.cursor() as cursor:
- # Count nodes before deletion
- cursor.execute(count_query)
- count_result = cursor.fetchone()
- expected_count = count_result[0] if count_result else 0
+ # Build WHERE conditions list
+ where_conditions = []
+
+ # Add memory_ids conditions
+ if memory_ids:
+ logger.info(f"[delete_node_by_prams] Processing {len(memory_ids)} memory_ids")
+ id_conditions = []
+ for node_id in memory_ids:
+ id_conditions.append(
+ f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{node_id}\"'::agtype"
+ )
+ where_conditions.append(f"({' OR '.join(id_conditions)})")
+
+ # Add file_ids conditions
+ if file_ids:
+ logger.info(f"[delete_node_by_prams] Processing {len(file_ids)} file_ids")
+ file_id_conditions = []
+ for file_id in file_ids:
+ file_id_conditions.append(
+ f"agtype_in_operator(agtype_access_operator(VARIADIC ARRAY[properties, '\"file_ids\"'::agtype]), '\"{file_id}\"'::agtype)"
+ )
+ where_conditions.append(f"({' OR '.join(file_id_conditions)})")
- logger.info(
- f"[delete_node_by_prams] Found {expected_count} nodes matching the criteria"
- )
+ # Add filter conditions
+ if filter_conditions:
+ logger.info("[delete_node_by_prams] Processing filter conditions")
+ where_conditions.extend(filter_conditions)
+
+ # Add user_name filter if provided
+ if user_name_conditions:
+ user_name_where = " OR ".join(user_name_conditions)
+ where_conditions.append(f"({user_name_where})")
+
+ # Build final WHERE clause
+ if not where_conditions:
+ logger.warning("[delete_node_by_prams] No WHERE conditions to delete")
+ return 0
+
+ where_clause = " AND ".join(where_conditions)
+
+ # Delete directly without counting
+ delete_query = f"""
+ DELETE FROM "{self.db_name}_graph"."Memory"
+ WHERE {where_clause}
+ """
+ logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
- # Delete nodes
cursor.execute(delete_query)
- # Use rowcount to get actual deleted count
deleted_count = cursor.rowcount
+ total_deleted_count = deleted_count
+
+ logger.info(f"[delete_node_by_prams] Deleted {deleted_count} nodes")
+
elapsed_time = time.time() - batch_start_time
logger.info(
- f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, deleted {deleted_count} nodes"
+ f"[delete_node_by_prams] Deletion completed successfully in {elapsed_time:.2f}s, total deleted {total_deleted_count} nodes"
)
except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
@@ -4919,89 +5245,108 @@ def delete_node_by_prams(
finally:
self._return_connection(conn)
- logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
- return deleted_count
+ logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes")
+ return total_deleted_count
@timed
- def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[str]]:
+ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, str | None]:
"""Get user names by memory ids.
Args:
memory_ids: List of memory node IDs to query.
Returns:
- dict[str, list[str]]: Dictionary with one key:
- - 'no_exist_memory_ids': List of memory_ids that do not exist (if any are missing)
- - 'exist_user_names': List of distinct user names (if all memory_ids exist)
+ dict[str, str | None]: Dictionary mapping memory_id to user_name.
+ - Key: memory_id
+ - Value: user_name if exists, None if memory_id does not exist
+ Example: {"4918d700-6f01-4f4c-a076-75cc7b0e1a7c": "zhangsan", "2222222": None}
"""
+ logger.info(f"[get_user_names_by_memory_ids] Querying memory_ids {memory_ids}")
if not memory_ids:
- return {"exist_user_names": []}
+ return {}
+
+ # Validate and normalize memory_ids
+ # Ensure all items are strings
+ normalized_memory_ids = []
+ for mid in memory_ids:
+ if not isinstance(mid, str):
+ mid = str(mid)
+ # Remove any whitespace
+ mid = mid.strip()
+ if mid:
+ normalized_memory_ids.append(mid)
+
+ if not normalized_memory_ids:
+ return {}
+
+ # Escape special characters for JSON string format in agtype
+ def escape_memory_id(mid: str) -> str:
+ """Escape special characters in memory_id for JSON string format."""
+ # Escape backslashes first, then double quotes
+ mid_str = mid.replace("\\", "\\\\")
+ mid_str = mid_str.replace('"', '\\"')
+ return mid_str
# Build OR conditions for each memory_id
id_conditions = []
- for mid in memory_ids:
+ for mid in normalized_memory_ids:
+ # Escape special characters
+ escaped_mid = escape_memory_id(mid)
id_conditions.append(
- f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{mid}\"'::agtype"
+ f"ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype) = '\"{escaped_mid}\"'::agtype"
)
where_clause = f"({' OR '.join(id_conditions)})"
- # Query to check which memory_ids exist
- check_query = f"""
- SELECT ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text
+ # Query to get memory_id and user_name pairs
+ query = f"""
+ SELECT
+ ag_catalog.agtype_access_operator(properties, '\"id\"'::agtype)::text AS memory_id,
+ ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text AS user_name
FROM "{self.db_name}_graph"."Memory"
WHERE {where_clause}
"""
- logger.info(f"[get_user_names_by_memory_ids] check_query: {check_query}")
+ logger.info(f"[get_user_names_by_memory_ids] query: {query}")
conn = None
+ result_dict = {}
try:
conn = self._get_connection()
with conn.cursor() as cursor:
- # Check which memory_ids exist
- cursor.execute(check_query)
- check_results = cursor.fetchall()
- existing_ids = set()
- for row in check_results:
- node_id = row[0]
- # Remove quotes if present
- if isinstance(node_id, str):
- node_id = node_id.strip('"').strip("'")
- existing_ids.add(node_id)
+ cursor.execute(query)
+ results = cursor.fetchall()
- # Check if any memory_ids are missing
- no_exist_list = [mid for mid in memory_ids if mid not in existing_ids]
+ # Build result dictionary from query results
+ for row in results:
+ memory_id_raw = row[0]
+ user_name_raw = row[1]
- # If any memory_ids are missing, return no_exist_memory_ids
- if no_exist_list:
- logger.info(
- f"[get_user_names_by_memory_ids] Found {len(no_exist_list)} non-existing memory_ids: {no_exist_list}"
- )
- return {"no_exist_memory_ids": no_exist_list}
+ # Remove quotes if present
+ if isinstance(memory_id_raw, str):
+ memory_id = memory_id_raw.strip('"').strip("'")
+ else:
+ memory_id = str(memory_id_raw).strip('"').strip("'")
- # All memory_ids exist, query user_names
- user_names_query = f"""
- SELECT DISTINCT ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype)::text
- FROM "{self.db_name}_graph"."Memory"
- WHERE {where_clause}
- """
- logger.info(f"[get_user_names_by_memory_ids] user_names_query: {user_names_query}")
+ if isinstance(user_name_raw, str):
+ user_name = user_name_raw.strip('"').strip("'")
+ else:
+ user_name = (
+ str(user_name_raw).strip('"').strip("'") if user_name_raw else None
+ )
- cursor.execute(user_names_query)
- results = cursor.fetchall()
- user_names = []
- for row in results:
- user_name = row[0]
- # Remove quotes if present
- if isinstance(user_name, str):
- user_name = user_name.strip('"').strip("'")
- user_names.append(user_name)
+ result_dict[memory_id] = user_name if user_name else None
+
+ # Set None for memory_ids that were not found
+ for mid in normalized_memory_ids:
+ if mid not in result_dict:
+ result_dict[mid] = None
logger.info(
- f"[get_user_names_by_memory_ids] All memory_ids exist, found {len(user_names)} distinct user_names"
+ f"[get_user_names_by_memory_ids] Found {len([v for v in result_dict.values() if v is not None])} memory_ids with user_names, "
+ f"{len([v for v in result_dict.values() if v is None])} memory_ids without user_names"
)
- return {"exist_user_names": user_names}
+ return result_dict
except Exception as e:
logger.error(
f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True
@@ -5009,3 +5354,52 @@ def get_user_names_by_memory_ids(self, memory_ids: list[str]) -> dict[str, list[
raise
finally:
self._return_connection(conn)
+
+ def exist_user_name(self, user_name: str) -> dict[str, bool]:
+ """Check if user name exists in the graph.
+
+ Args:
+ user_name: User name to check.
+
+ Returns:
+ dict[str, bool]: Dictionary with user_name as key and bool as value indicating existence.
+ """
+ logger.info(f"[exist_user_name] Querying user_name {user_name}")
+ if not user_name:
+ return {user_name: False}
+
+ # Escape special characters for JSON string format in agtype
+ def escape_user_name(un: str) -> str:
+ """Escape special characters in user_name for JSON string format."""
+ # Escape backslashes first, then double quotes
+ un_str = un.replace("\\", "\\\\")
+ un_str = un_str.replace('"', '\\"')
+ return un_str
+
+ # Escape special characters
+ escaped_un = escape_user_name(user_name)
+
+ # Query to check if user_name exists
+ query = f"""
+ SELECT COUNT(*)
+ FROM "{self.db_name}_graph"."Memory"
+ WHERE ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{escaped_un}\"'::agtype
+ """
+ logger.info(f"[exist_user_name] query: {query}")
+ result_dict = {}
+ conn = None
+ try:
+ conn = self._get_connection()
+ with conn.cursor() as cursor:
+ cursor.execute(query)
+ count = cursor.fetchone()[0]
+ result = count > 0
+ result_dict[user_name] = result
+ return result_dict
+ except Exception as e:
+ logger.error(
+ f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True
+ )
+ raise
+ finally:
+ self._return_connection(conn)
diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py
index d46db7c9e..b5fc4ba13 100644
--- a/src/memos/llms/hf.py
+++ b/src/memos/llms/hf.py
@@ -2,13 +2,7 @@
from typing import Any
from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
DynamicCache,
- LogitsProcessorList,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
)
from memos.configs.llm import HFLLMConfig
@@ -30,6 +24,17 @@ def __init__(self, config: HFLLMConfig):
"""
Initialize the HFLLM model and tokenizer, and set up logits processors for sampling.
"""
+ import torch
+
+ from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ )
+
self.config = config
# Default model if not specified
@@ -37,9 +42,14 @@ def __init__(self, config: HFLLMConfig):
self.config.model_name_or_path = "Qwen/Qwen3-1.7B"
# Initialize hf model
- self.model = AutoModelForCausalLM.from_pretrained(
- self.config.model_name_or_path, torch_dtype="auto", device_map="auto"
- )
+ if torch.backends.mps.is_available():
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.config.model_name_or_path, torch_dtype="auto"
+ ).to("mps")
+ else:
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.config.model_name_or_path, torch_dtype="auto", device_map="auto"
+ )
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_name_or_path, use_fast=True
)
@@ -355,6 +365,7 @@ def build_kv_cache(self, messages) -> DynamicCache:
DynamicCache: The constructed KV cache object.
"""
import torch
+ import transformers
# Accept multiple input types and convert to standard chat messages
if isinstance(messages, str):
@@ -391,7 +402,7 @@ def build_kv_cache(self, messages) -> DynamicCache:
# Convert from legacy tuple format to DynamicCache if needed
if isinstance(kv, tuple):
- kv = DynamicCache.from_legacy_cache(kv)
+ kv = transformers.DynamicCache.from_legacy_cache(kv)
# Handle compatibility between old and new transformers versions
# In newer versions, DynamicCache uses 'layers' attribute
diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py
index 563b8723e..f49f1d7d1 100644
--- a/src/memos/llms/openai.py
+++ b/src/memos/llms/openai.py
@@ -1,4 +1,5 @@
import json
+import time
from collections.abc import Generator
@@ -46,9 +47,16 @@ def generate(self, messages: MessageList, **kwargs) -> str:
"extra_body": kwargs.get("extra_body", self.config.extra_body),
"tools": kwargs.get("tools", NOT_GIVEN),
}
+ start_time = time.perf_counter()
logger.info(f"OpenAI LLM Request body: {request_body}")
+
response = self.client.chat.completions.create(**request_body)
- logger.info(f"Response from OpenAI: {response.model_dump_json()}")
+
+ cost_time = time.perf_counter() - start_time
+ logger.info(
+ f"Request body: {request_body}, Response from OpenAI: {response.model_dump_json()}, Cost time: {cost_time}"
+ )
+
tool_calls = getattr(response.choices[0].message, "tool_calls", None)
if isinstance(tool_calls, list) and len(tool_calls) > 0:
return self.tool_call_parser(tool_calls)
@@ -59,8 +67,8 @@ def generate(self, messages: MessageList, **kwargs) -> str:
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
if reasoning_content:
- return reasoning_content + response_content
- return response_content
+ return reasoning_content + (response_content or "")
+ return response_content or ""
@timed_with_status(
log_prefix="OpenAI LLM Stream",
@@ -151,7 +159,7 @@ def generate(self, messages: MessageList, **kwargs) -> str:
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
else:
- return response_content
+ return response_content or ""
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
"""Stream response from Azure OpenAI LLM with optional reasoning support."""
diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py
index 0b3fc3846..15d7c336a 100644
--- a/src/memos/mem_feedback/feedback.py
+++ b/src/memos/mem_feedback/feedback.py
@@ -2,9 +2,10 @@
import difflib
import json
import re
+import uuid
from datetime import datetime
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Literal
from tenacity import retry, stop_after_attempt, wait_random_exponential
@@ -17,6 +18,8 @@
from memos.log import get_logger
from memos.mem_feedback.base import BaseMemFeedback
from memos.mem_feedback.utils import (
+ extract_bracket_content,
+ extract_square_brackets_content,
general_split_into_chunks,
make_mem_item,
should_keep_update,
@@ -33,6 +36,7 @@
if TYPE_CHECKING:
+ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.templates.mem_feedback_prompts import (
FEEDBACK_ANSWER_PROMPT,
@@ -90,6 +94,7 @@ def __init__(self, config: MemFeedbackConfig):
self.stopword_manager = StopwordManager
self.searcher: Searcher = None
self.reranker = None
+ self.pref_mem: SimplePreferenceTextMemory = None
self.DB_IDX_READY = False
@require_python_package(
@@ -115,7 +120,7 @@ def _retry_db_operation(self, operation):
return operation()
except Exception as e:
logger.error(
- f"[1223 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True
+ f"[0107 Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True
)
raise
@@ -129,7 +134,7 @@ def _batch_embed(self, texts: list[str], embed_bs: int = 5):
results.extend(self._embed_once(batch))
except Exception as e:
logger.error(
- f"[1223 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}"
+ f"[0107 Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}"
)
results.extend([[0.0] * dim for _ in range(len(batch))])
return results
@@ -145,7 +150,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i
lambda: self.memory_manager.add(to_add_memories, user_name=user_name, use_batch=False)
)
logger.info(
- f"[1223 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}."
+ f"[0107 Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}."
)
return {
"record": {
@@ -177,12 +182,12 @@ def _keyword_replace_judgement(self, feedback_content: str) -> dict | None:
user_feedback=feedback_content,
)
- judge_res = self._get_llm_response(prompt)
+ judge_res = self._get_llm_response(prompt, load_type="bracket")
if judge_res:
return judge_res
else:
logger.warning(
- "[1223 Feedback Core: _feedback_judgement] feedback judgement failed, return []"
+ "[0107 Feedback Core: _feedback_judgement] feedback judgement failed, return []"
)
return {}
@@ -202,12 +207,12 @@ def _feedback_judgement(
feedback_time=feedback_time,
)
- judge_res = self._get_llm_response(prompt)
+ judge_res = self._get_llm_response(prompt, load_type="square_bracket")
if judge_res:
return judge_res
else:
logger.warning(
- "[1223 Feedback Core: _feedback_judgement] feedback judgement failed, return []"
+ "[0107 Feedback Core: _feedback_judgement] feedback judgement failed, return []"
)
return []
@@ -271,6 +276,14 @@ def _single_update_operation(
"""
Individual update operations
"""
+ if "preference" in old_memory_item.metadata.__dict__:
+ logger.info(
+ f"[0107 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}"
+ )
+ return self._single_update_pref(
+ old_memory_item, new_memory_item, user_id, user_name, operation
+ )
+
memory_type = old_memory_item.metadata.memory_type
source_doc_id = (
old_memory_item.metadata.file_ids[0]
@@ -281,6 +294,7 @@ def _single_update_operation(
)
if operation and "text" in operation and operation["text"]:
new_memory_item.memory = operation["text"]
+ new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0]
if memory_type == "WorkingMemory":
fields = {
@@ -317,6 +331,68 @@ def _single_update_operation(
"origin_memory": old_memory_item.memory,
}
+ def _single_update_pref(
+ self,
+ old_memory_item: TextualMemoryItem,
+ new_memory_item: TextualMemoryItem,
+ user_id: str,
+ user_name: str,
+ operation: dict,
+ ):
+ """update preference memory"""
+
+ feedback_context = new_memory_item.memory
+ if operation and "text" in operation and operation["text"]:
+ new_memory_item.memory = operation["text"]
+ new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0]
+
+ to_add_memory = old_memory_item.model_copy(deep=True)
+ to_add_memory.metadata.key = new_memory_item.metadata.key
+ to_add_memory.metadata.tags = new_memory_item.metadata.tags
+ to_add_memory.memory = new_memory_item.memory
+ to_add_memory.metadata.preference = new_memory_item.memory
+ to_add_memory.metadata.embedding = new_memory_item.metadata.embedding
+
+ to_add_memory.metadata.user_id = new_memory_item.metadata.user_id
+ to_add_memory.metadata.original_text = old_memory_item.memory
+ to_add_memory.metadata.covered_history = old_memory_item.id
+
+ to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = (
+ datetime.now().isoformat()
+ )
+ to_add_memory.metadata.context_summary = (
+ old_memory_item.metadata.context_summary + " \n" + feedback_context
+ )
+
+ # add new memory
+ to_add_memory.id = str(uuid.uuid4())
+ added_ids = self._retry_db_operation(lambda: self.pref_mem.add([to_add_memory]))
+ # delete
+ deleted_id = old_memory_item.id
+ collection_name = old_memory_item.metadata.preference_type
+ self._retry_db_operation(
+ lambda: self.pref_mem.delete_with_collection_name(collection_name, [deleted_id])
+ )
+ # add archived
+ old_memory_item.metadata.status = "archived"
+ old_memory_item.metadata.original_text = "archived"
+ old_memory_item.metadata.embedding = [0.0] * 1024
+
+ archived_ids = self._retry_db_operation(lambda: self.pref_mem.add([old_memory_item]))
+
+ logger.info(
+ f"[Memory Feedback UPDATE Pref] New Add:{added_ids!s} | Set archived:{archived_ids!s}"
+ )
+
+ return {
+ "id": to_add_memory.id,
+ "text": new_memory_item.memory,
+ "source_doc_id": "",
+ "archived_id": old_memory_item.id,
+ "origin_memory": old_memory_item.memory,
+ "type": "preference",
+ }
+
def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]:
"""Delete working memory bindings"""
bindings_to_delete = extract_working_binding_ids(mem_items)
@@ -334,11 +410,11 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) ->
self.graph_store.delete_node(mid, user_name=user_name)
logger.info(
- f"[1223 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}"
+ f"[0107 Feedback Core:_del_working_binding] Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}"
)
except Exception as e:
logger.warning(
- f"[1223 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}"
+ f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}"
)
def semantics_feedback(
@@ -355,13 +431,12 @@ def semantics_feedback(
lang = detect_lang("".join(memory_item.memory))
template = FEEDBACK_PROMPT_DICT["compare"][lang]
if current_memories == []:
- # retrieve feedback
- feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name)
-
- # retrieve question
+ # retrieve
last_user_index = max(i for i, d in enumerate(chat_history_list) if d["role"] == "user")
last_qa = " ".join([item["content"] for item in chat_history_list[last_user_index:]])
supplementary_retrieved = self._retrieve(last_qa, info=info, user_name=user_name)
+ feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name)
+
ids = []
for item in feedback_retrieved + supplementary_retrieved:
if item.id not in ids:
@@ -385,9 +460,14 @@ def semantics_feedback(
with ContextThreadPoolExecutor(max_workers=10) as executor:
future_to_chunk_idx = {}
for chunk in memory_chunks:
- current_memories_str = "\n".join(
- [f"{item.id}: {item.memory}" for item in chunk]
- )
+ chunk_list = []
+ for item in chunk:
+ if "preference" in item.metadata.__dict__:
+ chunk_list.append(f"{item.id}: {item.metadata.preference}")
+ else:
+ chunk_list.append(f"{item.id}: {item.memory}")
+ current_memories_str = "\n".join(chunk_list)
+
prompt = template.format(
now_time=now_time,
current_memories=current_memories_str,
@@ -395,7 +475,7 @@ def semantics_feedback(
chat_history=history_str,
)
- future = executor.submit(self._get_llm_response, prompt)
+ future = executor.submit(self._get_llm_response, prompt, load_type="bracket")
future_to_chunk_idx[future] = chunk
for future in concurrent.futures.as_completed(future_to_chunk_idx):
try:
@@ -408,7 +488,7 @@ def semantics_feedback(
all_operations.extend(chunk_operations["operations"])
except Exception as e:
logger.error(
- f"[1223 Feedback Core: semantics_feedback] Operation failed: {e}"
+ f"[0107 Feedback Core: semantics_feedback] Operation failed: {e}"
)
standard_operations = self.standard_operations(all_operations, current_memories)
@@ -458,7 +538,7 @@ def semantics_feedback(
update_results.append(result)
except Exception as e:
logger.error(
- f"[1223 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}",
+ f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}",
exc_info=True,
)
if update_results:
@@ -486,7 +566,7 @@ def _feedback_memory(
]
if filterd_ids:
logger.warning(
- f"[1223 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}."
+ f"[0107 Feedback Core: _feedback_memory] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}."
)
current_memories = [
@@ -518,7 +598,7 @@ def _feedback_memory(
results[i] = node
except Exception as e:
logger.error(
- f"[1223 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}",
+ f"[0107 Feedback Core: _feedback_memory] Error processing memory index {i}: {e}",
exc_info=True,
)
mem_res = [r for r in results if r]
@@ -542,13 +622,18 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys:
record.append(info_v == mem_v)
return all(record)
- def _retrieve(self, query: str, info=None, top_k=100, user_name=None):
+ def _retrieve(self, query: str, info=None, top_k=20, user_name=None):
"""Retrieve memory items"""
retrieved_mems = self.searcher.search(
query, info=info, user_name=user_name, top_k=top_k, full_recall=True
)
retrieved_mems = [item[0] for item in retrieved_mems if float(item[1]) > 0.01]
- return retrieved_mems
+
+ pref_info = {}
+ if "user_id" in info:
+ pref_info = {"user_id": info["user_id"]}
+ retrieved_prefs = self.pref_mem.search(query, top_k, pref_info)
+ return retrieved_mems + retrieved_prefs
def _vec_query(self, new_memories_embedding: list[float], user_name=None):
"""Vector retrieval query"""
@@ -577,7 +662,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None):
if not retrieved_ids:
logger.info(
- f"[1223 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}."
+ f"[0107 Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}."
)
filterd_ids = [
@@ -585,7 +670,7 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None):
]
if filterd_ids:
logger.warning(
- f"[1223 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}."
+ f"[0107 Feedback Core: _vec_query] Since the tags mode is fast, no modifications are made to the following memory {filterd_ids}."
)
return [
TextualMemoryItem(**item)
@@ -593,22 +678,41 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None):
if "mode:fast" not in item["metadata"]["tags"]
]
- def _get_llm_response(self, prompt: str, dsl: bool = True) -> dict:
+ def _get_llm_response(
+ self,
+ prompt: str,
+ dsl: bool = True,
+ load_type: Literal["bracket", "square_bracket"] | None = None,
+ ) -> dict:
messages = [{"role": "user", "content": prompt}]
+ response_text = ""
try:
response_text = self.llm.generate(messages, temperature=0.3, timeout=60)
- if dsl:
+ if not dsl:
+ return response_text
+ try:
response_text = response_text.replace("```", "").replace("json", "")
cleaned_text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", "", response_text)
response_json = json.loads(cleaned_text)
- else:
- return response_text
+ return response_json
+ except (json.JSONDecodeError, ValueError) as e:
+ if load_type == "bracket":
+ response_json = extract_bracket_content(response_text)
+ return response_json
+ elif load_type == "square_bracket":
+ response_json = extract_square_brackets_content(response_text)
+ return response_json
+ else:
+ logger.error(
+ f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}"
+ )
+ return None
+
except Exception as e:
logger.error(
f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}"
)
- response_json = None
- return response_json
+ return None
def filter_fault_update(self, operations: list[dict]):
"""To address the randomness of large model outputs, it is necessary to conduct validity evaluation on the texts used for memory override operations."""
@@ -627,7 +731,7 @@ def filter_fault_update(self, operations: list[dict]):
raw_operations_str = {"operations": chunk}
prompt = template.format(raw_operations=str(raw_operations_str))
- future = executor.submit(self._get_llm_response, prompt)
+ future = executor.submit(self._get_llm_response, prompt, load_type="bracket")
future_to_chunk_idx[future] = chunk
for future in concurrent.futures.as_completed(future_to_chunk_idx):
try:
@@ -639,9 +743,9 @@ def filter_fault_update(self, operations: list[dict]):
):
all_judge.extend(judge_res["operations_judgement"])
except Exception as e:
- logger.error(f"[1223 Feedback Core: filter_fault_update] Judgement failed: {e}")
+ logger.error(f"[0107 Feedback Core: filter_fault_update] Judgement failed: {e}")
- logger.info(f"[1223 Feedback Core: filter_fault_update] LLM judgement: {all_judge}")
+ logger.info(f"[0107 Feedback Core: filter_fault_update] LLM judgement: {all_judge}")
id2op = {item["id"]: item for item in updated_operations}
valid_updates = []
for judge in all_judge:
@@ -652,7 +756,7 @@ def filter_fault_update(self, operations: list[dict]):
valid_updates.append(valid_update)
logger.info(
- f"[1223 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}"
+ f"[0107 Feedback Core: filter_fault_update] {len(updated_operations)} -> {len(valid_updates)}"
)
return valid_updates + [item for item in operations if item["operation"] != "UPDATE"]
@@ -680,11 +784,11 @@ def correct_item(data):
and "text" in data
and "old_memory" in data
and data["operation"].lower() == "update"
- )
+ ), "Invalid operation item"
if not should_keep_update(data["text"], data["old_memory"]):
logger.warning(
- f"[1223 Feedback Core: semantics_feedback] Due to the excessive proportion of changes, skip update: {data}"
+ f"[0107 Feedback Core: correct_item] Due to the excessive proportion of changes, skip update: {data}"
)
return None
@@ -704,14 +808,14 @@ def correct_item(data):
return data
except Exception:
logger.error(
- f"[1223 Feedback Core: standard_operations] Error processing operation item: {data}",
+ f"[0107 Feedback Core: standard_operations] Error processing operation item: {data}",
exc_info=True,
)
return None
dehallu_res = [correct_item(item) for item in operations]
dehalluded_operations = [item for item in dehallu_res if item]
- logger.info(f"[1223 Feedback Core: dehalluded_operations] {dehalluded_operations}")
+ logger.info(f"[0107 Feedback Core: dehalluded_operations] {dehalluded_operations}")
# c add objects
add_texts = []
@@ -725,7 +829,7 @@ def correct_item(data):
elif item["operation"].lower() == "update":
llm_operations.append(item)
logger.info(
- f"[1223 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories"
+ f"[0107 Feedback Core: deduplicate add] {len(dehalluded_operations)} -> {len(llm_operations)} memories"
)
# Update takes precedence over add
@@ -739,7 +843,7 @@ def correct_item(data):
]
if filtered_items:
logger.info(
- f"[1223 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}"
+ f"[0107 Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}"
)
return update_items
else:
@@ -787,7 +891,7 @@ def _doc_filter(self, doc_scope: str, memories: list[TextualMemoryItem]):
memid for inscope_file in inscope_docs for memid in filename2_memid[inscope_file]
]
logger.info(
- f"[1223 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}"
+ f"[0107 Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}"
)
filter_memories = [mem for mem in memories if mem.id in inscope_ids]
return filter_memories
@@ -841,7 +945,7 @@ def process_keyword_replace(
retrieved_memories = self._doc_filter(doc_scope, retrieved_memories)
logger.info(
- f"[1223 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories."
+ f"[0107 Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories."
)
if not retrieved_memories:
@@ -926,7 +1030,7 @@ def check_validity(item):
info.update({"user_id": user_id, "user_name": user_name, "session_id": session_id})
logger.info(
- f"[1223 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}"
+ f"[0107 Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}"
)
# feedback keywords update
kwp_judge = self._keyword_replace_judgement(feedback_content)
@@ -959,7 +1063,7 @@ def check_validity(item):
if not valid_feedback:
logger.warning(
- f"[1223 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}."
+ f"[0107 Feedback Core: process_feedback_core] No valid judgements for user {user_name}: {raw_judge}."
)
return {"record": {"add": [], "update": []}}
@@ -1007,13 +1111,13 @@ def check_validity(item):
add_memories = mem_record["record"]["add"]
update_memories = mem_record["record"]["update"]
logger.info(
- f"[1223 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}."
+ f"[0107 Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}."
)
return mem_record
except Exception as e:
logger.error(
- f"[1223 Feedback Core: process_feedback_core] Error for user {user_name}: {e}"
+ f"[0107 Feedback Core: process_feedback_core] Error for user {user_name}: {e}"
)
return {"record": {"add": [], "update": []}}
diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py
index 429c2ea20..e32f939c7 100644
--- a/src/memos/mem_feedback/simple_feedback.py
+++ b/src/memos/mem_feedback/simple_feedback.py
@@ -4,6 +4,7 @@
from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM
from memos.mem_feedback.feedback import MemFeedback
from memos.mem_reader.simple_struct import SimpleStructMemReader
+from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
@@ -23,6 +24,7 @@ def __init__(
mem_reader: SimpleStructMemReader,
searcher: Searcher,
reranker: BaseReranker,
+ pref_mem: SimplePreferenceTextMemory,
):
self.llm = llm
self.embedder = embedder
@@ -31,5 +33,6 @@ def __init__(
self.mem_reader = mem_reader
self.searcher = searcher
self.stopword_manager = StopwordManager
+ self.pref_mem = pref_mem
self.reranker = reranker
self.DB_IDX_READY = False
diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py
index c32c12328..8e3b2f34c 100644
--- a/src/memos/mem_feedback/utils.py
+++ b/src/memos/mem_feedback/utils.py
@@ -1,3 +1,6 @@
+import json
+import re
+
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
@@ -48,8 +51,11 @@ def calculate_similarity(text1: str, text2: str) -> float:
similarity = calculate_similarity(old_text, new_text)
change_ratio = 1 - similarity
+ if change_ratio == float(0):
+ return False
+
if old_len < 200:
- return change_ratio < 0.5
+ return change_ratio < 0.7
else:
return change_ratio < 0.2
@@ -144,3 +150,81 @@ def make_mem_item(text: str, **kwargs) -> TextualMemoryItem:
info=info_,
),
)
+
+
+def extract_bracket_content(text):
+ """
+ Extract and parse JSON content enclosed in curly braces {} from text.
+ """
+ # Strategy 1: Greedy match to capture the outermost complete brace pair
+ greedy_match = re.search(r"\{.*\}", text, re.DOTALL)
+ if greedy_match is None:
+ error_msg = f"No curly brace content found in text: {text}"
+ raise ValueError(error_msg)
+
+ greedy_content = greedy_match.group(0)
+
+ # Strategy 2: Non-greedy match to find all brace pairs, use the last one
+ non_greedy_matches = re.findall(r"\{.*?\}", text, re.DOTALL)
+ if not non_greedy_matches:
+ error_msg = f"No curly brace content found in text: {text}"
+ raise ValueError(error_msg)
+
+ non_greedy_content = non_greedy_matches[-1]
+
+ for content in [greedy_content, non_greedy_content]:
+ try:
+ parsed_data = json.loads(content)
+ return parsed_data
+ except json.JSONDecodeError:
+ continue
+
+ for content in [greedy_content, non_greedy_content]:
+ try:
+ fixed_content = content.replace("{{", "{").replace("}}", "}")
+ parsed_data = json.loads(fixed_content)
+ return parsed_data
+ except json.JSONDecodeError:
+ continue
+
+ error_msg = f"Failed to parse JSON content from curly braces. Text preview: {text}"
+ raise ValueError(error_msg)
+
+
+def extract_square_brackets_content(text):
+ """
+ Extract and parse JSON content enclosed in square brackets [] from text.
+ """
+ # Strategy 1: Greedy match to capture the outermost complete bracket pair
+ greedy_match = re.search(r"\[.*\]", text, re.DOTALL)
+ if greedy_match is None:
+ error_msg = f"No square bracket content found in text: {text}"
+ raise ValueError(error_msg)
+
+ greedy_content = greedy_match.group(0)
+
+ # Strategy 2: Non-greedy match to find all bracket pairs, use the last one
+ non_greedy_matches = re.findall(r"\[.*?\]", text, re.DOTALL)
+ if not non_greedy_matches:
+ error_msg = f"No square bracket content found in text: {text}"
+ raise ValueError(error_msg)
+
+ non_greedy_content = non_greedy_matches[-1]
+
+ for content in [greedy_content, non_greedy_content]:
+ try:
+ parsed_data = json.loads(content)
+ return parsed_data
+ except json.JSONDecodeError:
+ continue
+
+ for content in [greedy_content, non_greedy_content]:
+ try:
+ fixed_content = content.replace("{{", "{").replace("}}", "}")
+ parsed_data = json.loads(fixed_content)
+ return parsed_data
+ except json.JSONDecodeError:
+ continue
+
+ error_msg = f"Failed to parse JSON content from square brackets. Text preview: {text}"
+ raise ValueError(error_msg)
diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py
index 1a88fa831..22cd0e9cb 100644
--- a/src/memos/mem_os/core.py
+++ b/src/memos/mem_os/core.py
@@ -2,7 +2,7 @@
import os
import time
-from datetime import datetime
+from datetime import datetime, timezone
from pathlib import Path
from threading import Lock
from typing import Any, Literal
@@ -192,7 +192,7 @@ def _register_chat_history(
self.chat_history_manager[user_id] = ChatHistory(
user_id=user_id if user_id is not None else self.user_id,
session_id=session_id if session_id is not None else self.session_id,
- created_at=datetime.utcnow(),
+ created_at=datetime.now(timezone.utc),
total_messages=0,
chat_history=[],
)
@@ -311,7 +311,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
past_key_values = None
if self.config.enable_activation_memory:
- if self.config.chat_model.backend != "huggingface":
+ if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]:
logger.error(
"Activation memory only used for huggingface backend. Skipping activation memory."
)
@@ -498,7 +498,9 @@ def register_mem_cube(
existing_cube = self.user_manager.get_cube(mem_cube_id)
# check the embedder is it consistent with MOSConfig
- if self.config.mem_reader.config.embedder != (
+ if hasattr(
+ self.mem_cubes[mem_cube_id].text_mem.config, "embedder"
+ ) and self.config.mem_reader.config.embedder != (
cube_embedder := self.mem_cubes[mem_cube_id].text_mem.config.embedder
):
logger.warning(
diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py
index 0114fc0da..0dc6ab209 100644
--- a/src/memos/mem_os/main.py
+++ b/src/memos/mem_os/main.py
@@ -310,7 +310,7 @@ def _generate_enhanced_response_with_context(
# Handle activation memory if enabled (same as core method)
past_key_values = None
if self.config.enable_activation_memory:
- if self.config.chat_model.backend != "huggingface":
+ if self.config.chat_model.backend not in ["huggingface", "huggingface_singleton"]:
logger.error(
"Activation memory only used for huggingface backend. Skipping activation memory."
)
diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py
index bf9f847d0..edb7875d4 100644
--- a/src/memos/mem_os/utils/default_config.py
+++ b/src/memos/mem_os/utils/default_config.py
@@ -181,15 +181,22 @@ def get_default_cube_config(
# Configure text memory based on type
if text_mem_type == "tree_text":
# Tree text memory requires Neo4j configuration
+ # NOTE: Neo4j Community Edition does NOT support multiple databases.
+ # It only has one default database named 'neo4j'.
+ # If you are using Community Edition:
+ # 1. Set 'use_multi_db' to False (default)
+ # 2. Set 'db_name' to 'neo4j' (default)
+ # 3. Set 'auto_create' to False to avoid 'CREATE DATABASE' permission errors.
db_name = f"memos{user_id.replace('-', '').replace('_', '')}"
if not kwargs.get("use_multi_db", False):
- db_name = kwargs.get("neo4j_db_name", "defaultdb")
+ db_name = kwargs.get("neo4j_db_name", "neo4j")
+
neo4j_config = {
"uri": kwargs.get("neo4j_uri", "bolt://localhost:7687"),
"user": kwargs.get("neo4j_user", "neo4j"),
"db_name": db_name,
"password": kwargs.get("neo4j_password", "12345678"),
- "auto_create": True,
+ "auto_create": kwargs.get("neo4j_auto_create", True),
"use_multi_db": kwargs.get("use_multi_db", False),
"embedding_dimension": kwargs.get("embedding_dimension", 3072),
}
diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py
index 5fdb59058..f6e33bb31 100644
--- a/src/memos/mem_os/utils/format_utils.py
+++ b/src/memos/mem_os/utils/format_utils.py
@@ -1087,38 +1087,64 @@ def convert_activation_memory_to_serializable(
serializable_items = []
for item in act_mem_items:
+ key_layers = 0
+ val_layers = 0
+ device = "unknown"
+ dtype = "unknown"
+ key_shapes = []
+ value_shapes = []
+
+ if item.memory:
+ if hasattr(item.memory, "layers"):
+ key_layers = len(item.memory.layers)
+ val_layers = len(item.memory.layers)
+ if key_layers > 0:
+ l0 = item.memory.layers[0]
+ k0 = getattr(l0, "key_cache", getattr(l0, "keys", None))
+ if k0 is not None:
+ device = str(k0.device)
+ dtype = str(k0.dtype)
+
+ for i, layer in enumerate(item.memory.layers):
+ k = getattr(layer, "key_cache", getattr(layer, "keys", None))
+ v = getattr(layer, "value_cache", getattr(layer, "values", None))
+ if k is not None:
+ key_shapes.append({"layer": i, "shape": list(k.shape)})
+ if v is not None:
+ value_shapes.append({"layer": i, "shape": list(v.shape)})
+
+ elif hasattr(item.memory, "key_cache"):
+ key_layers = len(item.memory.key_cache)
+ val_layers = len(item.memory.value_cache)
+ if key_layers > 0 and item.memory.key_cache[0] is not None:
+ device = str(item.memory.key_cache[0].device)
+ dtype = str(item.memory.key_cache[0].dtype)
+
+ for i, key_tensor in enumerate(item.memory.key_cache):
+ if key_tensor is not None:
+ key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
+
+ for i, val_tensor in enumerate(item.memory.value_cache):
+ if val_tensor is not None:
+ value_shapes.append({"layer": i, "shape": list(val_tensor.shape)})
+
# Extract basic information that can be serialized
serializable_item = {
"id": item.id,
"metadata": item.metadata,
"memory_info": {
"type": "DynamicCache",
- "key_cache_layers": len(item.memory.key_cache) if item.memory else 0,
- "value_cache_layers": len(item.memory.value_cache) if item.memory else 0,
- "device": str(item.memory.key_cache[0].device)
- if item.memory and item.memory.key_cache
- else "unknown",
- "dtype": str(item.memory.key_cache[0].dtype)
- if item.memory and item.memory.key_cache
- else "unknown",
+ "key_cache_layers": key_layers,
+ "value_cache_layers": val_layers,
+ "device": device,
+ "dtype": dtype,
},
}
# Add tensor shape information if available
- if item.memory and item.memory.key_cache:
- key_shapes = []
- value_shapes = []
-
- for i, key_tensor in enumerate(item.memory.key_cache):
- if key_tensor is not None:
- key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
-
- if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None:
- value_shapes.append(
- {"layer": i, "shape": list(item.memory.value_cache[i].shape)}
- )
-
+ if key_shapes:
serializable_item["memory_info"]["key_shapes"] = key_shapes
+ if value_shapes:
serializable_item["memory_info"]["value_shapes"] = value_shapes
serializable_items.append(serializable_item)
@@ -1144,7 +1170,19 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[
total_parameters = 0
for item in act_mem_items:
- if item.memory and item.memory.key_cache:
+ if not item.memory:
+ continue
+
+ if hasattr(item.memory, "layers"):
+ total_layers += len(item.memory.layers)
+ for layer in item.memory.layers:
+ k = getattr(layer, "key_cache", getattr(layer, "keys", None))
+ v = getattr(layer, "value_cache", getattr(layer, "values", None))
+ if k is not None:
+ total_parameters += k.numel()
+ if v is not None:
+ total_parameters += v.numel()
+ elif hasattr(item.memory, "key_cache"):
total_layers += len(item.memory.key_cache)
# Calculate approximate parameter count
diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py
index 48be9b72c..2ed1af53e 100644
--- a/src/memos/mem_reader/multi_modal_struct.py
+++ b/src/memos/mem_reader/multi_modal_struct.py
@@ -10,6 +10,7 @@
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
from memos.mem_reader.read_multi_modal.base import _derive_key
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
+from memos.mem_reader.utils import parse_json_result
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
@@ -377,7 +378,7 @@ def _get_llm_response(
messages = [{"role": "user", "content": prompt}]
try:
response_text = self.llm.generate(messages)
- response_json = self.parse_json_result(response_text)
+ response_json = parse_json_result(response_text)
except Exception as e:
logger.error(f"[LLM] Exception during chat generation: {e}")
response_json = {
diff --git a/src/memos/mem_reader/read_multi_modal/base.py b/src/memos/mem_reader/read_multi_modal/base.py
index 7664f4d7f..1a756c5d0 100644
--- a/src/memos/mem_reader/read_multi_modal/base.py
+++ b/src/memos/mem_reader/read_multi_modal/base.py
@@ -258,7 +258,7 @@ def _split_text(self, text: str, is_markdown: bool = False) -> list[str]:
if not text or not text.strip():
return []
- splitter = get_text_splitter()
+ splitter = get_text_splitter(is_markdown=is_markdown)
if not splitter:
# If text splitter is not available, return text as single chunk
return [text] if text.strip() else []
diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py
index 8fa0f2454..fbc704d0b 100644
--- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py
@@ -94,6 +94,7 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, boo
response = requests.get(url_str, timeout=30)
response.raise_for_status()
+ response.encoding = "utf-8"
if not filename:
filename = os.path.basename(parsed_url.path) or "downloaded_file"
diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py
index cba8ddeda..d3d97b4e6 100644
--- a/src/memos/mem_reader/read_multi_modal/utils.py
+++ b/src/memos/mem_reader/read_multi_modal/utils.py
@@ -107,7 +107,7 @@ def _cheap_close(t: str) -> str:
"config": {},
}
-DEFAULT_CHUNK_SIZE = int(os.getenv("FILE_PARSER_CHUNK_SIZE", "1000"))
+DEFAULT_CHUNK_SIZE = int(os.getenv("FILE_PARSER_CHUNK_SIZE", "1280"))
DEFAULT_CHUNK_OVERLAP = int(os.getenv("FILE_PARSER_CHUNK_OVERLAP", "200"))
diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py
index b870bf70a..fa72bd063 100644
--- a/src/memos/mem_reader/simple_struct.py
+++ b/src/memos/mem_reader/simple_struct.py
@@ -2,7 +2,6 @@
import copy
import json
import os
-import re
import traceback
from abc import ABC
@@ -18,6 +17,13 @@
from memos.llms.factory import LLMFactory
from memos.mem_reader.base import BaseMemReader
from memos.mem_reader.read_multi_modal import coerce_scene_data, detect_lang
+from memos.mem_reader.utils import (
+ count_tokens_text,
+ derive_key,
+ parse_json_result,
+ parse_keep_filter_response,
+ parse_rewritten_response,
+)
from memos.memories.textual.item import (
SourceMessage,
TextualMemoryItem,
@@ -89,27 +95,6 @@ def from_config(_config):
}
-try:
- import tiktoken
-
- try:
- _ENC = tiktoken.encoding_for_model("gpt-4o-mini")
- except Exception:
- _ENC = tiktoken.get_encoding("cl100k_base")
-
- def _count_tokens_text(s: str) -> int:
- return len(_ENC.encode(s or "", disallowed_special=()))
-except Exception:
- # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars
- def _count_tokens_text(s: str) -> int:
- if not s:
- return 0
- zh_chars = re.findall(r"[\u4e00-\u9fff]", s)
- zh = len(zh_chars)
- rest = len(s) - zh
- return zh + max(1, rest // 4)
-
-
def _build_node(idx, message, info, source_info, llm, parse_json_result, embedder):
# generate
try:
@@ -172,14 +157,6 @@ def _build_node(idx, message, info, source_info, llm, parse_json_result, embedde
return None
-def _derive_key(text: str, max_len: int = 80) -> str:
- """default key when without LLM: first max_len words"""
- if not text:
- return ""
- sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0]
- return (sent[:max_len]).strip()
-
-
class SimpleStructMemReader(BaseMemReader, ABC):
"""Naive implementation of MemReader."""
@@ -197,7 +174,8 @@ def __init__(self, config: SimpleStructMemReaderConfig):
self.memory_max_length = 8000
# Use token-based windowing; default to ~5000 tokens if not configured
self.chat_window_max_tokens = getattr(self.config, "chat_window_max_tokens", 1024)
- self._count_tokens = _count_tokens_text
+ self._count_tokens = count_tokens_text
+ self.searcher = None
def _make_memory_item(
self,
@@ -224,7 +202,7 @@ def _make_memory_item(
memory_type=memory_type,
status="activated",
tags=tags or [],
- key=key if key is not None else _derive_key(value),
+ key=key if key is not None else derive_key(value),
embedding=self.embedder.embed([value])[0],
usage=[],
sources=sources or [],
@@ -254,7 +232,7 @@ def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict
messages = [{"role": "user", "content": prompt}]
try:
response_text = self.llm.generate(messages)
- response_json = self.parse_json_result(response_text)
+ response_json = parse_json_result(response_text)
except Exception as e:
logger.error(f"[LLM] Exception during chat generation: {e}")
response_json = {
@@ -456,47 +434,73 @@ def get_memory(
standard_scene_data = coerce_scene_data(scene_data, type)
return self._read_memory(standard_scene_data, type, info, mode)
- @staticmethod
- def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]:
- """Parse index-keyed JSON from hallucination filter response.
- Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... }
- Returns (success, parsed_dict) with int keys.
- """
+ def rewrite_memories(
+ self, messages: list[dict], memory_list: list[TextualMemoryItem], user_only: bool = True
+ ) -> list[TextualMemoryItem]:
+ # Build input objects with memory text and metadata (timestamps, sources, etc.)
+ if user_only:
+ template = PROMPT_MAPPING["rewrite_user_only"]
+ filtered_messages = [m for m in messages if m.get("role") != "assistant"]
+ if len(filtered_messages) < 1:
+ return memory_list
+ else:
+ template = PROMPT_MAPPING["rewrite"]
+ filtered_messages = messages
+ if len(filtered_messages) < 2:
+ return memory_list
+
+ prompt_args = {
+ "messages_inline": "\n".join(
+ [f"- [{message['role']}]: {message['content']}" for message in filtered_messages]
+ ),
+ "memories_inline": json.dumps(
+ {idx: mem.memory for idx, mem in enumerate(memory_list)},
+ ensure_ascii=False,
+ indent=2,
+ ),
+ }
+ prompt = template.format(**prompt_args)
+
+ # Optionally run filter and parse the output
try:
- data = json.loads(text)
- except Exception:
- return False, {}
+ raw = self.llm.generate([{"role": "user", "content": prompt}])
+ success, parsed = parse_rewritten_response(raw)
+ logger.info(
+ f"[rewrite_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}"
+ )
+ if success:
+ logger.info(f"Rewrite filter result: {parsed}")
- if not isinstance(data, dict):
- return False, {}
+ new_memory_list = []
+ for mem_idx, content in parsed.items():
+ if mem_idx < 0 or mem_idx >= len(memory_list):
+ logger.warning(
+ f"[rewrite_memories] Invalid memory index {mem_idx} for memory_list {len(memory_list)}, skipping."
+ )
+ continue
- result: dict[int, dict] = {}
- for k, v in data.items():
- try:
- idx = int(k)
- except Exception:
- # allow integer keys as-is
- if isinstance(k, int):
- idx = k
- else:
- continue
- if not isinstance(v, dict):
- continue
- need_rewrite = v.get("need_rewrite")
- rewritten = v.get("rewritten", "")
- reason = v.get("reason", "")
- if (
- isinstance(need_rewrite, bool)
- and isinstance(rewritten, str)
- and isinstance(reason, str)
- ):
- result[idx] = {
- "need_rewrite": need_rewrite,
- "rewritten": rewritten,
- "reason": reason,
- }
+ need_rewrite = content.get("need_rewrite", False)
+ rewritten_text = content.get("rewritten", "")
+ reason = content.get("reason", "")
+ original_text = memory_list[mem_idx].memory
- return (len(result) > 0), result
+ # Replace memory text with rewritten content when rewrite is needed
+ if need_rewrite and isinstance(rewritten_text, str):
+ logger.info(
+ f"[rewrite_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'"
+ )
+ if len(rewritten_text.strip()) != 0:
+ memory_list[mem_idx].memory = rewritten_text
+ new_memory_list.append(memory_list[mem_idx])
+ else:
+ new_memory_list.append(memory_list[mem_idx])
+ return new_memory_list
+ else:
+ logger.warning("Rewrite filter parsing failed or returned empty result.")
+ except Exception as e:
+ logger.error(f"Rewrite filter execution error: {e}", stack_info=True)
+
+ return memory_list
def filter_hallucination_in_memories(
self, messages: list[dict], memory_list: list[TextualMemoryItem]
@@ -520,32 +524,32 @@ def filter_hallucination_in_memories(
# Optionally run filter and parse the output
try:
raw = self.llm.generate([{"role": "user", "content": prompt}])
- success, parsed = self._parse_hallucination_filter_response(raw)
+ success, parsed = parse_keep_filter_response(raw)
logger.info(
f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}"
)
if success:
logger.info(f"Hallucination filter result: {parsed}")
- assert len(parsed) == len(memory_list)
- for mem_idx, content in parsed.items():
- need_rewrite = content.get("need_rewrite", False)
- rewritten_text = content.get("rewritten", "")
- reason = content.get("reason", "")
- # Replace memory text with rewritten content when rewrite is needed
- if (
- need_rewrite
- and isinstance(rewritten_text, str)
- and len(rewritten_text.strip()) > 0
- ):
- original_text = memory_list[mem_idx].memory
+ filtered_list = []
+ for mem_idx, mem in enumerate(memory_list):
+ content = parsed.get(mem_idx)
+ if not content:
+ logger.warning(f"No verdict for memory {mem_idx}, keeping it.")
+ filtered_list.append(mem)
+ continue
+
+ keep = content.get("keep", True)
+ reason = content.get("reason", "")
+ if keep:
+ filtered_list.append(mem)
+ else:
logger.info(
- f"[filter_hallucination_in_memories] index={mem_idx}, need_rewrite={need_rewrite}, rewritten='{rewritten_text}', reason='{reason}', original memory='{original_text}', action='replace_text'"
+ f"[filter_hallucination_in_memories] Dropping memory index={mem_idx}, reason='{reason}', memory='{mem.memory}'"
)
- memory_list[mem_idx].memory = rewritten_text
- return memory_list
+ return filtered_list
else:
logger.warning("Hallucination filter parsing failed or returned empty result.")
except Exception as e:
@@ -606,29 +610,25 @@ def _read_memory(
for group_id in range(len(memory_list)):
try:
+ original_memory_group = copy.deepcopy(memory_list[group_id])
+ serialized_origin_memories = json.dumps(
+ [one.memory for one in original_memory_group], indent=2
+ )
revised_memory_list = self.filter_hallucination_in_memories(
messages=combined_messages,
- memory_list=memory_list[group_id],
+ memory_list=original_memory_group,
)
- if len(revised_memory_list) != len(memory_list[group_id]):
- original_serialized = [
- one.memory if hasattr(one, "memory") else str(one)
- for one in memory_list[group_id]
- ]
- filtered_serialized = [
- one.memory if hasattr(one, "memory") else str(one)
- for one in revised_memory_list
- ]
- logger.error(
- f"Length mismatch after hallucination filtering for group_id={group_id}: "
- f"original={len(memory_list[group_id])}, filtered={len(revised_memory_list)}"
- f"\noriginal_memory_list(serialized): {original_serialized}"
- f"\nfiltered_memory_list(serialized): {filtered_serialized}"
- f"\nmessages: {combined_messages}"
- f"\nSkipping update and keeping original memory."
+ serialized_revised_memories = json.dumps(
+ [one.memory for one in revised_memory_list], indent=2
+ )
+ if serialized_origin_memories != serialized_revised_memories:
+ memory_list[group_id] = revised_memory_list
+ logger.info(
+ f"[SIMPLE_STRUCT_ADD_FILTER] Modified the list for group_id={group_id}: "
+ f"\noriginal={serialized_origin_memories},"
+ f"\nrevised={serialized_revised_memories}"
)
- continue
- memory_list[group_id] = revised_memory_list
+
except Exception as e:
group_serialized = [
one.memory if hasattr(one, "memory") else str(one)
@@ -847,7 +847,7 @@ def _process_doc_data(self, scene_data_info, info, **kwargs):
info,
source_info_list,
self.llm,
- self.parse_json_result,
+ parse_json_result,
self.embedder,
): idx
for idx, msg in enumerate(messages)
@@ -870,44 +870,3 @@ def _process_transfer_doc_data(
self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None
):
raise NotImplementedError
-
- def parse_json_result(self, response_text: str) -> dict:
- s = (response_text or "").strip()
-
- m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I)
- s = (m.group(1) if m else s.replace("```", "")).strip()
-
- i = s.find("{")
- if i == -1:
- return {}
- s = s[i:].strip()
-
- try:
- return json.loads(s)
- except json.JSONDecodeError:
- pass
-
- j = max(s.rfind("}"), s.rfind("]"))
- if j != -1:
- try:
- return json.loads(s[: j + 1])
- except json.JSONDecodeError:
- pass
-
- def _cheap_close(t: str) -> str:
- t += "}" * max(0, t.count("{") - t.count("}"))
- t += "]" * max(0, t.count("[") - t.count("]"))
- return t
-
- t = _cheap_close(s)
- try:
- return json.loads(t)
- except json.JSONDecodeError as e:
- if "Invalid \\escape" in str(e):
- s = s.replace("\\", "\\\\")
- return json.loads(s)
- logger.error(
- f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \
- json: {s}"
- )
- return {}
diff --git a/src/memos/mem_reader/utils.py b/src/memos/mem_reader/utils.py
new file mode 100644
index 000000000..4e5a78af2
--- /dev/null
+++ b/src/memos/mem_reader/utils.py
@@ -0,0 +1,157 @@
+import json
+import re
+
+from memos import log
+
+
+logger = log.get_logger(__name__)
+
+try:
+ import tiktoken
+
+ try:
+ _ENC = tiktoken.encoding_for_model("gpt-4o-mini")
+ except Exception:
+ _ENC = tiktoken.get_encoding("cl100k_base")
+
+ def count_tokens_text(s: str) -> int:
+ return len(_ENC.encode(s or "", disallowed_special=()))
+except Exception:
+ # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars
+ def count_tokens_text(s: str) -> int:
+ if not s:
+ return 0
+ zh_chars = re.findall(r"[\u4e00-\u9fff]", s)
+ zh = len(zh_chars)
+ rest = len(s) - zh
+ return zh + max(1, rest // 4)
+
+
+def derive_key(text: str, max_len: int = 80) -> str:
+ """default key when without LLM: first max_len words"""
+ if not text:
+ return ""
+ sent = re.split(r"[。!?!?]\s*|\n", text.strip())[0]
+ return (sent[:max_len]).strip()
+
+
+def parse_json_result(response_text: str) -> dict:
+ s = (response_text or "").strip()
+
+ m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I)
+ s = (m.group(1) if m else s.replace("```", "")).strip()
+
+ i = s.find("{")
+ if i == -1:
+ return {}
+ s = s[i:].strip()
+
+ try:
+ return json.loads(s)
+ except json.JSONDecodeError:
+ pass
+
+ j = max(s.rfind("}"), s.rfind("]"))
+ if j != -1:
+ try:
+ return json.loads(s[: j + 1])
+ except json.JSONDecodeError:
+ pass
+
+ def _cheap_close(t: str) -> str:
+ t += "}" * max(0, t.count("{") - t.count("}"))
+ t += "]" * max(0, t.count("[") - t.count("]"))
+ return t
+
+ t = _cheap_close(s)
+ try:
+ return json.loads(t)
+ except json.JSONDecodeError as e:
+ if "Invalid \\escape" in str(e):
+ s = s.replace("\\", "\\\\")
+ return json.loads(s)
+ logger.error(
+ f"[JSONParse] Failed to decode JSON: {e}\nTail: Raw {response_text} \
+ json: {s}"
+ )
+ return {}
+
+
+def parse_rewritten_response(text: str) -> tuple[bool, dict[int, dict]]:
+ """Parse index-keyed JSON from hallucination filter response.
+ Expected shape: { "0": {"need_rewrite": bool, "rewritten": str, "reason": str}, ... }
+ Returns (success, parsed_dict) with int keys.
+ """
+ try:
+ m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I)
+ s = (m.group(1) if m else text).strip()
+ data = json.loads(s)
+ except Exception:
+ return False, {}
+
+ if not isinstance(data, dict):
+ return False, {}
+
+ result: dict[int, dict] = {}
+ for k, v in data.items():
+ try:
+ idx = int(k)
+ except Exception:
+ # allow integer keys as-is
+ if isinstance(k, int):
+ idx = k
+ else:
+ continue
+ if not isinstance(v, dict):
+ continue
+ need_rewrite = v.get("need_rewrite")
+ rewritten = v.get("rewritten", "")
+ reason = v.get("reason", "")
+ if (
+ isinstance(need_rewrite, bool)
+ and isinstance(rewritten, str)
+ and isinstance(reason, str)
+ ):
+ result[idx] = {
+ "need_rewrite": need_rewrite,
+ "rewritten": rewritten,
+ "reason": reason,
+ }
+
+ return (len(result) > 0), result
+
+
+def parse_keep_filter_response(text: str) -> tuple[bool, dict[int, dict]]:
+ """Parse index-keyed JSON from keep filter response.
+ Expected shape: { "0": {"keep": bool, "reason": str}, ... }
+ Returns (success, parsed_dict) with int keys.
+ """
+ try:
+ m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.I)
+ s = (m.group(1) if m else text).strip()
+ data = json.loads(s)
+ except Exception:
+ return False, {}
+
+ if not isinstance(data, dict):
+ return False, {}
+
+ result: dict[int, dict] = {}
+ for k, v in data.items():
+ try:
+ idx = int(k)
+ except Exception:
+ if isinstance(k, int):
+ idx = k
+ else:
+ continue
+ if not isinstance(v, dict):
+ continue
+ keep = v.get("keep")
+ reason = v.get("reason", "")
+ if isinstance(keep, bool):
+ result[idx] = {
+ "keep": keep,
+ "reason": reason,
+ }
+ return (len(result) > 0), result
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 1e0ecaadb..4c9310cbb 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -74,6 +74,7 @@
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
from memos.memories.activation.kv import KVCacheMemory
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
+from memos.memories.textual.naive import NaiveTextMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
@@ -198,13 +199,16 @@ def init_mem_cube(
logger.error("mem_cube is None, cannot initialize", stack_info=True)
self.mem_cube = mem_cube
self.text_mem: TreeTextMemory = self.mem_cube.text_mem
- self.reranker: HTTPBGEReranker = self.text_mem.reranker
+ self.reranker: HTTPBGEReranker = getattr(self.text_mem, "reranker", None)
if searcher is None:
- self.searcher: Searcher = self.text_mem.get_searcher(
- manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
- moscube=False,
- process_llm=self.process_llm,
- )
+ if hasattr(self.text_mem, "get_searcher"):
+ self.searcher: Searcher = self.text_mem.get_searcher(
+ manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
+ moscube=False,
+ process_llm=self.process_llm,
+ )
+ else:
+ self.searcher = None
else:
self.searcher = searcher
self.feedback_server = feedback_server
@@ -221,7 +225,7 @@ def initialize_modules(
process_llm = chat_llm
try:
- if redis_client:
+ if redis_client and self.use_redis_queue:
self.status_tracker = TaskStatusTracker(redis_client)
if self.dispatcher:
self.dispatcher.status_tracker = self.status_tracker
@@ -301,7 +305,7 @@ def status_tracker(self) -> TaskStatusTracker | None:
available via RedisSchedulerModule. This mirrors the lazy pattern used
by `mem_cube` so downstream modules can safely access the tracker.
"""
- if self._status_tracker is None:
+ if self._status_tracker is None and self.use_redis_queue:
try:
self._status_tracker = TaskStatusTracker(self.redis)
# Propagate to submodules when created lazily
@@ -310,7 +314,8 @@ def status_tracker(self) -> TaskStatusTracker | None:
if self.memos_message_queue:
self.memos_message_queue.set_status_tracker(self._status_tracker)
except Exception as e:
- logger.warning(f"Failed to lazily initialize status_tracker: {e}", exc_info=True)
+ logger.warning(f"Failed to lazy-initialize status_tracker: {e}", exc_info=True)
+
return self._status_tracker
@status_tracker.setter
@@ -540,6 +545,29 @@ def replace_working_memory(
mem_cube=mem_cube,
log_func_callback=self._submit_web_logs,
)
+ elif isinstance(text_mem_base, NaiveTextMemory):
+ # For NaiveTextMemory, we populate the monitors with the new candidates so activation memory can pick them up
+ logger.info(
+ f"NaiveTextMemory: Updating working memory monitors with {len(new_memory)} candidates."
+ )
+
+ # Use query keywords if available, otherwise just basic monitoring
+ query_db_manager = self.monitor.query_monitors[user_id][mem_cube_id]
+ query_db_manager.sync_with_orm()
+ query_keywords = query_db_manager.obj.get_keywords_collections()
+
+ new_working_memory_monitors = self.transform_working_memories_to_monitors(
+ query_keywords=query_keywords,
+ memories=new_memory,
+ )
+
+ self.monitor.update_working_memory_monitors(
+ new_working_memory_monitors=new_working_memory_monitors,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ )
+ memories_with_new_order = new_memory
else:
logger.error("memory_base is not supported")
memories_with_new_order = new_memory
@@ -842,6 +870,8 @@ def _submit_web_logs(
messages = [messages] # transform single message to list
for message in messages:
+ if self.rabbitmq_config is None:
+ return
try:
# Always call publish; the publisher now caches when offline and flushes after reconnect
logger.info(
@@ -1008,15 +1038,28 @@ def _monitor_loop(self):
try:
q_sizes = self.memos_message_queue.qsize()
+ if not isinstance(q_sizes, dict):
+ continue
+
for stream_key, queue_length in q_sizes.items():
- # Expected format: "memos:stream:{user_id}:{mem_cube_id}" or "{user_id}"
+ # Skip aggregate keys like 'total_size'
+ if stream_key == "total_size":
+ continue
+
+ # Key format: ...:{user_id}:{mem_cube_id}:{task_label}
+ # We want to extract user_id, which is the 3rd component from the end.
parts = stream_key.split(":")
if len(parts) >= 3:
- user_id = parts[2]
- self.metrics.update_queue_length(queue_length, user_id)
- elif not self.use_redis_queue: # local queue
- user_id = stream_key
+ user_id = parts[-3]
self.metrics.update_queue_length(queue_length, user_id)
+ else:
+ # Fallback for unexpected key formats (e.g. legacy or testing)
+ # Try to use the key itself if it looks like a user_id (no colons)
+ # or just log a warning?
+ # For now, let's assume if it's not total_size and short, it might be a direct user_id key
+ # (though that shouldn't happen with current queue implementations)
+ if ":" not in stream_key:
+ self.metrics.update_queue_length(queue_length, stream_key)
except Exception as e:
logger.error(f"Error in metrics monitor loop: {e}", exc_info=True)
diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
index ba7b558fd..8fd60153d 100644
--- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
+++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
@@ -418,6 +418,7 @@ def init_components() -> dict[str, Any]:
mem_reader=mem_reader,
searcher=searcher,
reranker=feedback_reranker,
+ pref_mem=pref_mem,
)
# Return all components as a dictionary for easy access and extension
return {"naive_mem_cube": naive_mem_cube, "feedback_server": feedback_server}
diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
index 57d78676f..fd83ec86f 100644
--- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py
+++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
@@ -55,7 +55,11 @@ def create_autofilled_log_item(
"mem_cube is None — this should not happen in production!", stack_info=True
)
text_mem_base: TreeTextMemory = mem_cube.text_mem
- current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id)
+
+ current_memory_sizes = {}
+ if hasattr(text_mem_base, "get_current_memory_size"):
+ current_memory_sizes = text_mem_base.get_current_memory_size(user_name=mem_cube_id)
+
current_memory_sizes = {
"long_term_memory_size": current_memory_sizes.get("LongTermMemory", 0),
"user_memory_size": current_memory_sizes.get("UserMemory", 0),
@@ -63,14 +67,32 @@ def create_autofilled_log_item(
"transformed_act_memory_size": NOT_INITIALIZED,
"parameter_memory_size": NOT_INITIALIZED,
}
+
memory_capacities = {
- "long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"],
- "user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"],
- "working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"],
+ "long_term_memory_capacity": 0,
+ "user_memory_capacity": 0,
+ "working_memory_capacity": 0,
"transformed_act_memory_capacity": NOT_INITIALIZED,
"parameter_memory_capacity": NOT_INITIALIZED,
}
+ if hasattr(text_mem_base, "memory_manager") and hasattr(
+ text_mem_base.memory_manager, "memory_size"
+ ):
+ memory_capacities.update(
+ {
+ "long_term_memory_capacity": text_mem_base.memory_manager.memory_size.get(
+ "LongTermMemory", 0
+ ),
+ "user_memory_capacity": text_mem_base.memory_manager.memory_size.get(
+ "UserMemory", 0
+ ),
+ "working_memory_capacity": text_mem_base.memory_manager.memory_size.get(
+ "WorkingMemory", 0
+ ),
+ }
+ )
+
if hasattr(self, "monitor"):
if (
user_id in self.monitor.activation_memory_monitors
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index 86066f346..9b19e9ecb 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -34,6 +34,7 @@
is_cloud_env,
)
from memos.memories.textual.item import TextualMemoryItem
+from memos.memories.textual.naive import NaiveTextMemory
from memos.memories.textual.preference import PreferenceTextMemory
from memos.memories.textual.tree import TreeTextMemory
from memos.types import (
@@ -846,7 +847,9 @@ def _process_memories_with_reader(
memory_item = text_mem.get(mem_id, user_name=user_name)
memory_items.append(memory_item)
except Exception as e:
- logger.warning(f"Failed to get memory {mem_id}: {e}")
+ logger.warning(
+ f"[_process_memories_with_reader] Failed to get memory {mem_id}: {e}"
+ )
continue
if not memory_items:
@@ -1364,22 +1367,31 @@ def process_session_turn(
text_mem_base = mem_cube.text_mem
if not isinstance(text_mem_base, TreeTextMemory):
- logger.error(
- f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} "
- f"for mem_cube_id={mem_cube_id}, user_id={user_id}. "
- f"text_mem_base value: {text_mem_base}",
- exc_info=True,
+ if isinstance(text_mem_base, NaiveTextMemory):
+ logger.debug(
+ f"NaiveTextMemory used for mem_cube_id={mem_cube_id}, processing session turn with simple search."
+ )
+ # Treat NaiveTextMemory similar to TreeTextMemory but with simpler logic
+ # We will perform retrieval to get "working memory" candidates for activation memory
+ # But we won't have a distinct "current working memory"
+ cur_working_memory = []
+ else:
+ logger.warning(
+ f"Not implemented! Expected TreeTextMemory but got {type(text_mem_base).__name__} "
+ f"for mem_cube_id={mem_cube_id}, user_id={user_id}. "
+ f"text_mem_base value: {text_mem_base}"
+ )
+ return [], []
+ else:
+ cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory(
+ user_name=mem_cube_id
)
- return
+ cur_working_memory = cur_working_memory[:top_k]
logger.info(
f"[process_session_turn] Processing {len(queries)} queries for user_id={user_id}, mem_cube_id={mem_cube_id}"
)
- cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory(
- user_name=mem_cube_id
- )
- cur_working_memory = cur_working_memory[:top_k]
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
intent_result = self.monitor.detect_intent(
q_list=queries, text_working_memory=text_working_memory
@@ -1419,15 +1431,28 @@ def process_session_turn(
)
search_args = {}
- results: list[TextualMemoryItem] = self.retriever.search(
- query=item,
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
- top_k=k_per_evidence,
- method=self.search_method,
- search_args=search_args,
- )
+ if isinstance(text_mem_base, NaiveTextMemory):
+ # NaiveTextMemory doesn't support complex search args usually, but let's see
+ # self.retriever.search calls mem_cube.text_mem.search
+ # NaiveTextMemory.search takes query and top_k
+ # SchedulerRetriever.search handles method dispatch
+ # For NaiveTextMemory, we might need to bypass retriever or extend it
+ # But let's try calling naive memory directly if retriever fails or doesn't support it
+ try:
+ results = text_mem_base.search(query=item, top_k=k_per_evidence)
+ except Exception as e:
+ logger.warning(f"NaiveTextMemory search failed: {e}")
+ results = []
+ else:
+ results: list[TextualMemoryItem] = self.retriever.search(
+ query=item,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=mem_cube,
+ top_k=k_per_evidence,
+ method=self.search_method,
+ search_args=search_args,
+ )
logger.info(
f"[process_session_turn] Search results for missing evidence '{item}': "
diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py
index b097b1e2d..d75d6ee75 100644
--- a/src/memos/mem_scheduler/monitors/general_monitor.py
+++ b/src/memos/mem_scheduler/monitors/general_monitor.py
@@ -200,15 +200,19 @@ def update_working_memory_monitors(
mem_cube_id: str,
mem_cube: GeneralMemCube,
):
- text_mem_base: TreeTextMemory = mem_cube.text_mem
- assert isinstance(text_mem_base, TreeTextMemory)
- self.working_mem_monitor_capacity = min(
- DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT,
- (
- int(text_mem_base.memory_manager.memory_size["WorkingMemory"])
- + self.partial_retention_number
- ),
- )
+ text_mem_base = mem_cube.text_mem
+
+ if isinstance(text_mem_base, TreeTextMemory):
+ self.working_mem_monitor_capacity = min(
+ DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT,
+ (
+ int(text_mem_base.memory_manager.memory_size["WorkingMemory"])
+ + self.partial_retention_number
+ ),
+ )
+ else:
+ # Fallback for NaiveTextMemory and others
+ self.working_mem_monitor_capacity = DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT
# register monitors
self.register_memory_manager_if_not_exists(
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index c3f5891ae..7007f8418 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -186,10 +186,14 @@ def mix_search_memories(
info=info,
search_tool_memory=search_req.search_tool_memory,
tool_mem_top_k=search_req.tool_mem_top_k,
+ dedup=search_req.dedup,
)
memories = merged_memories[: search_req.top_k]
- formatted_memories = [format_textual_memory_item(item) for item in memories]
+ formatted_memories = [
+ format_textual_memory_item(item, include_embedding=search_req.dedup == "sim")
+ for item in memories
+ ]
self.submit_memory_history_async_task(
search_req=search_req,
user_context=user_context,
@@ -233,7 +237,10 @@ def update_search_memories_to_redis(
mem_cube=self.mem_cube,
mode=SearchMode.FAST,
)
- formatted_memories = [format_textual_memory_item(data) for data in memories]
+ formatted_memories = [
+ format_textual_memory_item(data, include_embedding=search_req.dedup == "sim")
+ for data in memories
+ ]
else:
memories = [
TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]
diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
index cdd491183..2099da5a1 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
@@ -108,8 +108,6 @@ def __init__(
)
self.metrics = metrics
- self._status_tracker: TaskStatusTracker | None = None
- # Use setter to allow propagation and keep a single source of truth
self.status_tracker = status_tracker
self.submit_web_logs = submit_web_logs # ADDED
@@ -118,35 +116,6 @@ def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None:
return
# This is handled in BaseScheduler now
- @property
- def status_tracker(self) -> TaskStatusTracker | None:
- """Lazy-initialized status tracker for the dispatcher.
-
- If the tracker is None, attempt to initialize from the Redis-backed
- components available to the dispatcher (queue or orchestrator).
- """
- if self._status_tracker is None:
- try:
- self._status_tracker = TaskStatusTracker(self.redis)
- # Propagate to submodules when created lazily
- if self.memos_message_queue:
- self.memos_message_queue.set_status_tracker(self._status_tracker)
- except Exception as e:
- logger.warning(f"Failed to lazily initialize status_tracker: {e}", exc_info=True)
- return self._status_tracker
-
- @status_tracker.setter
- def status_tracker(self, value: TaskStatusTracker | None) -> None:
- self._status_tracker = value
- # Propagate to the queue if possible
- try:
- if self.memos_message_queue and hasattr(self.memos_message_queue, "status_tracker"):
- self.memos_message_queue.status_tracker = value
- except Exception as e:
- logger.warning(
- f"Failed to propagate dispatcher status_tracker to queue: {e}", exc_info=True
- )
-
def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem):
"""
Create a wrapper around the handler to track task execution and capture results.
diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
index 78b38aa80..1c9683542 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
@@ -81,6 +81,7 @@ def __init__(
# Consumer state
self._is_listening = False
self._message_handler: Callable[[ScheduleMessageItem], None] | None = None
+ self.supports_xautoclaim = False
# Connection state
self._is_connected = False
@@ -105,6 +106,7 @@ def __init__(
# Auto-initialize Redis connection
if self.auto_initialize_redis():
self._is_connected = True
+ self._check_xautoclaim_support()
self.seen_streams = set()
@@ -143,6 +145,33 @@ def __init__(
logger.debug(f"Initial stream keys refresh failed: {e}")
self._start_stream_keys_refresh_thread()
+ def _check_xautoclaim_support(self):
+ """Check if the Redis server supports xautoclaim (v6.2+)."""
+ if not self._redis_conn:
+ return
+
+ try:
+ info = self._redis_conn.info("server")
+ version_str = info.get("redis_version", "0.0.0")
+ # Simple version parsing
+ parts = [int(p) for p in version_str.split(".") if p.isdigit()]
+ while len(parts) < 3:
+ parts.append(0)
+
+ major, minor, _ = parts[:3]
+ if major > 6 or (major == 6 and minor >= 2):
+ self.supports_xautoclaim = True
+ else:
+ self.supports_xautoclaim = False
+
+ logger.info(
+ f"[REDIS_QUEUE] Redis version {version_str}. "
+ f"Supports xautoclaim: {self.supports_xautoclaim}"
+ )
+ except Exception as e:
+ logger.warning(f"Failed to check Redis version: {e}")
+ self.supports_xautoclaim = False
+
def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str:
stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}"
return stream_key
@@ -623,41 +652,67 @@ def _compute_pending_need(
need_pending = max(0, batch_size - new_count)
return need_pending if need_pending > 0 else 0
+ def _parse_pending_entry(self, entry) -> tuple[str, int]:
+ """Extract message_id and idle_time from a pending entry (dict, tuple, or object)."""
+ if isinstance(entry, dict):
+ return entry.get("message_id"), entry.get("time_since_delivered")
+ elif isinstance(entry, tuple | list):
+ return entry[0], entry[2]
+ else:
+ # Assume object (redis-py 5.x+ PendingMessage)
+ return getattr(entry, "message_id", None), getattr(entry, "time_since_delivered", 0)
+
+ def _manual_xautoclaim(
+ self, stream_key: str, min_idle_time: int, count: int
+ ) -> tuple[str, list[tuple[str, dict]], list[str]]:
+ """
+ Simulate xautoclaim using xpending and xclaim for compatibility with older Redis versions.
+ """
+ # 1. Get pending entries (fetch slightly more to increase chance of finding idle ones)
+ fetch_count = count * 3
+ pending_entries = self._redis_conn.xpending_range(
+ stream_key, self.consumer_group, "-", "+", fetch_count
+ )
+
+ if not pending_entries:
+ return "0-0", [], []
+
+ claim_ids = []
+ for entry in pending_entries:
+ # entry structure depends on redis-py version/decoding
+ # Assuming list of dicts: {'message_id': '...', 'time_since_delivered': ms, ...}
+ # or list of tuples
+ msg_id, idle_time = self._parse_pending_entry(entry)
+
+ if idle_time >= min_idle_time:
+ claim_ids.append(msg_id)
+ if len(claim_ids) >= count:
+ break
+
+ if not claim_ids:
+ return "0-0", [], []
+
+ # 2. Claim messages
+ claimed_messages = self._redis_conn.xclaim(
+ stream_key, self.consumer_group, self.consumer_name, min_idle_time, claim_ids
+ )
+
+ return "0-0", claimed_messages, []
+
def _claim_pending_messages(
self, stream_key: str, need_pending_count: int, task_label: str
) -> list[tuple[str, list[tuple[str, dict]]]]:
"""Claim pending messages exceeding idle threshold, with group existence handling."""
- try:
- claimed_result = self._redis_conn.xautoclaim(
- name=stream_key,
- groupname=self.consumer_group,
- consumername=self.consumer_name,
- min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label),
- start_id="0-0",
- count=need_pending_count,
- justid=False,
- )
- if len(claimed_result) == 2:
- next_id, claimed = claimed_result
- deleted_ids = []
- elif len(claimed_result) == 3:
- next_id, claimed, deleted_ids = claimed_result
- else:
- raise ValueError(f"Unexpected xautoclaim response length: {len(claimed_result)}")
+ min_idle = self.orchestrator.get_task_idle_min(task_label=task_label)
- return [(stream_key, claimed)] if claimed else []
- except Exception as read_err:
- err_msg = str(read_err).lower()
- if "nogroup" in err_msg or "no such key" in err_msg:
- logger.warning(
- f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)."
- )
- self._ensure_consumer_group(stream_key=stream_key)
+ # Use native xautoclaim if supported (Redis 6.2+)
+ if self.supports_xautoclaim:
+ try:
claimed_result = self._redis_conn.xautoclaim(
name=stream_key,
groupname=self.consumer_group,
consumername=self.consumer_name,
- min_idle_time=self.orchestrator.get_task_idle_min(task_label=task_label),
+ min_idle_time=min_idle,
start_id="0-0",
count=need_pending_count,
justid=False,
@@ -670,23 +725,64 @@ def _claim_pending_messages(
else:
raise ValueError(
f"Unexpected xautoclaim response length: {len(claimed_result)}"
- ) from read_err
+ )
return [(stream_key, claimed)] if claimed else []
+ except Exception as read_err:
+ err_msg = str(read_err).lower()
+ if "nogroup" in err_msg or "no such key" in err_msg:
+ logger.warning(
+ f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (xautoclaim)."
+ )
+ self._ensure_consumer_group(stream_key=stream_key)
+ claimed_result = self._redis_conn.xautoclaim(
+ name=stream_key,
+ groupname=self.consumer_group,
+ consumername=self.consumer_name,
+ min_idle_time=min_idle,
+ start_id="0-0",
+ count=need_pending_count,
+ justid=False,
+ )
+ if len(claimed_result) == 2:
+ next_id, claimed = claimed_result
+ deleted_ids = []
+ elif len(claimed_result) == 3:
+ next_id, claimed, deleted_ids = claimed_result
+ else:
+ raise ValueError(
+ f"Unexpected xautoclaim response length: {len(claimed_result)}"
+ ) from read_err
+
+ return [(stream_key, claimed)] if claimed else []
+ return []
+
+ # Fallback to manual xautoclaim for older Redis versions
+ try:
+ _next, claimed, _deleted = self._manual_xautoclaim(
+ stream_key, min_idle, need_pending_count
+ )
+ return [(stream_key, claimed)] if claimed else []
+ except Exception as read_err:
+ err_msg = str(read_err).lower()
+ if "nogroup" in err_msg or "no such key" in err_msg:
+ logger.warning(
+ f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry (manual xautoclaim)."
+ )
+ self._ensure_consumer_group(stream_key=stream_key)
+ try:
+ _next, claimed, _deleted = self._manual_xautoclaim(
+ stream_key, min_idle, need_pending_count
+ )
+ return [(stream_key, claimed)] if claimed else []
+ except Exception:
+ return []
return []
- def _batch_claim_pending_messages(
+ def _batch_claim_native(
self, claims_spec: list[tuple[str, int, str]]
) -> list[tuple[str, list[tuple[str, dict]]]]:
- """Batch-claim pending messages across multiple streams.
-
- Args:
- claims_spec: List of tuples (stream_key, need_pending_count, task_label)
-
- Returns:
- A list of (stream_key, claimed_entries) pairs for all successful claims.
- """
-
+ """Batch-claim pending messages using Redis xautoclaim pipeline (Redis 6.2+)."""
pipe = self._redis_conn.pipeline(transaction=False)
for stream_key, need_count, label in claims_spec:
pipe.xautoclaim(
@@ -700,14 +796,11 @@ def _batch_claim_pending_messages(
)
try:
- # Execute with raise_on_error=False so we get exceptions in the results list
- # instead of aborting the whole batch.
results = pipe.execute(raise_on_error=False)
except Exception as e:
logger.error(f"Pipeline execution critical failure: {e}")
results = [e] * len(claims_spec)
- # Handle individual failures (e.g. NOGROUP) by retrying just that stream
final_results = []
for i, res in enumerate(results):
if isinstance(res, Exception):
@@ -734,12 +827,8 @@ def _batch_claim_pending_messages(
else:
final_results.append(res)
- results = final_results
-
- claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = []
- for (stream_key, _need_count, _label), claimed_result in zip(
- claims_spec, results, strict=False
- ):
+ claimed_pairs = []
+ for (stream_key, _, _), claimed_result in zip(claims_spec, final_results, strict=False):
try:
if not claimed_result:
continue
@@ -758,6 +847,98 @@ def _batch_claim_pending_messages(
return claimed_pairs
+ def _batch_claim_manual(
+ self, claims_spec: list[tuple[str, int, str]]
+ ) -> list[tuple[str, list[tuple[str, dict]]]]:
+ """Batch-claim pending messages using 2-phase pipeline (Redis < 6.2)."""
+ # Phase 1: Fetch pending messages for all streams
+ pending_pipe = self._redis_conn.pipeline(transaction=False)
+ for stream_key, need_count, _label in claims_spec:
+ fetch_count = need_count * 3
+ pending_pipe.xpending_range(stream_key, self.consumer_group, "-", "+", fetch_count)
+
+ try:
+ pending_results = pending_pipe.execute(raise_on_error=False)
+ except Exception as e:
+ logger.error(f"Pending fetch pipeline failed: {e}")
+ return []
+
+ # Phase 2: Filter and prepare claim pipeline
+ claim_pipe = self._redis_conn.pipeline(transaction=False)
+ streams_to_claim_indices = []
+ claimed_pairs: list[tuple[str, list[tuple[str, dict]]]] = []
+
+ for i, (stream_key, need_count, label) in enumerate(claims_spec):
+ pending_res = pending_results[i]
+ min_idle = self.orchestrator.get_task_idle_min(task_label=label)
+
+ if isinstance(pending_res, Exception):
+ err_msg = str(pending_res).lower()
+ if "nogroup" in err_msg or "no such key" in err_msg:
+ try:
+ self._ensure_consumer_group(stream_key)
+ _next, claimed, _ = self._manual_xautoclaim(
+ stream_key, min_idle, need_count
+ )
+ if claimed:
+ claimed_pairs.append((stream_key, claimed))
+ except Exception as retry_err:
+ logger.warning(f"Retry manual claim failed for {stream_key}: {retry_err}")
+ continue
+
+ if not pending_res:
+ continue
+
+ claim_ids = []
+ for entry in pending_res:
+ msg_id, idle_time = self._parse_pending_entry(entry)
+ if idle_time >= min_idle:
+ claim_ids.append(msg_id)
+ if len(claim_ids) >= need_count:
+ break
+
+ if claim_ids:
+ claim_pipe.xclaim(
+ stream_key,
+ self.consumer_group,
+ self.consumer_name,
+ min_idle,
+ claim_ids,
+ )
+ streams_to_claim_indices.append(i)
+
+ if streams_to_claim_indices:
+ try:
+ claim_results = claim_pipe.execute(raise_on_error=False)
+ for idx_in_results, original_idx in enumerate(streams_to_claim_indices):
+ res = claim_results[idx_in_results]
+ stream_key = claims_spec[original_idx][0]
+ if isinstance(res, list) and res:
+ claimed_pairs.append((stream_key, res))
+ except Exception as e:
+ logger.error(f"Claim pipeline failed: {e}")
+
+ return claimed_pairs
+
+ def _batch_claim_pending_messages(
+ self, claims_spec: list[tuple[str, int, str]]
+ ) -> list[tuple[str, list[tuple[str, dict]]]]:
+ """Batch-claim pending messages across multiple streams.
+
+ Args:
+ claims_spec: List of tuples (stream_key, need_pending_count, task_label)
+
+ Returns:
+ A list of (stream_key, claimed_entries) pairs for all successful claims.
+ """
+ if not self._redis_conn or not claims_spec:
+ return []
+
+ if self.supports_xautoclaim:
+ return self._batch_claim_native(claims_spec)
+
+ return self._batch_claim_manual(claims_spec)
+
def _convert_messages(
self, messages: list[tuple[str, list[tuple[str, dict]]]]
) -> list[ScheduleMessageItem]:
@@ -992,6 +1173,7 @@ def connect(self) -> None:
# Test the connection
self._redis_conn.ping()
self._is_connected = True
+ self._check_xautoclaim_support()
logger.debug("Redis connection established successfully")
# Start stream keys refresher when connected
self._start_stream_keys_refresh_thread()
diff --git a/src/memos/mem_scheduler/utils/api_utils.py b/src/memos/mem_scheduler/utils/api_utils.py
index c8d096517..3833b5926 100644
--- a/src/memos/mem_scheduler/utils/api_utils.py
+++ b/src/memos/mem_scheduler/utils/api_utils.py
@@ -6,14 +6,15 @@
from memos.memories.textual.tree import TextualMemoryItem
-def format_textual_memory_item(memory_data: Any) -> dict[str, Any]:
+def format_textual_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
"""Format a single memory item for API response."""
memory = memory_data.model_dump()
memory_id = memory["id"]
ref_id = f"[{memory_id.split('-')[0]}]"
memory["ref_id"] = ref_id
- memory["metadata"]["embedding"] = []
+ if not include_embedding:
+ memory["metadata"]["embedding"] = []
memory["metadata"]["sources"] = []
memory["metadata"]["ref_id"] = ref_id
memory["metadata"]["id"] = memory_id
diff --git a/src/memos/mem_scheduler/utils/status_tracker.py b/src/memos/mem_scheduler/utils/status_tracker.py
index c42ef0d0f..4977cfc3c 100644
--- a/src/memos/mem_scheduler/utils/status_tracker.py
+++ b/src/memos/mem_scheduler/utils/status_tracker.py
@@ -44,6 +44,9 @@ def task_submitted(
mem_cube_id: Memory cube identifier
business_task_id: Optional business-level task ID (one task_id can have multiple item_ids)
"""
+ if not self.redis:
+ return
+
key = self._get_key(user_id)
payload = {
"status": "waiting",
@@ -100,6 +103,9 @@ def task_completed(self, task_id: str, user_id: str):
self.redis.expire(key, timedelta(days=7))
def task_failed(self, task_id: str, user_id: str, error_message: str):
+ if not self.redis:
+ return
+
key = self._get_key(user_id)
existing_data_json = self.redis.hget(key, task_id)
if not existing_data_json:
@@ -147,6 +153,9 @@ def get_task_status_by_business_id(self, business_task_id: str, user_id: str) ->
- If any item is 'failed' → 'failed'
Returns None if task_id not found.
"""
+ if not self.redis:
+ return None
+
# Get all item_ids for this task_id
task_items_key = self._get_task_items_key(user_id, business_task_id)
item_ids = self.redis.smembers(task_items_key)
diff --git a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
index 46b2ad3d1..a07934b8e 100644
--- a/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
+++ b/src/memos/mem_scheduler/webservice_modules/rabbitmq_service.py
@@ -14,7 +14,6 @@
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue
from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE
-from memos.mem_scheduler.utils.misc_utils import is_cloud_env
logger = get_logger(__name__)
@@ -31,6 +30,7 @@ def __init__(self):
Initialize RabbitMQ connection settings.
"""
super().__init__()
+ self.auth_config = None
# RabbitMQ settings
self.rabbitmq_config: RabbitMQConfig | None = None
@@ -100,22 +100,35 @@ def initialize_rabbitmq(
)
return
+ if self.is_rabbitmq_connected():
+ logger.warning("RabbitMQ is already connected. Skipping initialization.")
+ return
+
from pika.adapters.select_connection import SelectConnection
- if config is None:
- if config_path is None and AuthConfig.default_config_exists():
- auth_config = AuthConfig.from_local_config()
- elif Path(config_path).exists():
- auth_config = AuthConfig.from_local_config(config_path=config_path)
+ if config is not None:
+ if isinstance(config, RabbitMQConfig):
+ self.rabbitmq_config = config
+ elif isinstance(config, dict):
+ self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq
else:
- auth_config = AuthConfig.from_local_env()
- self.rabbitmq_config = auth_config.rabbitmq
- elif isinstance(config, RabbitMQConfig):
- self.rabbitmq_config = config
- elif isinstance(config, dict):
- self.rabbitmq_config = AuthConfig.from_dict(config).rabbitmq
+ logger.error(f"Unsupported config type: {type(config)}")
+ return
+
else:
- logger.error("Not implemented")
+ if config_path is not None and Path(config_path).exists():
+ self.auth_config = AuthConfig.from_local_config(config_path=config_path)
+ elif AuthConfig.default_config_exists():
+ self.auth_config = AuthConfig.from_local_config()
+ else:
+ self.auth_config = AuthConfig.from_local_env()
+ self.rabbitmq_config = self.auth_config.rabbitmq
+
+ if self.rabbitmq_config is None:
+ logger.error(
+ "Failed to load RabbitMQ configuration. Please check your config file or environment variables."
+ )
+ return
# Load exchange configuration from config
if self.rabbitmq_config:
@@ -132,7 +145,16 @@ def initialize_rabbitmq(
self.rabbitmq_exchange_type = self.rabbitmq_config.exchange_type
logger.info(f"Using configured exchange type: {self.rabbitmq_exchange_type}")
- # Start connection process
+ env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
+ env_exchange_type = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_TYPE")
+ if env_exchange_name:
+ self.rabbitmq_exchange_name = env_exchange_name
+ logger.info(f"Using env exchange name override: {self.rabbitmq_exchange_name}")
+ if env_exchange_type:
+ self.rabbitmq_exchange_type = env_exchange_type
+ logger.info(f"Using env exchange type override: {self.rabbitmq_exchange_type}")
+
+ # Start connection process
parameters = self.get_rabbitmq_connection_param()
self.rabbitmq_connection = SelectConnection(
parameters,
@@ -148,7 +170,7 @@ def initialize_rabbitmq(
self._io_loop_thread.start()
logger.info("RabbitMQ connection process started")
except Exception:
- logger.error("Fail to initialize auth_config", exc_info=True)
+ logger.error("Failed to initialize RabbitMQ connection", exc_info=True)
finally:
with self._rabbitmq_lock:
self._rabbitmq_initializing = False
@@ -313,15 +335,16 @@ def rabbitmq_publish_message(self, message: dict):
if label == "knowledgeBaseUpdate":
routing_key = ""
- # Cloud environment override: applies to specific message types if MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set
+ # Env override: apply to all message types when MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME is set
env_exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME")
- if is_cloud_env() and env_exchange_name and label in ["taskStatus", "knowledgeBaseUpdate"]:
+ env_routing_key = os.getenv("MEMSCHEDULER_RABBITMQ_ROUTING_KEY")
+ if env_exchange_name:
exchange_name = env_exchange_name
- routing_key = "" # Routing key is always empty in cloud environment for these types
-
- # Specific diagnostic logging for messages affected by cloud environment settings
+ routing_key = (
+ env_routing_key if env_routing_key is not None and env_routing_key != "" else ""
+ )
logger.info(
- f"[DIAGNOSTIC] Publishing {label} message in Cloud Env. "
+ f"[DIAGNOSTIC] Publishing {label} message with env exchange override. "
f"Exchange: {exchange_name}, Routing Key: '{routing_key}'."
)
logger.info(f" - Message Content: {json.dumps(message, indent=2, ensure_ascii=False)}")
diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py
index 98d611dbf..1981b958f 100644
--- a/src/memos/memories/activation/kv.py
+++ b/src/memos/memories/activation/kv.py
@@ -2,9 +2,7 @@
import pickle
from datetime import datetime
-from importlib.metadata import version
-from packaging.version import Version
from transformers import DynamicCache
from memos.configs.memory import KVCacheMemoryConfig
@@ -211,10 +209,24 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
return caches[0]
merged = DynamicCache()
- num_layers = len(caches[0].key_cache)
- if Version(version("transformers")) >= Version("4.54.0"):
- merged.append_new_layers(num_layers - 1)
+ # Check for new structure (layers)
+ if hasattr(caches[0], "layers"):
+ num_layers = len(caches[0].layers)
+
+ # Ensure merged has layers attribute and populate it
+ if not hasattr(merged, "layers"):
+ merged.layers = []
+
+ if num_layers > 0:
+ # Get the class of the layer from the first cache
+ # We assume all caches use the same layer class
+ layer_cls = type(caches[0].layers[0])
+
+ # Populate merged.layers
+ while len(merged.layers) < num_layers:
+ merged.layers.append(layer_cls())
+
for layer in range(num_layers):
# gather all K and V for this layer
keys = [c.layers[layer].keys for c in caches]
@@ -223,7 +235,10 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
merged.layers[layer].keys = torch.cat(keys, dim=-2)
merged.layers[layer].values = torch.cat(vals, dim=-2)
- else:
+ # Check for old structure (key_cache)
+ elif hasattr(caches[0], "key_cache"):
+ num_layers = len(caches[0].key_cache)
+
for layer in range(num_layers):
# gather all K and V for this layer
keys = [c.key_cache[layer] for c in caches]
@@ -232,6 +247,11 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
merged.key_cache.append(torch.cat(keys, dim=-2))
merged.value_cache.append(torch.cat(vals, dim=-2))
+ else:
+ raise AttributeError(
+ "DynamicCache object has neither 'layers' nor 'key_cache' attributes"
+ )
+
return merged
diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py
index 144bfad7f..0c6e5339d 100644
--- a/src/memos/memories/textual/prefer_text_memory/extractor.py
+++ b/src/memos/memories/textual/prefer_text_memory/extractor.py
@@ -69,6 +69,11 @@ def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
try:
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
+ if not response:
+ logger.error(
+ f"[prefer_extractor]: (Error) LLM response content is {response} when extracting explicit preference"
+ )
+ return None
response = response.strip().replace("```json", "").replace("```", "").strip()
result = json.loads(response)
for d in result:
@@ -92,6 +97,11 @@ def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, A
try:
response = self.llm_provider.generate([{"role": "user", "content": prompt}])
+ if not response:
+ logger.error(
+ f"[prefer_extractor]: (Error) LLM response content is {response} when extracting implicit preference"
+ )
+ return None
response = response.strip().replace("```json", "").replace("```", "").strip()
result = json.loads(response)
for d in result:
diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py
index e1bc0e72b..78f4d6e28 100644
--- a/src/memos/memories/textual/preference.py
+++ b/src/memos/memories/textual/preference.py
@@ -1,6 +1,7 @@
import json
import os
+from datetime import datetime
from typing import Any
from memos.configs.memory import PreferenceTextMemoryConfig
@@ -87,6 +88,9 @@ def search(
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
+ if not isinstance(search_filter, dict):
+ search_filter = {}
+ search_filter.update({"status": "activated"})
logger.info(f"search_filter for preference memory: {search_filter}")
return self.retriever.retrieve(query, top_k, info, search_filter)
@@ -244,7 +248,7 @@ def get_all(self) -> list[TextualMemoryItem]:
Returns:
list[TextualMemoryItem]: List of all memories.
"""
- all_collections = self.vector_db.list_collections()
+ all_collections = ["explicit_preference", "implicit_preference"]
all_memories = {}
for collection_name in all_collections:
items = self.vector_db.get_all(collection_name)
@@ -258,7 +262,12 @@ def get_all(self) -> list[TextualMemoryItem]:
]
return all_memories
- def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[TextualMemoryItem]:
+ def get_memory_by_filter(
+ self,
+ filter: dict[str, Any] | None = None,
+ page: int | None = None,
+ page_size: int | None = None,
+ ):
"""Get memories by filter.
Args:
filter (dict[str, Any]): Filter criteria.
@@ -266,19 +275,35 @@ def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[Tex
list[TextualMemoryItem]: List of memories that match the filter.
"""
collection_list = self.vector_db.config.collection_name
- all_db_items = []
+
+ memories = []
for collection_name in collection_list:
db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter)
- all_db_items.extend(db_items)
- memories = [
- TextualMemoryItem(
- id=memo.id,
- memory=memo.memory,
- metadata=PreferenceTextualMemoryMetadata(**memo.payload),
- )
- for memo in all_db_items
- ]
- return memories
+ db_items_memory = [
+ TextualMemoryItem(
+ id=memo.id,
+ memory=memo.memory,
+ metadata=PreferenceTextualMemoryMetadata(**memo.payload),
+ )
+ for memo in db_items
+ ]
+ memories.extend(db_items_memory)
+
+ # sort
+ sorted_memories = sorted(
+ memories,
+ key=lambda item: datetime.fromisoformat(item.metadata.created_at),
+ reverse=True,
+ )
+ if page and page_size:
+ if page < 1:
+ page = 1
+ if page_size < 1:
+ page_size = 10
+ pick_memories = sorted_memories[(page - 1) * page_size : page * page_size]
+ return pick_memories, len(sorted_memories)
+
+ return sorted_memories, len(sorted_memories)
def delete(self, memory_ids: list[str]) -> None:
"""Delete memories.
@@ -289,6 +314,15 @@ def delete(self, memory_ids: list[str]) -> None:
for collection_name in collection_list:
self.vector_db.delete(collection_name, memory_ids)
+ def delete_by_filter(self, filter: dict[str, Any]) -> None:
+ """Delete memories by filter.
+ Args:
+ filter (dict[str, Any]): Filter criteria.
+ """
+ collection_list = self.vector_db.config.collection_name
+ for collection_name in collection_list:
+ self.vector_db.delete_by_filter(collection_name=collection_name, filter=filter)
+
def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None:
"""Delete memories by their IDs and collection name.
Args:
diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py
index 1f02132bb..cc1781f06 100644
--- a/src/memos/memories/textual/simple_preference.py
+++ b/src/memos/memories/textual/simple_preference.py
@@ -61,6 +61,9 @@ def search(
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
+ if not isinstance(search_filter, dict):
+ search_filter = {}
+ search_filter.update({"status": "activated"})
return self.retriever.retrieve(query, top_k, info, search_filter)
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
@@ -87,7 +90,7 @@ def get_with_collection_name(
return None
return TextualMemoryItem(
id=res.id,
- memory=res.payload.get("dialog_str", ""),
+ memory=res.memory,
metadata=PreferenceTextualMemoryMetadata(**res.payload),
)
except Exception as e:
@@ -113,7 +116,7 @@ def get_by_ids_with_collection_name(
return [
TextualMemoryItem(
id=memo.id,
- memory=memo.payload.get("dialog_str", ""),
+ memory=memo.memory,
metadata=PreferenceTextualMemoryMetadata(**memo.payload),
)
for memo in res
@@ -129,14 +132,14 @@ def get_all(self) -> list[TextualMemoryItem]:
Returns:
list[TextualMemoryItem]: List of all memories.
"""
- all_collections = self.vector_db.list_collections()
+ all_collections = ["explicit_preference", "implicit_preference"]
all_memories = {}
for collection_name in all_collections:
items = self.vector_db.get_all(collection_name)
all_memories[collection_name] = [
TextualMemoryItem(
id=memo.id,
- memory=memo.payload.get("dialog_str", ""),
+ memory=memo.memory,
metadata=PreferenceTextualMemoryMetadata(**memo.payload),
)
for memo in items
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index 22545496a..c486e6cf6 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -161,6 +161,7 @@ def search(
user_name: str | None = None,
search_tool_memory: bool = False,
tool_mem_top_k: int = 6,
+ dedup: str | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
@@ -207,6 +208,7 @@ def search(
user_name=user_name,
search_tool_memory=search_tool_memory,
tool_mem_top_k=tool_mem_top_k,
+ dedup=dedup,
**kwargs,
)
@@ -319,13 +321,21 @@ def get_by_ids(
) -> list[TextualMemoryItem]:
raise NotImplementedError
- def get_all(self, user_name: str | None = None) -> dict:
+ def get_all(
+ self,
+ user_name: str,
+ user_id: str | None = None,
+ page: int | None = None,
+ page_size: int | None = None,
+ ) -> dict:
"""Get all memories.
Returns:
list[TextualMemoryItem]: List of all memories.
"""
- all_items = self.graph_store.export_graph(user_name=user_name)
- return all_items
+ graph_output = self.graph_store.export_graph(
+ user_name=user_name, user_id=user_id, page=page, page_size=page_size
+ )
+ return graph_output
def delete(self, memory_ids: list[str], user_name: str | None = None) -> None:
"""Hard delete: permanently remove nodes and their edges from the graph."""
@@ -337,6 +347,13 @@ def delete(self, memory_ids: list[str], user_name: str | None = None) -> None:
except Exception as e:
logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}")
+ def delete_by_memory_ids(self, memory_ids: list[str]) -> None:
+ """Delete memories by memory_ids."""
+ try:
+ self.graph_store.delete_node_by_prams(memory_ids=memory_ids)
+ except Exception as e:
+ logger.error(f"An error occurred while deleting memories by memory_ids: {e}")
+
def delete_all(self) -> None:
"""Delete all memories and their relationships from the graph store."""
try:
@@ -348,7 +365,7 @@ def delete_all(self) -> None:
def delete_by_filter(
self,
- writable_cube_ids: list[str],
+ writable_cube_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
) -> None:
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py
index d9398a22c..5a82883c8 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py
@@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any
-
import numpy as np
from memos.dependency import require_python_package
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index cad9ab64b..3612d37eb 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -119,9 +119,13 @@ def post_retrieve(
info=None,
search_tool_memory: bool = False,
tool_mem_top_k: int = 6,
+ dedup: str | None = None,
plugin=False,
):
- deduped = self._deduplicate_results(retrieved_results)
+ if dedup == "no":
+ deduped = retrieved_results
+ else:
+ deduped = self._deduplicate_results(retrieved_results)
final_results = self._sort_and_trim(
deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
)
@@ -141,6 +145,7 @@ def search(
user_name: str | None = None,
search_tool_memory: bool = False,
tool_mem_top_k: int = 6,
+ dedup: str | None = None,
**kwargs,
) -> list[TextualMemoryItem]:
"""
@@ -202,6 +207,7 @@ def search(
plugin=kwargs.get("plugin", False),
search_tool_memory=search_tool_memory,
tool_mem_top_k=tool_mem_top_k,
+ dedup=dedup,
)
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -713,6 +719,8 @@ def _retrieve_simple(
)
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
documents = [getattr(item, "memory", "") for item in items]
+ if not documents:
+ return []
documents_embeddings = self.embedder.embed(documents)
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
index e1ce859bf..f4d6c4847 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
@@ -5,7 +5,10 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
-from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
+from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
+ FastTokenizer,
+ parse_json_result,
+)
from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT
@@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
for attempt_times in range(attempts):
try:
context = kwargs.get("context", "")
- response = response.replace("```", "").replace("json", "").strip()
- response_json = eval(response)
+ response_json = parse_json_result(response)
+ if not response_json:
+ raise ValueError("Parsed JSON is empty")
+
return ParsedTaskGoal(
memories=response_json.get("memories", []),
keys=response_json.get("keys", []),
@@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
context=context,
)
except Exception as e:
- raise ValueError(
- f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}"
- ) from e
+ if attempt_times == attempts - 1:
+ raise ValueError(
+ f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}"
+ ) from e
+ continue
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py
index bcd47b078..54caa20f7 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/utils.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/utils.py
@@ -27,7 +27,7 @@
"tags": [...],
"goal_type": "retrieval | qa | generation",
"rephrased_instruction": "...", # return an empty string if the original instruction is easy enough to understand
- "internet_search": True/False,
+ "internet_search": true/false,
"memories": ["...", "...", ...]
}
"""
diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py
index 57f2cdba1..6c3cc0cc7 100644
--- a/src/memos/multi_mem_cube/single_cube.py
+++ b/src/memos/multi_mem_cube/single_cube.py
@@ -15,6 +15,7 @@
)
from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
+from memos.mem_reader.utils import parse_keep_filter_response
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import (
ADD_TASK_LABEL,
@@ -23,6 +24,7 @@
PREF_ADD_TASK_LABEL,
)
from memos.multi_mem_cube.views import MemCubeView
+from memos.templates.mem_reader_prompts import PROMPT_MAPPING
from memos.types.general_types import (
FINE_STRATEGY,
FineStrategy,
@@ -41,6 +43,7 @@
from memos.mem_cube.navie import NaiveMemCube
from memos.mem_reader.simple_struct import SimpleStructMemReader
from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
+ from memos.memories.textual.item import TextualMemoryItem
@dataclass
@@ -54,6 +57,7 @@ class SingleCubeView(MemCubeView):
feedback_server: Any | None = None
deepsearch_agent: Any | None = None
+ @timed
def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]:
"""
This is basically your current handle_add_memories logic,
@@ -100,6 +104,7 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]:
return all_memories
+ @timed
def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
# Create UserContext object
user_context = UserContext(
@@ -147,6 +152,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
self.logger.info(f"Search {len(memories_result)} memories.")
return memories_result
+ @timed
def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]:
target_session_id = feedback_req.session_id or "default_session"
if feedback_req.async_mode == "async":
@@ -258,7 +264,10 @@ def _deep_search(
search_filter=search_filter,
info=info,
)
- formatted_memories = [format_memory_item(data) for data in enhanced_memories]
+ formatted_memories = [
+ format_memory_item(data, include_embedding=search_req.dedup == "sim")
+ for data in enhanced_memories
+ ]
return formatted_memories
def _agentic_search(
@@ -267,7 +276,10 @@ def _agentic_search(
deepsearch_results = self.deepsearch_agent.run(
search_req.query, user_id=user_context.mem_cube_id
)
- formatted_memories = [format_memory_item(data) for data in deepsearch_results]
+ formatted_memories = [
+ format_memory_item(data, include_embedding=search_req.dedup == "sim")
+ for data in deepsearch_results
+ ]
return formatted_memories
def _fine_search(
@@ -322,6 +334,7 @@ def _fine_search(
top_k=search_req.top_k,
user_name=user_context.mem_cube_id,
info=info,
+ dedup=search_req.dedup,
)
# Enhance with query
@@ -372,8 +385,13 @@ def _dedup_by_content(memories: list) -> list:
unique_memories.append(mem)
return unique_memories
- deduped_memories = _dedup_by_content(enhanced_memories)
- formatted_memories = [format_memory_item(data) for data in deduped_memories]
+ deduped_memories = (
+ enhanced_memories if search_req.dedup == "no" else _dedup_by_content(enhanced_memories)
+ )
+ formatted_memories = [
+ format_memory_item(data, include_embedding=search_req.dedup == "sim")
+ for data in deduped_memories
+ ]
logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}")
@@ -457,9 +475,13 @@ def _fast_search(
plugin=plugin,
search_tool_memory=search_req.search_tool_memory,
tool_mem_top_k=search_req.tool_mem_top_k,
+ dedup=search_req.dedup,
)
- formatted_memories = [format_memory_item(data) for data in search_results]
+ formatted_memories = [
+ format_memory_item(data, include_embedding=search_req.dedup == "sim")
+ for data in search_results
+ ]
return formatted_memories
@@ -551,6 +573,7 @@ def _schedule_memory_tasks(
)
self.mem_scheduler.submit_messages(messages=[message_item_add])
+ @timed
def _process_pref_mem(
self,
add_req: APIADDRequest,
@@ -631,6 +654,105 @@ def _process_pref_mem(
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
]
+ def add_before_search(
+ self,
+ messages: list[dict],
+ memory_list: list[TextualMemoryItem],
+ user_name: str,
+ info: dict[str, Any],
+ ) -> list[TextualMemoryItem]:
+ # Build input objects with memory text and metadata (timestamps, sources, etc.)
+ template = PROMPT_MAPPING["add_before_search"]
+
+ if not self.searcher:
+ self.logger.warning("[add_before_search] Searcher is not initialized, skipping check.")
+ return memory_list
+
+ # 1. Gather candidates and search for related memories
+ candidates_data = []
+ for idx, mem in enumerate(memory_list):
+ try:
+ related_memories = self.searcher.search(
+ query=mem.memory, top_k=3, mode="fast", user_name=user_name, info=info
+ )
+ related_text = "None"
+ if related_memories:
+ related_text = "\n".join([f"- {r.memory}" for r in related_memories])
+
+ candidates_data.append(
+ {"idx": idx, "new_memory": mem.memory, "related_memories": related_text}
+ )
+ except Exception as e:
+ self.logger.error(
+ f"[add_before_search] Search error for memory '{mem.memory}': {e}"
+ )
+ # If search fails, we can either skip this check or treat related as empty
+ candidates_data.append(
+ {
+ "idx": idx,
+ "new_memory": mem.memory,
+ "related_memories": "None (Search Failed)",
+ }
+ )
+
+ if not candidates_data:
+ return memory_list
+
+ # 2. Build Prompt
+ messages_inline = "\n".join(
+ [
+ f"- [{message.get('role', 'unknown')}]: {message.get('content', '')}"
+ for message in messages
+ ]
+ )
+
+ candidates_inline_dict = {
+ str(item["idx"]): {
+ "new_memory": item["new_memory"],
+ "related_memories": item["related_memories"],
+ }
+ for item in candidates_data
+ }
+
+ candidates_inline = json.dumps(candidates_inline_dict, ensure_ascii=False, indent=2)
+
+ prompt = template.format(
+ messages_inline=messages_inline, candidates_inline=candidates_inline
+ )
+
+ # 3. Call LLM
+ try:
+ raw = self.mem_reader.llm.generate([{"role": "user", "content": prompt}])
+ success, parsed_result = parse_keep_filter_response(raw)
+
+ if not success:
+ self.logger.warning(
+ "[add_before_search] Failed to parse LLM response, keeping all."
+ )
+ return memory_list
+
+ # 4. Filter
+ filtered_list = []
+ for idx, mem in enumerate(memory_list):
+ res = parsed_result.get(idx)
+ if not res:
+ filtered_list.append(mem)
+ continue
+
+ if res.get("keep", True):
+ filtered_list.append(mem)
+ else:
+ self.logger.info(
+ f"[add_before_search] Dropping memory: '{mem.memory}', reason: '{res.get('reason')}'"
+ )
+
+ return filtered_list
+
+ except Exception as e:
+ self.logger.error(f"[add_before_search] LLM execution error: {e}")
+ return memory_list
+
+ @timed
def _process_text_mem(
self,
add_req: APIADDRequest,
diff --git a/src/memos/reranker/concat.py b/src/memos/reranker/concat.py
index 502af18b6..b39496a1c 100644
--- a/src/memos/reranker/concat.py
+++ b/src/memos/reranker/concat.py
@@ -83,10 +83,18 @@ def concat_original_source(
merge_field = ["sources"] if rerank_source is None else rerank_source.split(",")
documents = []
for item in graph_results:
- memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
+ m = item.get("memory") if isinstance(item, dict) else getattr(item, "memory", None)
+
+ memory = _TAG1.sub("", m) if isinstance(m, str) else m
+
sources = []
for field in merge_field:
- source = getattr(item.metadata, field, None)
+ if isinstance(item, dict):
+ metadata = item.get("metadata", {})
+ source = metadata.get(field) if isinstance(metadata, dict) else None
+ else:
+ source = getattr(item.metadata, field, None) if hasattr(item, "metadata") else None
+
if source is None:
continue
sources.append((memory, source))
diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py
index 4e9054f1e..32034cf6d 100644
--- a/src/memos/reranker/http_bge.py
+++ b/src/memos/reranker/http_bge.py
@@ -129,7 +129,7 @@ def __init__(
def rerank(
self,
query: str,
- graph_results: list[TextualMemoryItem],
+ graph_results: list[TextualMemoryItem] | list[dict[str, Any]],
top_k: int,
search_priority: dict | None = None,
**kwargs,
@@ -164,11 +164,15 @@ def rerank(
if self.rerank_source:
documents = concat_original_source(graph_results, self.rerank_source)
else:
- documents = [
- (_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
- for item in graph_results
- ]
- documents = [d for d in documents if isinstance(d, str) and d]
+ documents = []
+ filtered_graph_results = []
+ for item in graph_results:
+ m = item.get("memory") if isinstance(item, dict) else getattr(item, "memory", None)
+
+ if isinstance(m, str) and m:
+ documents.append(_TAG1.sub("", m))
+ filtered_graph_results.append(item)
+ graph_results = filtered_graph_results
logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...")
diff --git a/src/memos/reranker/strategies/concat_docsource.py b/src/memos/reranker/strategies/concat_docsource.py
index 0fb471218..d90452995 100644
--- a/src/memos/reranker/strategies/concat_docsource.py
+++ b/src/memos/reranker/strategies/concat_docsource.py
@@ -54,6 +54,7 @@ def prepare_documents(
original_items = {}
tracker = DialogueRankingTracker()
documents = []
+ documents_set = set()
for item in graph_results:
memory = getattr(item, "memory", None)
if isinstance(memory, str):
@@ -66,7 +67,11 @@ def prepare_documents(
if source.type == "file":
chunk_text += source.content
if chunk_text:
- documents.append(f"{memory}\n\n[Sources]:\n{chunk_text}")
+ if chunk_text in documents_set:
+ continue
+ else:
+ documents_set.add(chunk_text)
+ documents.append(f"{memory}\n\n[Sources]:\n{chunk_text}")
else:
documents.append(memory)
return tracker, original_items, documents
diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py
index fef3ee6c0..20f8150b7 100644
--- a/src/memos/templates/mem_reader_prompts.py
+++ b/src/memos/templates/mem_reader_prompts.py
@@ -69,6 +69,26 @@
"summary": "Tom is currently focused on managing a new project with a tight schedule. After a team meeting on June 25, 2025, he realized the original deadline of December 15 might not be feasible due to backend delays. Concerned about insufficient testing time, he welcomed Jerry’s suggestion of proposing an extension. Tom plans to raise the idea of shifting the deadline to January 5, 2026 in the next morning’s meeting. His actions reflect both stress about timelines and a proactive, team-oriented problem-solving approach."
}
+Dialogue:
+assistant: [10:30 AM, August 15, 2025]: The book Deep Work you mentioned is
+indeed very suitable for your current situation. The book explains … (omitted). The author suggests setting aside 2–3 hours of focused work blocks each day and turning off all notifications during that time. Considering that you need to submit a report next week, you could try using the 9:00–11:00 AM time slot for focused work.
+
+Output:
+{
+ "memory list": [
+ {
+ "key": "Deep Work Book Recommendation",
+ "memory_type": "LongTermMemory",
+ "value": "On August 15, 2025, the assistant recommended the book 'Deep Work' to the user and introduced its suggestion of reserving 2–3 hours per day for focused work while turning off all notifications. Based on the user's need to submit a report the following week, the assistant also suggested trying 9:00–11:00 AM as a focused work time block.",
+ "tags": ["book recommendation", "deep work", "time management", "report"]
+ }
+ ],
+ "summary": "The assistant recommended the book 'Deep Work' to the user and introduced the work methods discussed in the book."
+}
+
+Note: When the dialogue contains only assistant messages, phrasing such as
+“assistant recommended” or “assistant suggested” should be used, rather than incorrectly attributing the content to the user’s statements or plans.
+
Another Example in Chinese (注意: 当user的语言为中文时,你就需要也输出中文):
{
"memory list": [
@@ -163,6 +183,25 @@
"summary": "Tom目前正专注于管理一个进度紧张的新项目。在2025年6月25日的团队会议后,他意识到原定2025年12月15日的截止日期可能无法实现,因为后端会延迟。由于担心测试时间不足,他接受了Jerry提出的延期建议。Tom计划在次日早上的会议上提出将截止日期推迟至2026年1月5日。他的行为反映出对时间线的担忧,以及积极、以团队为导向的问题解决方式。"
}
+对话:
+assistant: [2025年8月15日上午10:30]:
+你提到的那本《深度工作》确实很适合你现在的情况。这本书讲了......(略),作者建议每天留出2-3
+小时的专注时间块,期间关闭所有通知。考虑到你下周要交的报告,可以试试早上9点到11点这个时段。
+
+输出:
+{
+ "memory list": [
+ {
+ "key": "深度工作书籍推荐",
+ "memory_type": "LongTermMemory",
+ "value": "2025年8月15日助手向用户推荐了《深度工作》一书,并介绍了书中建议的每天留出2-3小时专注时间块、关闭所有通知的方法。助手还根据用户下周需要提交报告的情况,建议用户尝试早上9点到11点作为专注时段。",
+ "tags": ["书籍推荐", "深度工作", "时间管理", "报告"]
+ }
+ ],
+ "summary": "助手向用户推荐了《深度工作》一书,并介绍了了其中的工作方法"
+}
+注意:当对话仅有助手消息时,应使用"助手推荐"、"助手建议"等表述,而非将其错误归因为用户的陈述或计划。
+
另一个中文示例(注意:当用户语言为中文时,您也需输出中文):
{
"memory list": [
@@ -622,23 +661,21 @@
专注于从图像中提取事实性、可观察的信息。除非与用户记忆明显相关,否则避免推测。"""
-SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """
+SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """
You are a strict, language-preserving memory validator and rewriter.
Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content.
Rules:
1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching.
-2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, emotional labels, summaries, or generalizations.
-3. **Ambiguity Elimination**:
- - Replace vague pronouns (e.g., “he”, “it”, “they”) with clear, specific entities **only if** the messages identify them.
- - Convert relative time expressions (e.g., “yesterday”) to absolute dates **only if** the messages provide enough temporal context.
+2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED.
4. **Hallucination Removal**:
- - If a memory contains **any content not verbatim or directly implied by the user**, it must be rewritten.
- - Do **not** rephrase inferences as facts. Instead, either:
- - Remove the unsupported part and retain only the grounded core, or
- - If the entire memory is ungrounded, mark it for rewrite and make the lack of user support explicit.
+- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten.
+- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific.
+- Do **not** rephrase inferences as facts. Instead, either:
+- Remove the unsupported part and retain only the grounded core.
5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged.
+6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite.
Inputs:
messages:
@@ -651,16 +688,188 @@
- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices.
- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }}
- The "reason" must be brief and precise, e.g.:
- - "contains unsupported inference"
- - "vague pronoun with no referent in messages"
- - "relative time resolved to 2025-12-16"
+ - "contains unsupported inference ...."
- "fully grounded and concise"
Important: Output **only** the JSON. No extra text, explanations, markdown, or fields.
"""
+SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT = """
+You are a strict, language-preserving memory validator and rewriter.
+
+Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content.
+
+Rules:
+1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching.
+2. **Strict Factual Grounding**: Include only what is explicitly stated by the user in messages marked as [user]. Remove or flag anything not directly present in the user’s utterances—no assumptions, interpretations, predictions, generalizations, or content originating solely from [assistant].
+3. **Source Attribution Requirement**:
+ - Every memory must be clearly traceable to its source:
+ - If a fact appears **only in [assistant] messages** and **is not affirmed by [user]**, label it as “[assistant] memory”.
+ - If [assistant] states something and [user] explicitly contradicts or denies it, label it as “[assistant] memory, but [user] [brief quote or summary of denial]”.
+ - If a fact is stated by [user] —whether or not [assistant] also mentions it— it is attributed to “[user]” and may be retained without qualification.
+4. **Timestamp Exception**: Memories may include timestamps (e.g., "On December 19, 2026") derived from conversation metadata. If such a date likely reflects the conversation time (even if not in the `messages` list), do NOT treat it as hallucinated—but still attribute it to “[user]” only if the user mentioned or confirmed the date.
+
+Inputs:
+messages:
+{messages_inline}
+
+memories:
+{memories_inline}
+
+Output Format:
+- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices.
+- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }}
+- The "reason" must be brief and precise, e.g.:
+ - "contains unsupported inference from [assistant]"
+ - "[assistant] memory, but [user] said 'I don't have a dog'"
+ - "fully grounded in [user]"
+
+Important: Output **only** the JSON. No extra text, explanations, markdown, or fields.
+"""
+
+SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT = """
+You are a strict, language-preserving memory validator and rewriter.
+
+Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content.
+
+Note: The provided messages contain only user messages. The assistant's responses are intentionally omitted, not because the assistant didn't answer, but to focus strictly on validating memories against user input.
+
+Rules:
+1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching.
+2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED.
+4. **Hallucination Removal**:
+- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten.
+- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific.
+- Do **not** rephrase inferences as facts. Instead, either:
+- Remove the unsupported part and retain only the grounded core.
+5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged.
+6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite.
+
+Inputs:
+messages:
+{messages_inline}
+
+memories:
+{memories_inline}
+
+Output Format:
+- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices.
+- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }}
+- The "reason" must be brief and precise, e.g.:
+ - "contains unsupported inference ...."
+ - "fully grounded and concise"
+
+Important: Output **only** the JSON. No extra text, explanations, markdown, or fields.
+"""
+
+SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT_BACKUP = """
+You are a strict, language-preserving memory validator and rewriter.
+
+Your task is to eliminate hallucinations and tighten memories by grounding them strictly in the user’s explicit messages. Memories must be factual, unambiguous, and free of any inferred or speculative content.
+
+Rules:
+1. **Language Consistency**: Keep the exact original language of each memory—no translation or language switching.
+2. **Strict Factual Grounding**: Include only what the user explicitly stated. Remove or flag anything not directly present in the messages—no assumptions, interpretations, predictions, or generalizations NOT supported by the text. However, **you MUST retain specific details, reasons, explanations, and feelings if the user explicitly expressed them.** Minor formatting corrections (e.g., adding missing spaces between names, fixing obvious typos) are ALLOWED.
+4. **Hallucination Removal**:
+- If a memory contains **any content not supported by the user's explicit statements**, it must be rewritten.
+- **Do NOT remove** details, reasons, or explanations that the user explicitly provided, even if they are subjective or specific.
+- Do **not** rephrase inferences as facts. Instead, either:
+- Remove the unsupported part and retain only the grounded core.
+5. **No Change if Fully Grounded**: If the memory is concise, unambiguous, and fully supported by the user’s messages, keep it unchanged.
+6. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite.
+
+Inputs:
+messages:
+{messages_inline}
+
+memories:
+{memories_inline}
+
+Output Format:
+- Return a JSON object with string keys ("0", "1", "2", ...) matching input memory indices.
+- Each value must be: {{ "need_rewrite": boolean, "rewritten": string, "reason": string }}
+- The "reason" must be brief and precise, e.g.:
+ - "contains unsupported inference ...."
+ - "fully grounded and concise"
+
+Important: Output **only** the JSON. No extra text, explanations, markdown, or fields.
+"""
+
+SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT = """
+ You are a strict memory validator.
+ Your task is to identify and delete hallucinated memories that are not explicitly stated by the user in the provided messages.
+
+ Rules:
+ 1. **Explicit Denial & Inconsistency**: If a memory claims something that the user explicitly denied or is clearly inconsistent with the user's statements, mark it for deletion.
+ 2. **Timestamp Exception**: Memories may include timestamps (e.g., dates like "On December 19, 2026") derived from conversation metadata. If the date in the memory is likely the conversation time (even if not shown in the `messages` list), do NOT treat it as a hallucination or require a rewrite.
+
+ Example:
+ Messages:
+ [user]: I'm planning a trip to Japan next month for about a week.
+ [assistant]: That sounds great! Are you planning to visit Tokyo Disneyland?
+ [user]: No, I won't be going to Tokyo this time. I plan to stay in Kyoto and Osaka to avoid crowds.
+
+ Memories:
+ {{
+ "0": "User plans to travel to Japan for a week next month.",
+ "1": "User intends to visit Tokyo Disneyland.",
+ "2": "User plans to stay in Kyoto and Osaka."
+ }}
+
+ Output:
+ {{
+ "0": {{ "keep": true, "reason": "Explicitly stated by user." }},
+ "1": {{ "keep": false, "reason": "User explicitly denied visiting Tokyo." }},
+ "2": {{ "keep": true, "reason": "Explicitly stated by user." }}
+ }}
+
+ Inputs:
+ Messages:
+ {messages_inline}
+
+ Memories:
+ {memories_inline}
+
+ Output Format:
+ - Return a JSON object with string keys ("0", "1", "2", ...) matching the input memory indices.
+ - Each value must be: {{ "keep": boolean, "reason": string }}
+ - "keep": true only if the memory is a direct reflection of the user's explicit words.
+ - "reason": brief, factual, and cites missing or unsupported content.
+
+ Important: Output **only** the JSON. No extra text, explanations, markdown, or fields.
+ """
+
+
+SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT = """
+You are a memory manager.
+Your task is to decide if a new memory should be added to the long-term memory, given a list of existing related memories.
+
+Rules:
+1. **Redundancy Check**: If the new memory is completely redundant, already known, or covered by the existing memories, discard it.
+2. **New Information**: If the new memory provides new information, details, or updates compared to the existing memories, keep it.
+3. **Contradiction**: If the new memory contradicts existing memories but seems valid/newer, keep it (updates).
+4. **Context Check**: Use the provided conversation messages to verify if the new memory is grounded in the user's explicit statements.
+
+Inputs:
+Messages:
+{messages_inline}
+
+Candidate Memories (to be evaluated):
+{candidates_inline}
+
+Output Format:
+- Return a JSON object with string keys ("0", "1", "2", ...) matching the input candidate memory indices.
+- Each value must be: {{ "keep": boolean, "reason": string }}
+- "keep": true if the memory should be added.
+- "reason": brief explanation.
+
+Important: Output **only** the JSON. No extra text.
+"""
# Prompt mapping for specialized tasks (e.g., hallucination filtering)
PROMPT_MAPPING = {
"hallucination_filter": SIMPLE_STRUCT_HALLUCINATION_FILTER_PROMPT,
+ "rewrite": SIMPLE_STRUCT_REWRITE_MEMORY_PROMPT,
+ "rewrite_user_only": SIMPLE_STRUCT_REWRITE_MEMORY_USER_ONLY_PROMPT,
+ "add_before_search": SIMPLE_STRUCT_ADD_BEFORE_SEARCH_PROMPT,
}
diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py
index 3315e061e..a67f0c12c 100644
--- a/src/memos/templates/prefer_complete_prompt.py
+++ b/src/memos/templates/prefer_complete_prompt.py
@@ -77,6 +77,7 @@
* **Contextual signals**: What do the user's choices, comparisons, exclusions, or scenario selections reveal about their deeper preferences?
- Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned.
- Go beyond surface-level facts to understand the user's hidden possibilities and underlying logic.
+- For Assistant's responses or suggestions, they can only be extracted as the user's implicit preferences if there is evidence in subsequent conversation that the user implicitly accepted them (e.g., adoption, agreement, acting on the suggestion, etc.). Assistant suggestions alone do not constitute user preferences.
Requirements:
1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses.
@@ -117,6 +118,7 @@
* **情境信号**:用户的选择、比较、排除或场景选择揭示了什么样的深层偏好?
- 不要将明确陈述的偏好视为隐式偏好;此提示仅用于推断未直接提及的偏好。
- 超越表面事实,理解用户的隐藏可能性和背后的逻辑。
+- 对于Assistant的回答内容或建议,只有在后续对话中用户表现出隐含接受(如采纳、认同、按建议行动等)的情况下,才能将相关内容提取为用户的隐式偏好。单纯的Assistant建议本身不构成用户偏好。
要求:
1. 仅在对话中有充分证据时进行推断;避免无根据或牵强的猜测。
diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py
index ecbca5815..cc8909d34 100644
--- a/src/memos/vec_dbs/milvus.py
+++ b/src/memos/vec_dbs/milvus.py
@@ -457,14 +457,13 @@ def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None:
return None
entity = results[0]
- payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]}
return MilvusVecDBItem(
id=entity["id"],
memory=entity.get("memory"),
original_text=entity.get("original_text"),
vector=entity.get("vector"),
- payload=payload,
+ payload=entity.get("payload", {}),
)
def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]:
@@ -479,14 +478,13 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBIt
items = []
for entity in results:
- payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]}
items.append(
MilvusVecDBItem(
id=entity["id"],
memory=entity.get("memory"),
original_text=entity.get("original_text"),
vector=entity.get("vector"),
- payload=payload,
+ payload=entity.get("payload", {}),
)
)
@@ -646,3 +644,11 @@ def delete(self, collection_name: str, ids: list[str]) -> None:
collection_name=collection_name,
ids=ids,
)
+
+ def delete_by_filter(self, collection_name: str, filter: dict[str, Any]) -> None:
+ """Delete items from the vector database by filter."""
+ expr = self._dict_to_expr(filter) if filter else ""
+ self.client.delete(
+ collection_name=collection_name,
+ filter=expr,
+ )
diff --git a/src/memos/vec_dbs/qdrant.py b/src/memos/vec_dbs/qdrant.py
index 633cd3580..d0853c4af 100644
--- a/src/memos/vec_dbs/qdrant.py
+++ b/src/memos/vec_dbs/qdrant.py
@@ -138,14 +138,14 @@ def search(
List of search results with distance scores and payloads.
"""
qdrant_filter = self._dict_to_filter(filter) if filter else None
- response = self.client.search(
+ response = self.client.query_points(
collection_name=self.config.collection_name,
- query_vector=query_vector,
+ query=query_vector,
limit=top_k,
query_filter=qdrant_filter,
with_vectors=True,
with_payload=True,
- )
+ ).points
logger.info(f"Qdrant search completed with {len(response)} results.")
return [
VecDBItem(
diff --git a/tests/llms/test_hf.py b/tests/llms/test_hf.py
index 595995ad1..375bf2247 100644
--- a/tests/llms/test_hf.py
+++ b/tests/llms/test_hf.py
@@ -11,8 +11,8 @@
from memos.llms.hf import HFLLM
-@patch("memos.llms.hf.AutoModelForCausalLM", MagicMock())
-@patch("memos.llms.hf.AutoTokenizer", MagicMock())
+@patch("transformers.AutoModelForCausalLM", MagicMock())
+@patch("transformers.AutoTokenizer", MagicMock())
class TestHFLLM(unittest.TestCase):
def setUp(self):
self.mock_inputs = MagicMock()
diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py
index f81356886..fd07fbf41 100644
--- a/tests/mem_reader/test_simple_structure.py
+++ b/tests/mem_reader/test_simple_structure.py
@@ -1,4 +1,3 @@
-import json
import unittest
from unittest.mock import MagicMock, patch
@@ -8,6 +7,7 @@
from memos.embedders.factory import EmbedderFactory
from memos.llms.factory import LLMFactory
from memos.mem_reader.simple_struct import SimpleStructMemReader
+from memos.mem_reader.utils import parse_json_result
from memos.memories.textual.item import TextualMemoryItem
@@ -57,7 +57,6 @@ def test_process_chat_data(self):
'"summary": "Tom is currently focused on managing a new project with a tight schedule."}'
)
self.reader.llm.generate.return_value = mock_response
- self.reader.parse_json_result = lambda x: json.loads(x)
result = self.reader._process_chat_data(scene_data_info, info)
@@ -105,7 +104,7 @@ def test_get_scene_data_info_with_chat(self):
def test_parse_json_result_success(self):
"""Test successful JSON parsing."""
raw_response = '{"summary": "Test summary", "tags": ["test"]}'
- result = self.reader.parse_json_result(raw_response)
+ result = parse_json_result(raw_response)
self.assertIsInstance(result, dict)
self.assertIn("summary", result)
@@ -113,7 +112,7 @@ def test_parse_json_result_success(self):
def test_parse_json_result_failure(self):
"""Test failure in JSON parsing."""
raw_response = "Invalid JSON string"
- result = self.reader.parse_json_result(raw_response)
+ result = parse_json_result(raw_response)
self.assertEqual(result, {})
diff --git a/tests/vec_dbs/test_qdrant.py b/tests/vec_dbs/test_qdrant.py
index f4bd276c3..67f76d463 100644
--- a/tests/vec_dbs/test_qdrant.py
+++ b/tests/vec_dbs/test_qdrant.py
@@ -70,13 +70,25 @@ def test_add_and_get_by_id(vec_db):
def test_search(vec_db):
id = str(uuid.uuid4())
- vec_db.client.search.return_value = [
- type(
- "obj",
- (object,),
- {"id": id, "vector": [0.1, 0.2, 0.3], "payload": {"tag": "search"}, "score": 0.9},
- )
- ]
+ mock_response = type(
+ "QueryResponse",
+ (object,),
+ {
+ "points": [
+ type(
+ "obj",
+ (object,),
+ {
+ "id": id,
+ "vector": [0.1, 0.2, 0.3],
+ "payload": {"tag": "search"},
+ "score": 0.9,
+ },
+ )
+ ]
+ },
+ )()
+ vec_db.client.query_points.return_value = mock_response
results = vec_db.search([0.1, 0.2, 0.3], top_k=1)
assert len(results) == 1
assert isinstance(results[0], VecDBItem)