diff --git a/.gitignore b/.gitignore
index 7a35919..1698ba2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -158,6 +158,9 @@ tutorial/example_deep_finance/yaml/*
tutorial/example_deep_finance/config/*
tutorial/example_deep_finance/scripts/*
flash_attn-2.8.*.whl
+tutorial/example_deep_finance/prepare_data/*
+tutorial/example_deep_finance/judge/analytical_sufficiency/*
+
.dockerignore
benchmark_datasets
modelscope_cache
diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py
index dc192aa..51f6d60 100644
--- a/ajet/context_tracker/multiagent_tracking.py
+++ b/ajet/context_tracker/multiagent_tracking.py
@@ -82,6 +82,18 @@ def extract_text_content_from_content_dict(self, msg):
# },
# ],
# }
+ # or tool_result format?? not observed yet:
+ # msg = {
+ # "role": "tool",
+ # "content": [
+ # {
+ # "type": "tool_result",
+ # "id": "call_xxx",
+ # "output": "tool output content",
+ # "name": "tool_name"
+ # },
+ # ],
+ # }
str_content = ""
diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py
index 88f9ab1..c261056 100644
--- a/ajet/task_runner/general_runner.py
+++ b/ajet/task_runner/general_runner.py
@@ -9,6 +9,7 @@
from ajet.schema.trajectory import Reward
from ajet.task_runner.base_runner import BaseAgentRunner
from ajet.utils.dynamic_import import dynamic_import
+from ajet.utils.metric_helper.reward_metric_helper import populate_reward_metadata_from_stats
class GeneralRunner(BaseAgentRunner):
@@ -73,6 +74,10 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
madness=0,
description="",
)
+
+ # Populate reward metadata with deep_finance reward stats if available
+ if "reward_stats" in workflow_output.metadata:
+ populate_reward_metadata_from_stats(reward, workflow_output.metadata["reward_stats"])
context_tracker.process_reward(reward)
# generate token before merging
context_tracker.group_merge()
diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py
index 76d034b..ea951d5 100644
--- a/ajet/utils/metric_helper/reward_metric_helper.py
+++ b/ajet/utils/metric_helper/reward_metric_helper.py
@@ -11,9 +11,12 @@
- judge_time/ Judge time consumption statistics
"""
-from typing import List, Dict, Any
+from typing import List, Dict, Any, TYPE_CHECKING
import numpy as np
+if TYPE_CHECKING:
+ from ajet.schema.trajectory import Reward
+
def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]:
"""
@@ -72,22 +75,15 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str
metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties)
metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0
- # ========== Detect OpenJudge Usage ==========
+ # ========== OpenJudge Metrics (PresentationQualityGrader, GroundingGrader) ==========
openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False))
if openjudge_enabled_count > 0:
- # ========== OpenJudge Metrics ==========
-
- # Dynamically extract OpenJudge grader fields
- # Currently supported graders: report_resolution, trajectory_faithfulness,
- # rubrics_performance, trajectory_comprehensive, information_gain, action_loop
+ # OpenJudge graders: presentation_quality, grounding
openjudge_graders = [
- "report_resolution",
- "trajectory_faithfulness",
- "rubrics_performance",
- "trajectory_comprehensive",
- "information_gain",
- "action_loop",
+ "presentation_quality",
+ "grounding",
+ "planning"
]
for grader_name in openjudge_graders:
@@ -151,3 +147,18 @@ def compute_reward_metrics_from_trajectories(trajectories: List[Any], prefix: st
reward_stats_list = extract_reward_stats_from_trajectories(trajectories)
return compute_reward_metrics(reward_stats_list, prefix=prefix)
+
+def populate_reward_metadata_from_stats(reward: "Reward", reward_stats: Dict[str, Any]) -> None:
+ """
+ Populate Reward.metadata with all reward statistics.
+
+ Args:
+ reward: The Reward object to populate
+ reward_stats: The reward_stats dictionary from judge
+ """
+ if not reward_stats:
+ return
+
+ # Directly copy all reward_stats into metadata
+ reward.metadata.update(reward_stats)
+
diff --git a/tutorial/example_deep_finance/__init__.py b/tutorial/example_deep_finance/__init__.py
new file mode 100644
index 0000000..36e084c
--- /dev/null
+++ b/tutorial/example_deep_finance/__init__.py
@@ -0,0 +1 @@
+# tutorial/example_deep_finance package
diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh
index 6e3c13b..bee02ac 100644
--- a/tutorial/example_deep_finance/deep_finance.sh
+++ b/tutorial/example_deep_finance/deep_finance.sh
@@ -1,10 +1,10 @@
#!/bin/bash
-set -e
+set -e
#===============================================================================
# 1. 配置区域 - 用户只需修改这里
#===============================================================================
-SUFFIX="deep_finance" # 实验后缀,影响所有日志和实验名称
-PREFIX="open" # 实验前缀,影响日志和实验所在文件夹
+SUFFIX="newjudge" # 实验后缀,影响所有日志和实验名称
+PREFIX="ajet_newjudge" # 实验前缀,影响日志和实验所在文件夹
# OpenJudge 模型配置
OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型
@@ -12,10 +12,9 @@ RM_LLM='qwen-max' # RM Gallery 评分模型
JUDGE_CONCURRENCY=10
# 奖励权重配置
-RM_WEIGHT=0.4
-CITATION_AUDIT_WEIGHT=0.2
-REPORT_RESOLUTION_WEIGHT=0.2
-TRAJECTORY_FAITHFULNESS_WEIGHT=0.2
+RM_WEIGHT=0.5
+PRESENTATION_QUALITY_WEIGHT=0.25
+GROUNDING_WEIGHT=0.25
# 训练参数配置
NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次
@@ -23,7 +22,8 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize
NUM_STEPS=6 # 每个样本step轮数
DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000
-# 主目录
+# 主目录(需要更改)
+export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new"
NNODES=${WORLD_SIZE}
@@ -46,7 +46,7 @@ fi
# 2. 动态生成配置文件 (从yaml template生成yaml)
#===============================================================================
# 修改:配置文件生成路径,现在动态生成到 yaml 目录下
-CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml"
+CONFIG_TEMPLATE="tutorial/example_deep_finance/deep_finance.yaml"
CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/${SUFFIX}.yaml"
mkdir -p $(dirname ${CONFIG_FILE})
@@ -55,12 +55,11 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \
-e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \
-e "s|{{NNODES}}|${NNODES}|g" \
-e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \
- -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \
+ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \
+ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \
-e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \
-e "s|{{RM_LLM}}|${RM_LLM}|g" \
-e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \
- -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \
- -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \
-e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \
-e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \
-e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \
@@ -72,7 +71,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \
${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE}
echo "配置文件已生成: ${CONFIG_FILE}"
-echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}"
+echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}"
#===============================================================================
# 3. 环境配置
@@ -106,7 +105,7 @@ export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS
# 其他服务配置
HF_ENDPOINT="https://hf-mirror.com"
ES_HOSTS="http://11.160.132.46:8200"
-export HF_ENDPOINT ES_HOSTS
+export HF_ENDPOINT ES_HOSTS
# log 文件位置
CURRENT_TIME=$(date "+%Y%m%d_%H%M%S")
@@ -114,7 +113,7 @@ LOG_DIR="${AJET_ROOT}/logs/${PREFIX}"
MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log"
ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log"
TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log"
-
+env_log_prefix="${SUFFIX}__${CURRENT_TIME}"
# 多机训练参数配置
GPUS_PER_NODE=8
EXPECTED_WORKERS=$WORLD_SIZE
@@ -156,6 +155,8 @@ export NCCL_ASYNC_ERROR_HANDLING=1
export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}"
export RAY_CLUSTER_MODE="multi_node"
+export DEEPFINANCE_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径
+export DEEPFINANCE_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && DEEPFINANCE_TOOL_RESULT_MAX_CHARS=${DEEPFINANCE_TOOL_RESULT_MAX_CHARS} DEEPFINANCE_MCP_CONFIG=${DEEPFINANCE_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080"
#===============================================================================
@@ -202,11 +203,12 @@ if [[ $HOSTNAME == *"-master-"* ]]; then
# 启动训练任务(最核心)
python ajet/launcher.py \
+ --with-deepfinance \
--conf ${CONFIG_FILE} \
--backbone="verl" \
- --prefix=${SUFFIX} \
+ --prefix=${env_log_prefix} \
2>&1 | tee ${TRAIN_LOG}
-
+
#===============================================================================
# 6.2 Worker 节点启动流程
@@ -218,4 +220,4 @@ else
ray stop || true
ray start --address $MASTER_ADDR:6379 --num-gpus 8
while true; do sleep 60; done
-fi
+fi
\ No newline at end of file
diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml
index 15dd566..33103fe 100644
--- a/tutorial/example_deep_finance/deep_finance.yaml
+++ b/tutorial/example_deep_finance/deep_finance.yaml
@@ -1,19 +1,18 @@
# ------------------ 主要配置 ------------------
ajet:
- project_name: ajet_deep_finance
- experiment_name: "ajet_deep_finance"
+ project_name: "{{PREFIX}}"
+ experiment_name: "{{SUFFIX}}"
# Judge 配置(嵌套结构,对应 self.config.ajet.judge.*)
judge:
- openjudge_llm: qwen-flash # OpenJudge 模型
- rm_llm: qwen-max # RM Gallery 模型
- concurrency: 10 # Judge 并发数
+ openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型
+ rm_llm: {{RM_LLM}} # RM Gallery 模型
+ concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数
train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径
val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径
# OpenJudge 权重配置
- report_resolution_weight: 0.2 # 报告质量评估
- trajectory_faithfulness_weight: 0.2 # 事实准确性评估
- citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性)
- rm_weight: 0.4 # RM Gallery 权重
+ presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估
+ grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估
+ rm_weight: {{RM_WEIGHT}} # RM Gallery 权重
task_judge:
# 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service)
judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge
@@ -21,7 +20,7 @@ ajet:
# ✨✨✨✨ 设置待训练的模型
path: {{MODEL_PATH}}
trainer_common:
- nnodes: 8
+ nnodes: {{NNODES}}
n_gpus_per_node: 8
val_before_train: True
val_pass_n: 8
@@ -32,10 +31,10 @@ ajet:
rollout:
# ✨✨✨✨ 编写并选择Agent
user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol
- force_disable_toolcalls: True
+ force_disable_toolcalls: False
enable_oversample: False
tensor_model_parallel_size: 8
- num_repeat: 4
+ num_repeat: {{NUM_REPEAT}}
max_env_worker: 64 # 增加环境并行数
max_num_seqs: 64 # 增加VLLM并发序列数
max_response_length_in_one_turn: 8000
@@ -43,14 +42,14 @@ ajet:
agent_madness_reward: 0.0
compute_madness_checklist: None
multi_turn:
- max_steps: 6
+ max_steps: {{NUM_STEPS}}
interchange_server:
interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
debug:
debug_max_parallel: 1 # 增加并行任务数,充分利用GPU
debug_first_n_tasks: 100 # 增加处理的任务数
data:
- train_batch_size: 32
+ train_batch_size: {{TRAIN_BATCH_SIZE}}
max_prompt_length: 8000
max_response_length: 41000
@@ -58,18 +57,16 @@ ajet:
type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service
deep_finance:
training:
- file_path: {{TRAIN_PATH}}
+ file_path: {{TRAIN_DATA_PATH}}
validation:
- file_path: {{VAL_PATH}}
+ file_path: {{VAL_DATA_PATH}}
# env_service 仍需配置(用于工具调用)
env_service:
env_type: "finworld"
env_url: {{ENV_SERVICE_URL}}
env_action_preference: code
-
-
trainer:
- default_local_dir: {{CKPT_SAVE_PATH}}
+ default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}"
# resume_mode: disable # 禁用自动恢复,从头开始训练
actor_rollout_ref:
rollout:
diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py
index 31e4be0..03f1013 100644
--- a/tutorial/example_deep_finance/deep_finance_judge.py
+++ b/tutorial/example_deep_finance/deep_finance_judge.py
@@ -1,5 +1,5 @@
"""DeepFinance Task Judge - OpenJudge 版本
-集成: RM Gallery, OpenJudge Graders (含 CitationAudit)
+集成: RM Gallery, PresentationQualityGrader
"""
import os
@@ -13,32 +13,11 @@
from ajet.task_judge.base_judge import BaseJudge
from ajet.workflow import WorkflowOutput, WorkflowTask
-from openjudge.graders.agent.action.action_loop import ActionLoopDetectionGrader
-from openjudge.graders.agent.observation.observation_information_gain import (
- ObservationInformationGainGrader,
-)
-from openjudge.graders.agent.trajectory.trajectory_comprehensive import (
- TrajectoryComprehensiveGrader,
-)
from openjudge.models.openai_chat_model import OpenAIChatModel
-from openjudge.models.schema.prompt_template import LanguageEnum
from openjudge.runner.grading_runner import GraderConfig, GradingRunner
-from openjudge.scenarios.deep_research.graders.financial_report_resolution import (
- FinancialReportResolutionGrader,
-)
-from openjudge.scenarios.deep_research.graders.financial_trajectory_faithfulness import (
- FinancialTrajectoryFaithfulGrader,
-)
-from openjudge.scenarios.deep_research.graders.rubrics_based_trajectory_performance import (
- RubricsBasedTrajectoryPerformance,
-)
-from openjudge.scenarios.deep_research.graders.financial_report_citation_audit import (
- FinancialReportCitationAuditGrader,
-)
+from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader
-# RewardStats 不再使用,OpenJudge 版本直接使用字典存储
-# Reference Answer 路径现在从 config 中读取,见 _init_reference_answers 方法
# OpenJudge imports
# =============================================================================
@@ -88,7 +67,7 @@ def load_reference_answers_from_file(file_path: str) -> Tuple[Dict[str, str], Di
class DeepFinanceJudgeByOpenJudge(BaseJudge):
"""
使用 OpenJudge 框架的 DeepFinance Judge
- 集成: RM Gallery, OpenJudge Graders (含 CitationAudit)
+ 集成: RM Gallery, PresentationQualityGrader
分析:
- compute_reward 每次处理 **一条采样**(单个 workflow_output)
@@ -116,26 +95,15 @@ def _setup_weights(self):
配置 OpenJudge 各 grader 的权重并归一化
graders 对应关系:
- - financial_report_resolution: 报告质量和问题解决能力
- - financial_trajectory_faithfulness: 事实准确性(忠实度)
- - citation_audit: 引用审计(覆盖率 + 真实性)
- - rubrics_based_trajectory_performance: 基于 rubrics 的评估
- - trajectory_comprehensive: 轨迹综合评估
- - observation_information_gain: 信息增益(去重)
- - action_loop_detection: 动作循环检测(惩罚项)
+ - presentation_quality: 报告呈现质量评估
"""
cfg = getattr(self.config, "ajet", None)
- # 定义各 grader 的权重(可从 config 中读取)- 与 deep_finance_judge.py 对齐
+ # 定义各 grader 的权重(可从 config 中读取)
self.w = {
"rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重
- "citation_audit": getattr(cfg, "citation_audit_weight", 0.0) if cfg else 0.0, # CitationAudit 权重
- "report_resolution": getattr(cfg, "report_resolution_weight", 0.0) if cfg else 0.0,
- "trajectory_faithfulness": getattr(cfg, "trajectory_faithfulness_weight", 0.0) if cfg else 0.0,
- # "rubrics_performance": getattr(cfg, "rubrics_performance_weight", 0.2) if cfg else 0.2,
- # "trajectory_comprehensive": getattr(cfg, "trajectory_comprehensive_weight", 0.2) if cfg else 0.2,
- # "information_gain": getattr(cfg, "information_gain_weight", 0.1) if cfg else 0.1,
- # "action_loop": getattr(cfg, "action_loop_weight", 0.1) if cfg else 0.1
+ "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25,
+ "grounding": getattr(cfg, "grounding_weight", 0.25) if cfg else 0.25,
}
# 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化)
@@ -244,15 +212,14 @@ def _create_runner_in_loop(self) -> GradingRunner:
注意:GradingRunner 内部的 Semaphore 会绑定到创建时的事件循环,
因此不能使用单例模式,必须在每次调用的事件循环中创建新实例。
"""
- language = LanguageEnum.ZH
- grader_configs = self._create_grader_configs(self.model, language)
+ grader_configs = self._create_grader_configs(self.model)
return GradingRunner(
grader_configs=grader_configs,
max_concurrency=self.max_concurrency,
show_progress=False
)
- def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) -> Dict[str, GraderConfig]:
+ def _create_grader_configs(self, model: OpenAIChatModel) -> Dict[str, GraderConfig]:
"""
创建所有 grader 的配置
@@ -260,54 +227,35 @@ def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum)
- key: grader 名称
- value: GraderConfig(grader=..., mapper=...)
"""
+
+ def extract_user_query(data: Dict) -> str:
+ """从 messages 中提取第一条 user 消息的 content"""
+ for msg in data.get("messages", []):
+ if msg.get("role") == "user":
+ return msg.get("content", "")
+ return ""
+
+ def extract_report_content(data: Dict) -> str:
+ """从 messages 中提取最后一条 assistant 消息的 content"""
+ for msg in reversed(data.get("messages", [])):
+ if msg.get("role") == "assistant":
+ return msg.get("content", "")
+ return ""
+
return {
- # 1. 报告质量评估 - 需要 messages 和 chat_date
- "report_resolution": GraderConfig(
- grader=FinancialReportResolutionGrader(model=model, language=language),
+ # 报告呈现质量评估 - 需要 user_query 和 report_content
+ "presentation_quality": GraderConfig(
+ grader=PresentationQualityGrader(model=model),
mapper=lambda data: {
- "messages": data["messages"],
- "chat_date": data.get("chat_date")
+ "user_query": extract_user_query(data),
+ "report_content": extract_report_content(data),
},
),
-
- # 2. 事实准确性评估 - 需要 messages
- "trajectory_faithfulness": GraderConfig(
- grader=FinancialTrajectoryFaithfulGrader(model=model, language=language),
- mapper=lambda data: {"messages": data["messages"]},
- ),
-
- # 3. 引用审计评估 - 需要 messages
- "citation_audit": GraderConfig(
- grader=FinancialReportCitationAuditGrader(model=model, language=language),
- mapper=lambda data: {"messages": data["messages"]},
+ # 引用规范性评估 - 需要完整的 traj
+ "grounding": GraderConfig(
+ grader=GroundingGrader(model=model),
+ mapper=lambda data: {"traj": data},
),
-
- # 4. Rubrics 评估 - 需要 messages 和 rubrics
- # "rubrics_performance": GraderConfig(
- # grader=RubricsBasedTrajectoryPerformance(model=model, language=language),
- # mapper=lambda data: {
- # "messages": data["messages"],
- # "rubrics": data.get("rubrics", [])
- # },
- # ),
-
- # 5. 轨迹综合评估 - 需要 messages
- # "trajectory_comprehensive": GraderConfig(
- # grader=TrajectoryComprehensiveGrader(model=model, language=language),
- # mapper=lambda data: {"messages": data["messages"]},
- # ),
-
- # 6. 信息增益评估 - 需要 messages(非 LLM grader)
- # "information_gain": GraderConfig(
- # grader=ObservationInformationGainGrader(similarity_threshold=0.5),
- # mapper=lambda data: {"messages": data["messages"]},
- # ),
-
- # 7. 动作循环检测 - 需要 messages(非 LLM grader)
- # "action_loop": GraderConfig(
- # grader=ActionLoopDetectionGrader(similarity_threshold=1.0),
- # mapper=lambda data: {"messages": data["messages"]},
- # ),
}
def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]:
@@ -361,14 +309,28 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO
chat_date=chat_date
)
+ if openjudge_sample.get('messages'):
+ last_msg = openjudge_sample['messages'][-1]
+
# 3. 调用 OpenJudge Runner.arun(异步)
grading_start_time = time.time()
grader_results = self._run_openjudge_evaluation([openjudge_sample])
grading_time = time.time() - grading_start_time
+
# 4. 提取各 grader 分数(arun 返回 Dict[str, List[GraderScore]],这里取第一条)
grader_scores, quota_exceeded_flags = self._extract_grader_scores(grader_results)
+ # 4.5 如果有分数为0的grader,保存调试信息到单独文件
+ self._save_zero_score_debug(
+ grader_scores=grader_scores,
+ grader_results=grader_results,
+ query=query,
+ history=history,
+ report=assistants[-1] if assistants else "",
+ task_id=task_id
+ )
+
# 5. 加权融合(包含 RM Gallery 和 OpenJudge Graders)
fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw)
@@ -552,8 +514,7 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[
输入:
- grader_results: Dict[str, List[GraderScore]]
{
- "report_resolution": [GraderScore(score=0.88, reason="...", metadata={...})],
- "trajectory_faithfulness": [GraderScore(score=1.0, ...)],
+ "presentation_quality": [GraderScore(score=0.88, reason="...", metadata={...})],
...
}
@@ -570,6 +531,10 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[
if score_list and len(score_list) > 0:
# 取第一条采样的分数(因为每次只评估一条)
grader_score = score_list[0]
+
+ # DEBUG: 记录详细信息
+ reason_str = getattr(grader_score, 'reason', None)
+ print(f" [DEBUG] {grader_name}: score={getattr(grader_score, 'score', 'N/A')}, reason={str(reason_str)[:300] if reason_str else 'N/A'}")
if hasattr(grader_score, "score"):
scores[grader_name] = grader_score.score
# 检测错误类型:分数为0且有错误信息
@@ -581,6 +546,7 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[
else:
# 如果出错,设为 0
scores[grader_name] = 0.0
+ print(f" [DEBUG] {grader_name}: no 'score' attr, grader_score={grader_score}")
else:
scores[grader_name] = 0.0
@@ -657,6 +623,69 @@ def _save_rm_log(self, result, query: str, task_id: str):
except Exception:
pass
+ def _save_zero_score_debug(
+ self,
+ grader_scores: Dict[str, float],
+ grader_results: Dict[str, List[Any]],
+ query: str,
+ history: List[Dict],
+ report: str,
+ task_id: str
+ ):
+ """
+ 当有 grader 分数为 0 时,保存详细调试信息到单独文件
+
+ 保存内容包括:
+ - query: 用户查询
+ - traj: 对话历史
+ - report: 最终报告(前500字)
+ - zero_score_reasons: 得 0 分的原因
+ """
+ try:
+ # 检查是否有分数为 0 的 grader
+ zero_score_graders = [name for name, score in grader_scores.items() if score == 0.0]
+ if not zero_score_graders:
+ return
+
+ # 提取得 0 分的原因
+ zero_score_reasons = {}
+ for grader_name in zero_score_graders:
+ if grader_name in grader_results:
+ score_list = grader_results[grader_name]
+ if score_list and len(score_list) > 0:
+ grader_score = score_list[0]
+ reason = getattr(grader_score, 'reason', None)
+ zero_score_reasons[grader_name] = str(reason) if reason else "N/A"
+ else:
+ zero_score_reasons[grader_name] = "empty score_list"
+ else:
+ zero_score_reasons[grader_name] = "grader not in results"
+
+ # 构建调试日志
+ debug_log = {
+ "task_id": task_id,
+ "timestamp": datetime.now().isoformat(),
+ "query": query,
+ "report": report if report else "",
+ "trajectory": history,
+ "grader_scores": grader_scores,
+ "zero_score_graders": zero_score_graders,
+ "zero_score_reasons": zero_score_reasons
+ }
+
+ # 保存到单独文件
+ save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new/tutorial/example_deep_finance/outputs/reward_zero_debug"
+ os.makedirs(save_dir, exist_ok=True)
+ log_file = os.path.join(save_dir, f"zeroscore_{datetime.now().strftime('%Y%m%d')}.jsonl")
+ with open(log_file, "a", encoding="utf-8") as f:
+ f.write(json.dumps(debug_log, ensure_ascii=False) + "\n")
+
+ print(f" [ZERO SCORE DEBUG] task_id={task_id}, zero_graders={zero_score_graders}, saved to {log_file}")
+
+ except Exception as e:
+ print(f"⚠️ Failed to save zero score debug: {e}")
+ pass
+
def _compute_penalty(self, tool_calls: int) -> float:
"""
计算工具调用惩罚(保留原有逻辑)
@@ -689,13 +718,7 @@ def _update_metadata_stats(
更新 metadata["reward_stats"] - 直接使用 OpenJudge 原始字段
OpenJudge graders(按实际启用情况):
- - report_resolution: 报告质量和问题解决能力
- - trajectory_faithfulness: 事实准确性(忠实度)
- - citation_audit: 引用审计(覆盖率 + 真实性)
- - rubrics_performance: 基于 rubrics 的评估(可选)
- - trajectory_comprehensive: 轨迹综合评估(可选)
- - information_gain: 信息增益/去重(可选)
- - action_loop: 动作循环检测(惩罚项,可选)
+ - presentation_quality: 报告呈现质量评估
注意:不再硬套 RewardStats 的字段名,直接使用 openjudge_ 前缀
"""
@@ -712,10 +735,6 @@ def _update_metadata_stats(
"penalty": penalty,
"step_reward": step_reward,
"openjudge_enabled": True,
- # Quota exceeded (429) 统计
- "quota_exceeded_any": quota_exceeded_any, # 是否有任何 grader 超额
- "quota_exceeded_count": quota_exceeded_count, # 超额的 grader 数量
- "quota_exceeded_graders": quota_exceeded_flags, # 各 grader 的超额标记
# RM Gallery 相关
"rm_enabled": self._rm_enabled,
"rm_raw": rm_raw,
diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh
index 6b27d0f..e794dff 100644
--- a/tutorial/example_deep_finance/deep_finance_single.sh
+++ b/tutorial/example_deep_finance/deep_finance_single.sh
@@ -3,8 +3,8 @@ set -e
#===============================================================================
# 1. 配置区域 - 用户只需修改这里
#===============================================================================
-SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称
-PREFIX="open" # 实验前缀,影响日志和实验所在文件夹
+SUFFIX="newjudge" # 实验后缀,影响所有日志和实验名称
+PREFIX="ajet_newjudge" # 实验前缀,影响日志和实验所在文件夹
# OpenJudge 模型配置
OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型
@@ -12,10 +12,9 @@ RM_LLM='qwen-max' # RM Gallery 评分模型
JUDGE_CONCURRENCY=10
# 奖励权重配置
-RM_WEIGHT=0.4
-CITATION_AUDIT_WEIGHT=0.2
-REPORT_RESOLUTION_WEIGHT=0.2
-TRAJECTORY_FAITHFULNESS_WEIGHT=0.2
+RM_WEIGHT=0.5
+PRESENTATION_QUALITY_WEIGHT=0.25
+GROUNDING_WEIGHT=0.25
# 训练参数配置
NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次
@@ -23,7 +22,8 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize
NUM_STEPS=6 # 每个样本step轮数
DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000
-# 主目录
+# 主目录(需要更改)
+export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new"
NNODES=${WORLD_SIZE}
@@ -55,70 +55,23 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \
-e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \
-e "s|{{NNODES}}|${NNODES}|g" \
-e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \
- -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \
+ -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \
+ -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \
-e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \
-e "s|{{RM_LLM}}|${RM_LLM}|g" \
-e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \
- -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \
- -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \
-e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \
-e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \
-e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \
-e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \
-e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \
- -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \
-e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \
-e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \
-e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \
${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE}
echo "配置文件已生成: ${CONFIG_FILE}"
-echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}"
-
-#===============================================================================
-# 3. 环境配置
-#===============================================================================
-# MongoDB 缓存配置
-CACHE_TYPE="mongodb"
-MONGO_URI="mongodb://${ADDR}:27117/"
-MONGO_DB_NAME="finworld_cache"
-MONGO_COLLECTION_NAME="tool_cache"
-export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME
-
-# DeepFinance MCP 配置
-DEEPFINANCE_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json"
-
-# 动态生成 MCP 配置文件
-mkdir -p $(dirname ${DEEPFINANCE_MCP_CONFIG})
-cat > ${DEEPFINANCE_MCP_CONFIG} << EOF
-{
- "mcpServers": {
- "flowllm": {
- "transport": "sse",
- "url": "http://${ADDR}:${MCP_PORT}/sse",
- "timeout": 600,
- "sse_read_timeout": 1200
- }
- }
-}
-EOF
-export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS
-
-# 其他服务配置
-HF_ENDPOINT="https://hf-mirror.com"
-ES_HOSTS="http://11.160.132.46:8200"
-export HF_ENDPOINT ES_HOSTS
-
-# log 文件位置
-CURRENT_TIME=$(date "+%Y%m%d_%H%M%S")
-LOG_DIR="${AJET_ROOT}/logs/${PREFIX}"
-MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log"
-ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log"
-TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log"
-
-# 多机训练参数配置
-GPUS_PER_NODE=8
-EXPECTED_WORKERS=$WORLD_SIZE
+echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}"
#===============================================================================
@@ -162,7 +115,6 @@ export RAY_CLUSTER_MODE="multi_node"
#===============================================================================
# 6. 主流程
#===============================================================================
-log "开始多机多卡训练: ${SUFFIX}"
log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}"
mkdir -p ${LOG_DIR}
mkdir -p $(dirname ${CONFIG_FILE})
diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py
new file mode 100644
index 0000000..75c8cef
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/__init__.py
@@ -0,0 +1,11 @@
+# 使得可以通过 from judge import PresentationQualityGrader 直接引用
+from .grounding.grader import GroundingGrader
+from .presentation_quality.grader import PresentationQualityGrader
+# from .research_depth.grader import ResearchDepthGrader
+# from .research_breadth.grader import ResearchBreadthGrader
+
+# 以后添加了其他 grader 也可以加在这里
+# from .grounding.grader import GroundingGrader
+# from .research_breadth.grader import ResearchBreadthGrader
+# __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"]
+__all__ = ["PresentationQualityGrader", "GroundingGrader"]
diff --git a/tutorial/example_deep_finance/judge/grounding/__init__.py b/tutorial/example_deep_finance/judge/grounding/__init__.py
new file mode 100644
index 0000000..1123382
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/grounding/__init__.py
@@ -0,0 +1,4 @@
+"""Grounding Grader - 引用规范性评估"""
+from .grader import GroundingGrader
+
+__all__ = ["GroundingGrader"]
diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py
new file mode 100644
index 0000000..599ccc9
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/grounding/grader.py
@@ -0,0 +1,222 @@
+"""Grounding Grader - 引用规范性评估 (OpenJudge 版本)"""
+from __future__ import annotations
+
+import os
+from typing import Any, Dict, List, Tuple
+
+from openjudge.graders.base_grader import BaseGrader
+from openjudge.graders.schema import GraderScore
+
+# import path 兼容两种写法
+try:
+ from openjudge.models import OpenAIChatModel
+except Exception: # pragma: no cover
+ from openjudge.models.openai_chat_model import OpenAIChatModel
+
+from .prompt import GROUNDING_SYSTEM_PROMPT, GROUNDING_USER_PROMPT_TEMPLATE
+from .json_utils import strict_load_json, validate_shape, construct_reward_prompt
+
+
+class GroundingGrader(BaseGrader):
+ """
+ 引用规范性评估 Grader
+
+ - 输入:traj(完整对话轨迹)
+ - 输出:GraderScore(name, score, reason)
+ - score:综合分数,范围[0,1]
+ - citation_coverage_score: 引用覆盖率(0.5 权重)
+ - grounding_score: 引用真实性(0.5 权重)
+ - invalid_penalty: 无效引用惩罚(最多扣 0.5)
+ - determinism:建议用 temperature=0 + disable thinking
+ - 解析失败:score=0,并在 reason 显示报错
+ """
+
+ def __init__(
+ self,
+ model: OpenAIChatModel,
+ name: str = "grounding",
+ **kwargs: Any,
+ ):
+ super().__init__(name=name, **kwargs)
+ self.model = model
+
+ @staticmethod
+ def create_default_model(
+ model_name: str,
+ api_key: str | None = None,
+ base_url: str | None = None,
+ deterministic: bool = True,
+ enable_thinking: bool = False,
+ seed: int = 0,
+ ) -> OpenAIChatModel:
+ """
+ 创建默认模型
+ 也可以不调用这个工厂,自己在外面 new OpenAIChatModel
+ """
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
+ base_url = base_url or os.getenv("OPENAI_BASE_URL")
+
+ extra_body: Dict[str, Any] = {}
+ if deterministic:
+ extra_body.update(
+ {
+ "temperature": 0,
+ "top_p": 1,
+ "seed": seed,
+ "presence_penalty": 0,
+ "frequency_penalty": 0,
+ }
+ )
+ if enable_thinking is False:
+ extra_body["enable_thinking"] = False
+
+ kwargs: Dict[str, Any] = {"model": model_name}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["base_url"] = base_url
+ if extra_body:
+ kwargs["extra_body"] = extra_body
+
+ return OpenAIChatModel(**kwargs)
+
+ async def aevaluate(
+ self,
+ traj: Any,
+ **_: Any,
+ ) -> GraderScore:
+ """
+ 入口:必须喂 traj(完整对话轨迹)
+
+ Args:
+ traj: 对话轨迹,格式为 [{"role": ..., "content": ...}, ...]
+ 或者 {"messages": [...]} 格式
+
+ Returns:
+ GraderScore(name, score, reason)
+ """
+ # 1. 提取 messages(兼容两种格式)
+ if isinstance(traj, dict):
+ messages_list = traj.get("messages", [])
+ elif isinstance(traj, list):
+ messages_list = traj
+ else:
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason="BadInput: traj must be list or dict with 'messages'",
+ )
+
+ if not messages_list:
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason="BadInput: empty trajectory",
+ )
+
+ # 2. 构建 prompt
+ user_prompt = construct_reward_prompt(messages_list, GROUNDING_USER_PROMPT_TEMPLATE)
+
+ messages = [
+ {"role": "system", "content": GROUNDING_SYSTEM_PROMPT},
+ {"role": "user", "content": user_prompt}
+ ]
+
+ # 3. 调用模型
+ try:
+ resp = await self.model.achat(messages)
+ raw_text = getattr(resp, "content", None)
+ if raw_text is None:
+ raw_text = str(resp)
+ except Exception as e:
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason=f"ModelCallError: {type(e).__name__}: {e}",
+ )
+
+ # 4. 解析 JSON
+ obj, jerr = strict_load_json(str(raw_text))
+ if obj is None:
+ snippet = str(raw_text)[:200].replace("\n", " ")
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason=f"ParseError: {jerr}; raw[:200]={snippet}",
+ )
+
+ obj, serr = validate_shape(obj)
+ if obj is None:
+ snippet = str(raw_text)[:200].replace("\n", " ")
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason=f"SchemaError: {serr}; raw[:200]={snippet}",
+ )
+
+ # 5. 计算分数
+ score, reason = self._compute_scores(obj)
+ return GraderScore(name=self.name, score=score, reason=reason)
+
+ def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]:
+ """
+ 根据 LLM 返回的结果计算评分
+
+ Args:
+ obj: LLM 返回的 JSON,包含 total_key_facts, cited_key_facts, fake_count 等
+
+ Returns:
+ (score, reason) 元组
+ """
+ total_key_facts = obj.get('total_key_facts', 0)
+ cited_key_facts = obj.get('cited_key_facts', 0)
+ fake_count = obj.get('fake_count', 0)
+ missing_count = obj.get('missing_count', 0)
+
+ # invalid refs: 结构化/可追溯性问题
+ invalid_reference_nums = obj.get('invalid_reference_nums', [])
+ if not isinstance(invalid_reference_nums, list):
+ invalid_reference_nums = []
+ invalid_ref_count = len(invalid_reference_nums)
+
+ # 边界情况:没有关键事实,直接返回 0
+ if total_key_facts == 0:
+ citation_coverage_score = 0.0
+ grounding_score = 0.0
+ else:
+ # coverage: 引用覆盖率
+ citation_coverage_score = cited_key_facts / total_key_facts
+
+ # grounding: 引用真实性(已引用中非虚假的比例)
+ if cited_key_facts == 0:
+ grounding_score = 0.0
+ else:
+ grounding_score = max(0.0, 1 - fake_count / cited_key_facts)
+
+ # 轻量惩罚:存在 invalid refs 会降低 reward
+ # 每个 invalid 号扣 0.1,最多扣 0.5
+ invalid_penalty = min(0.1 * invalid_ref_count, 0.5)
+
+ # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚
+ final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score
+ final_reward = max(0.0, final_reward - invalid_penalty)
+
+ # 构建 reason
+ good_citations = obj.get('good_citations', [])
+ good_str = "; ".join(str(x)[:50] for x in good_citations[:2]) if good_citations else ""
+
+ parts: List[str] = [
+ f"total={total_key_facts}",
+ f"cited={cited_key_facts}",
+ f"missing={missing_count}",
+ f"fake={fake_count}",
+ f"invalid={invalid_ref_count}",
+ f"coverage={citation_coverage_score:.3f}",
+ f"grounding={grounding_score:.3f}",
+ f"penalty={invalid_penalty:.2f}",
+ ]
+ if good_str:
+ parts.append(f"good:[{good_str}]")
+
+ reason = " | ".join(parts)
+ return round(final_reward, 6), reason[:800]
diff --git a/tutorial/example_deep_finance/judge/grounding/json_utils.py b/tutorial/example_deep_finance/judge/grounding/json_utils.py
new file mode 100644
index 0000000..a3f793a
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/grounding/json_utils.py
@@ -0,0 +1,267 @@
+from __future__ import annotations
+
+import json
+import re
+from typing import Any, Dict, List, Tuple
+
+_JSON_RE = re.compile(r"\{.*\}", re.DOTALL)
+
+
+def extract_first_json_object(text: str) -> str | None:
+ """
+ Best-effort: extract the first {...} block.
+ If none found, return None.
+ """
+ if not text:
+ return None
+ m = _JSON_RE.search(text.strip())
+ if not m:
+ return None
+ return m.group(0)
+
+
+def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]:
+ """
+ Return (obj, error). Any parse failure => (None, error_msg)
+ """
+ js = extract_first_json_object(text)
+ if js is None:
+ return None, "No JSON object found in model output"
+ try:
+ obj = json.loads(js)
+ if not isinstance(obj, dict):
+ return None, f"Top-level JSON is not an object: {type(obj).__name__}"
+ return obj, None
+ except Exception as e:
+ return None, f"{type(e).__name__}: {e}"
+
+
+def get_bool_pass(item: Any) -> bool:
+ if isinstance(item, dict):
+ v = item.get("pass")
+ else:
+ v = item
+ if isinstance(v, bool):
+ return v
+ if isinstance(v, (int, float)):
+ return bool(v)
+ if isinstance(v, str):
+ return v.strip().lower() in {"true", "1", "yes", "y"}
+ return False
+
+
+def get_note(item: Any) -> str:
+ if isinstance(item, dict):
+ note = item.get("note", "")
+ else:
+ note = ""
+ note = "" if note is None else str(note)
+ note = note.strip()
+ # 最多给点余量,避免reason爆长
+ return note[:120]
+
+
+def validate_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]:
+ """
+ 验证 grounding JSON 结构
+
+ 必需字段:
+ - total_key_facts: int
+ - cited_key_facts: int
+ - missing_count: int
+ - fake_count: int
+ - good_citations: list
+ - invalid_reference_nums: list
+ """
+ # 必需的 int 字段
+ int_fields = ["total_key_facts", "cited_key_facts", "missing_count", "fake_count"]
+ for field in int_fields:
+ if field not in obj:
+ return None, f"Missing field: {field}"
+ val = obj[field]
+ # 尝试转换为 int
+ if isinstance(val, (int, float)):
+ obj[field] = int(val)
+ elif isinstance(val, str) and val.isdigit():
+ obj[field] = int(val)
+ elif not isinstance(val, int):
+ return None, f"Field '{field}' must be int, got {type(val).__name__}"
+
+ # good_citations 必须是 list
+ if "good_citations" not in obj:
+ obj["good_citations"] = []
+ elif not isinstance(obj["good_citations"], list):
+ obj["good_citations"] = []
+ else:
+ # 确保每个元素是字符串,最多保留 2 条
+ obj["good_citations"] = [str(x) for x in obj["good_citations"][:2]]
+
+ # invalid_reference_nums 必须是 list
+ if "invalid_reference_nums" not in obj:
+ obj["invalid_reference_nums"] = []
+ elif not isinstance(obj["invalid_reference_nums"], list):
+ obj["invalid_reference_nums"] = []
+ else:
+ # 确保每个元素是 int,最多保留 5 个
+ nums = []
+ for x in obj["invalid_reference_nums"][:5]:
+ if isinstance(x, int):
+ nums.append(x)
+ elif isinstance(x, (float, str)):
+ try:
+ nums.append(int(x))
+ except ValueError:
+ pass
+ obj["invalid_reference_nums"] = sorted(nums)
+
+ return obj, None
+
+
+
+
+# =============================================================================
+# Trajectory 处理辅助函数
+# =============================================================================
+
+def _extract_text_content(content) -> str:
+ """统一提取纯文本内容"""
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ out = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ out.append(item.get("text", ""))
+ elif isinstance(item, str):
+ out.append(item)
+ return "\n".join(out)
+ return str(content)
+
+
+def _strip_think(text: str) -> str:
+ """去除 ... 标签"""
+ return re.sub(r".*?\s*", "", text, flags=re.S).strip()
+
+
+def _strip_markdown_fences(text: str) -> str:
+ """
+ 清理 markdown 代码块标记
+ - 移除开头的 ```markdown / ```md / ``` 等
+ - 移除结尾的 ```
+ """
+ text = text.strip()
+ # 移除开头的 ```xxx
+ text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE)
+ # 移除结尾的 ```
+ text = re.sub(r'\n?```\s*$', '', text)
+ return text.strip()
+
+
+def _normalize_traj(trajectory):
+ """兼容 [[...]] 格式"""
+ if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list):
+ return trajectory[0]
+ return trajectory
+
+
+def _extract_tool_call_json(text: str) -> str:
+ """提取工具调用 JSON"""
+ m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text)
+ if m:
+ return m.group(1).strip()
+ l, r = text.find("["), text.rfind("]")
+ if l != -1 and r != -1 and r > l:
+ cand = text[l:r+1].strip()
+ if ("tool_name" in cand) and ("tool_args" in cand):
+ return cand
+ return ""
+
+
+def _looks_like_tool_result(text: str) -> bool:
+ """判断是否为工具返回结果"""
+ t = text.strip()
+ if t.startswith("Tool:") or t.startswith("Result:"):
+ return True
+ if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t):
+ return True
+ if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "):
+ return True
+ return False
+
+
+def _is_probably_final_report(text: str) -> bool:
+ """判断是否为最终报告"""
+ t = text.strip()
+ return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ")
+
+
+def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str:
+ """
+ 从 trajectory 构建 reward prompt
+
+ Args:
+ trajectory: 对话轨迹 [{"role": ..., "content": ...}, ...]
+
+ Returns:
+ 构建好的 user prompt 字符串
+ """
+ traj = _normalize_traj(trajectory)
+ if not traj:
+ traj = []
+
+ user_query = ""
+ tool_calls: List[str] = []
+ evidence: List[str] = []
+ final_report = ""
+
+ # 找到 final report(从后往前找第一个符合条件的 assistant 消息)
+ for i in range(len(traj) - 1, -1, -1):
+ step = traj[i]
+ if step.get("role") == "assistant":
+ txt = _strip_think(_extract_text_content(step.get("content")))
+ if _is_probably_final_report(txt):
+ final_report = txt
+ break
+ if not final_report:
+ for i in range(len(traj) - 1, -1, -1):
+ if traj[i].get("role") == "assistant":
+ final_report = _strip_think(_extract_text_content(traj[i].get("content")))
+ break
+
+ # 清理 markdown 代码块标记
+ final_report = _strip_markdown_fences(final_report)
+
+ # 遍历提取 user_query, tool_calls, evidence
+ for idx, step in enumerate(traj):
+ role = step.get("role")
+ raw = _extract_text_content(step.get("content"))
+ txt = _strip_think(raw)
+ if not raw:
+ continue
+
+ if role == "user" and not user_query and (not _looks_like_tool_result(raw)):
+ user_query = txt
+ continue
+
+ if role == "assistant":
+ call_json = _extract_tool_call_json(raw)
+ if call_json:
+ tool_calls.append(f"[Step {idx}] TOOL_CALL:\n{call_json}")
+
+ if role in ("tool", "user"):
+ if _looks_like_tool_result(raw):
+ evidence.append(f"[Step {idx}] EVIDENCE_TOOL_RESULT:\n{raw}")
+ else:
+ # query 之后的用户补充也保留为 evidence
+ if user_query:
+ evidence.append(f"[Step {idx}] EVIDENCE_USER_CONTEXT:\n{txt}")
+
+ evidence_text = "\n\n".join(tool_calls + evidence)
+
+ return user_prompt_template.format(
+ user_query=user_query,
+ evidence_text=evidence_text,
+ final_report=final_report
+ ).strip()
diff --git a/tutorial/example_deep_finance/judge/grounding/prompt.py b/tutorial/example_deep_finance/judge/grounding/prompt.py
new file mode 100644
index 0000000..24bea13
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/grounding/prompt.py
@@ -0,0 +1,117 @@
+"""Grounding Grader Prompt - 引用规范性评估"""
+
+GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。
+
+========================
+一、引用规范(以此为准)
+========================
+1) 关键事实句必须引用:
+ - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。
+ - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。
+
+2) 引用位置规则(严格执行):
+ - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。
+ - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。
+
+3) References 必须存在且可追溯:
+ - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。
+ - 正文出现的每个 [n] 必须能在 References 中找到对应条目。
+
+4) References 条目两种合法形式(必须满足其一):
+ A) URL 形式:`[n] 标题或简述 - https://...`
+ - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。
+ B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)`
+ - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。
+ - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。
+
+========================
+二、输入
+========================
+你会收到:
+- User Query
+- Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息)
+- AI Report(待审计报告,含正文与 References)
+
+真实性核对原则:
+- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。
+- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。
+
+========================
+三、统计与判定口径(严格遵守)
+========================
+【文本范围】
+- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。
+- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。
+
+【句子/条目如何计数】
+- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。
+- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。
+- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。
+
+【关键事实句识别(务求稳定)】
+- 满足任一条件可视为关键事实句:
+ (a) 含具体数值/比例/排名/区间/估值倍数/财务指标;
+ (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”);
+ (c) 对具体公司/行业/政策做了可验证的确定性陈述;
+ (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。
+
+【引用是否“句末”】【重要】
+- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如:
+ - “……增长 20%[3]”
+ - “……增长 20% [3][4]”
+- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。
+
+【invalid_reference_nums 的定义】
+- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid:
+ (a) References 中不存在该编号条目;
+ (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等);
+ (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。
+- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。
+
+【missing_count 的定义】
+- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。
+
+【cited_key_facts 的定义】
+- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。
+
+【fake_count 的定义(只在明显时计数)】
+- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。
+- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。
+
+【good_citations 的定义】
+- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足:
+ - 是关键事实句;
+ - 句末有 [n];
+ - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。
+- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。
+
+========================
+四、输出(只输出 JSON,字段固定)
+========================
+{
+ "total_key_facts": ,
+ "cited_key_facts": ,
+ "good_citations": ["...", "..."],
+ "missing_count": ,
+ "fake_count": ,
+ "invalid_reference_nums": [, ...]
+}
+
+只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。
+"""
+
+# =============================================================================
+# User Prompt Template
+# =============================================================================
+
+GROUNDING_USER_PROMPT_TEMPLATE = """请审计以下 AI 研究报告的引用规范性,只输出 JSON。
+
+### User Query
+{user_query}
+
+### Evidence
+{evidence_text}
+
+### AI Report(待审计报告)
+{final_report}
+"""
diff --git a/tutorial/example_deep_finance/judge/grounding/reference.py b/tutorial/example_deep_finance/judge/grounding/reference.py
new file mode 100644
index 0000000..6e67a38
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/grounding/reference.py
@@ -0,0 +1,363 @@
+GROUNDING_SYSTEM_PROMPT = """你是一位“引用审计员”,负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。
+
+========================
+一、引用规范(以此为准)
+========================
+1) 关键事实句必须引用:
+ - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。
+ - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。
+
+2) 引用位置规则(严格执行):
+ - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。
+ - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。
+
+3) References 必须存在且可追溯:
+ - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。
+ - 正文出现的每个 [n] 必须能在 References 中找到对应条目。
+
+4) References 条目两种合法形式(必须满足其一):
+ A) URL 形式:`[n] 标题或简述 - https://...`
+ - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。
+ B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)`
+ - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。
+ - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。
+
+========================
+二、输入
+========================
+你会收到:
+- User Query
+- Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息)
+- AI Report(待审计报告,含正文与 References)
+
+真实性核对原则:
+- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。
+- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。
+
+========================
+三、统计与判定口径(严格遵守)
+========================
+【文本范围】
+- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。
+- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。
+
+【句子/条目如何计数】
+- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。
+- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。
+- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。
+
+【关键事实句识别(务求稳定)】
+- 满足任一条件可视为关键事实句:
+ (a) 含具体数值/比例/排名/区间/估值倍数/财务指标;
+ (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”);
+ (c) 对具体公司/行业/政策做了可验证的确定性陈述;
+ (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。
+
+【引用是否“句末”】【重要】
+- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如:
+ - “……增长 20%[3]”
+ - “……增长 20% [3][4]”
+- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。
+
+【invalid_reference_nums 的定义】
+- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid:
+ (a) References 中不存在该编号条目;
+ (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等);
+ (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。
+- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。
+
+【missing_count 的定义】
+- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。
+
+【cited_key_facts 的定义】
+- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。
+
+【fake_count 的定义(只在明显时计数)】
+- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。
+- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。
+
+【good_citations 的定义】
+- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足:
+ - 是关键事实句;
+ - 句末有 [n];
+ - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。
+- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。
+
+========================
+四、输出(只输出 JSON,字段固定)
+========================
+{
+ "total_key_facts": ,
+ "cited_key_facts": ,
+ "good_citations": ["...", "..."],
+ "missing_count": ,
+ "fake_count": ,
+ "invalid_reference_nums": [, ...]
+}
+
+只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。
+"""
+
+
+
+import json
+import re
+from typing import Dict, Any, List
+
+
+def _extract_text_content(content) -> str:
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ out = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ out.append(item.get("text", ""))
+ elif isinstance(item, str):
+ out.append(item)
+ return "\n".join(out)
+ return str(content)
+
+def _strip_think(text: str) -> str:
+ return re.sub(r".*?\s*", "", text, flags=re.S).strip()
+
+def _normalize_traj(trajectory):
+ # 兼容 [[...]] :contentReference[oaicite:1]{index=1}
+ if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list):
+ return trajectory[0]
+ return trajectory
+
+def _extract_tool_call_json(text: str) -> str:
+ m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text)
+ if m:
+ return m.group(1).strip()
+ l, r = text.find("["), text.rfind("]")
+ if l != -1 and r != -1 and r > l:
+ cand = text[l:r+1].strip()
+ if ("tool_name" in cand) and ("tool_args" in cand):
+ return cand
+ return ""
+
+def _looks_like_tool_result(text: str) -> bool:
+ t = text.strip()
+ if t.startswith("Tool:") or t.startswith("Result:"):
+ return True
+ if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t):
+ return True
+ if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "):
+ return True
+ return False
+
+def _is_probably_final_report(text: str) -> bool:
+ t = text.strip()
+ return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ")
+
+def construct_reward_prompt(trajectory: List[Dict[str, Any]]) -> str:
+ traj = _normalize_traj(trajectory)
+
+ user_query = ""
+ tool_calls: List[str] = []
+ evidence: List[str] = []
+ final_report = ""
+
+ # final report
+ for i in range(len(traj) - 1, -1, -1):
+ step = traj[i]
+ if step.get("role") == "assistant":
+ txt = _strip_think(_extract_text_content(step.get("content")))
+ if _is_probably_final_report(txt):
+ final_report = txt
+ break
+ if not final_report:
+ for i in range(len(traj) - 1, -1, -1):
+ if traj[i].get("role") == "assistant":
+ final_report = _strip_think(_extract_text_content(traj[i].get("content")))
+ break
+
+ # iterate
+ for idx, step in enumerate(traj):
+ role = step.get("role")
+ raw = _extract_text_content(step.get("content"))
+ txt = _strip_think(raw)
+ if not raw:
+ continue
+
+ if role == "user" and not user_query and (not _looks_like_tool_result(raw)):
+ user_query = txt
+ continue
+
+ if role == "assistant":
+ call_json = _extract_tool_call_json(raw)
+ if call_json:
+ tool_calls.append(f"[Step {idx}] TOOL_CALL:\n{call_json}")
+
+ if role in ("tool", "user"):
+ if _looks_like_tool_result(raw):
+ evidence.append(f"[Step {idx}] EVIDENCE_TOOL_RESULT:\n{raw}")
+ else:
+ # query 之后的用户补充也保留为 evidence(有些系统会把 tool_result 注入到 user)
+ if user_query:
+ evidence.append(f"[Step {idx}] EVIDENCE_USER_CONTEXT:\n{txt}")
+
+ evidence_text = "\n\n".join(tool_calls + evidence)
+
+ return f"""请审计以下 AI 研究报告的引用规范性,只输出 JSON。
+
+### User Query
+{user_query}
+
+### Evidence(来自完整 trajectory)
+{evidence_text}
+
+### AI Report(待审计报告)
+{final_report}
+""".strip()
+
+
+class RefJudgeEvaluator:
+ """
+ 引用规范性评估器
+
+ 使用 LLM 评估报告的引用覆盖率和引用真实性。
+ """
+
+ def __init__(self, llm_client):
+ """
+ 初始化评估器
+
+ Args:
+ llm_client: LLMJudgeClient 实例
+ """
+ self.llm_client = llm_client
+ print("✓ RefJudgeEvaluator: Initialized")
+
+ def build_messages(self, conversation_history: List[Dict]) -> List[Dict[str, str]]:
+ """
+ 从对话历史构建 LLM 评估消息
+
+ Args:
+ conversation_history: 对话历史 [{"role": "...", "content": "..."}]
+
+ Returns:
+ LLM 消息列表
+ """
+ print(f"\n[RefJudgeEvaluator] 构建评估消息...")
+ print(f" - 对话历史轮数: {len(conversation_history)}")
+
+ # 调用现有的 prompt 构建函数
+ user_prompt = construct_reward_prompt(conversation_history)
+
+ messages = [
+ {"role": "system", "content": GROUNDING_SYSTEM_PROMPT},
+ {"role": "user", "content": user_prompt}
+ ]
+
+ print(f" ✓ 消息构建完成,system prompt 长度: {len(GROUNDING_SYSTEM_PROMPT)}")
+ print(f" ✓ user prompt 长度: {len(user_prompt)}")
+
+ return messages
+
+ def _compute_scores(self, raw_result: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ 根据 LLM 返回的原始结果计算评分
+
+ Args:
+ raw_result: LLM 返回的 JSON,包含 total_key_facts, cited_key_facts, fake_count 等
+
+ Returns:
+ 包含 citation_coverage_score, grounding_score, final_reward 的字典
+ """
+ total_key_facts = raw_result.get('total_key_facts', 0)
+ cited_key_facts = raw_result.get('cited_key_facts', 0)
+ fake_count = raw_result.get('fake_count', 0)
+
+ # invalid refs: 结构化/可追溯性问题(来自 prompt 的 invalid_reference_nums)
+ invalid_reference_nums = raw_result.get('invalid_reference_nums', [])
+ if not isinstance(invalid_reference_nums, list):
+ invalid_reference_nums = []
+ invalid_ref_count = len(invalid_reference_nums)
+
+ # 边界情况:没有关键事实,直接返回 0
+ if total_key_facts == 0:
+ citation_coverage_score = 0.0
+ grounding_score = 0.0
+ else:
+ # coverage: 引用覆盖率
+ citation_coverage_score = cited_key_facts / total_key_facts
+
+ # grounding: 引用真实性(已引用中非虚假的比例)
+ if cited_key_facts == 0:
+ grounding_score = 0.0
+ else:
+ grounding_score = max(0.0, 1 - fake_count / cited_key_facts)
+
+ # 轻量惩罚:存在 invalid refs 会降低 reward(但不改变 cited_key_facts 的统计口径)
+ # 说明:invalid_reference_nums 在 prompt 中已定义为“正文出现过的不合规编号(去重)”。
+ # 这里采用简单、确定性的惩罚:每个 invalid 号扣 0.1,最多扣 0.5。
+ invalid_penalty = min(0.1 * invalid_ref_count, 0.5)
+
+ # final_reward: 综合分数(代码计算,权重 0.5:0.5),再叠加 invalid 惩罚
+ final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score
+ final_reward = max(0.0, final_reward - invalid_penalty)
+
+ return {
+ 'citation_coverage_score': citation_coverage_score,
+ 'grounding_score': grounding_score,
+ 'final_reward': final_reward,
+ 'invalid_ref_count': invalid_ref_count,
+ 'invalid_penalty': invalid_penalty,
+ }
+
+ async def evaluate_async(self, conversation_history: List[Dict]) -> Dict[str, Any]:
+ """
+ 异步评估引用规范性
+
+ Args:
+ conversation_history: 对话历史
+
+ Returns:
+ 评估结果字典,包含:
+ - citation_coverage_score: 引用覆盖率分数 (0.0-1.0)
+ - grounding_score: 引用真实性分数 (0.0-1.0)
+ - final_reward: 最终奖励分数 (0.0-1.0)
+ - total_key_facts, cited_key_facts, fake_count 等原始字段
+ """
+ # print(f"\n开始评估引用规范性...")
+
+ messages = self.build_messages(conversation_history)
+ raw_result = await self.llm_client.evaluate_async(messages)
+
+ # 计算评分
+ scores = self._compute_scores(raw_result)
+
+ # 合并原始结果和计算的评分
+ result = {**raw_result, **scores}
+
+ # 确保必要字段存在
+ result.setdefault('total_key_facts', 0)
+ result.setdefault('cited_key_facts', 0)
+ result.setdefault('missing_count', 0)
+ result.setdefault('fake_count', 0)
+ result.setdefault('invalid_reference_nums', [])
+ result.setdefault('good_citations', [])
+
+ print(f" ✓ [RefJudgeEvaluator] 引用规范性评估完成:")
+ print(f" - total_key_facts: {result['total_key_facts']}")
+ print(f" - cited_key_facts: {result['cited_key_facts']}")
+ print(f" - fake_count: {result['fake_count']}")
+ print(f" - invalid_ref_count: {result.get('invalid_ref_count', 0)}")
+ print(f" - invalid_penalty: {result.get('invalid_penalty', 0.0):.4f}")
+ print(f" - citation_coverage_score: {result['citation_coverage_score']:.4f}")
+ print(f" - grounding_score: {result['grounding_score']:.4f}")
+ print(f" - final_reward: {result['final_reward']:.4f}")
+
+ return result
+
+ def evaluate_sync(self, conversation_history: List[Dict]) -> Dict[str, Any]:
+ """
+ 同步评估引用规范性
+ """
+ import asyncio
+ return asyncio.run(self.evaluate_async(conversation_history))
diff --git a/tutorial/example_deep_finance/judge/presentation_quality/__init__.py b/tutorial/example_deep_finance/judge/presentation_quality/__init__.py
new file mode 100644
index 0000000..2db690f
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/presentation_quality/__init__.py
@@ -0,0 +1,4 @@
+"""Grounding Grader - 引用规范性评估"""
+from .grader import PresentationQualityGrader
+
+__all__ = ["PresentationQualityGrader"]
diff --git a/tutorial/example_deep_finance/judge/presentation_quality/grader.py b/tutorial/example_deep_finance/judge/presentation_quality/grader.py
new file mode 100644
index 0000000..c440c3e
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/presentation_quality/grader.py
@@ -0,0 +1,211 @@
+from __future__ import annotations
+
+import os
+import re
+from typing import Any, Dict, List, Tuple
+
+from openjudge.graders.base_grader import BaseGrader
+from openjudge.graders.schema import GraderScore
+
+# import path 兼容两种写法(文档里两种都出现过)
+try:
+ from openjudge.models import OpenAIChatModel
+except Exception: # pragma: no cover
+ from openjudge.models.openai_chat_model import OpenAIChatModel
+
+from .prompt import (
+ QUALITY_SYSTEM_PROMPT,
+ USER_PROMPT_TEMPLATE,
+ ALL_KEYS,
+ A_KEYS,
+ B_KEYS,
+ C_KEYS,
+)
+from .json_utils import strict_load_json, validate_shape, get_score, get_note
+
+
+class PresentationQualityGrader(BaseGrader):
+ """
+ - 输入:report_content(研究报告文本)
+ - 输出:GraderScore(name, score, reason)
+ - score:8项按1/3/5分制评分,总分归一化到[0,1](总分/40)
+ - determinism:建议用 temperature=0 + disable thinking 等(见 create_default_model)
+ - 解析失败:score=0,并在 reason 显示报错
+ """
+
+ def __init__(
+ self,
+ model: OpenAIChatModel,
+ name: str = "presentation_quality",
+ **kwargs: Any,
+ ):
+ super().__init__(name=name, **kwargs)
+ self.model = model
+
+ @staticmethod
+ def create_default_model(
+ model_name: str,
+ api_key: str | None = None,
+ base_url: str | None = None,
+ deterministic: bool = True,
+ enable_thinking: bool = False,
+ seed: int = 0,
+ ) -> OpenAIChatModel:
+ """
+ 你也可以不调用这个工厂,自己在外面 new OpenAIChatModel。
+ QuickStart 文档确认 OpenAIChatModel 会从 OPENAI_API_KEY/OPENAI_BASE_URL 读取。
+ """
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
+ base_url = base_url or os.getenv("OPENAI_BASE_URL")
+
+ extra_body: Dict[str, Any] = {}
+ if deterministic:
+ # OpenAI兼容接口常见字段;DashScope/Qwen 常用 enable_thinking
+ extra_body.update(
+ {
+ "temperature": 0,
+ "top_p": 1,
+ "seed": seed,
+ "presence_penalty": 0,
+ "frequency_penalty": 0,
+ }
+ )
+ if enable_thinking is False:
+ extra_body["enable_thinking"] = False
+
+ kwargs: Dict[str, Any] = {"model": model_name}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["base_url"] = base_url
+ if extra_body:
+ kwargs["extra_body"] = extra_body
+
+ return OpenAIChatModel(**kwargs)
+
+ async def aevaluate(
+ self,
+ report_content: str,
+ user_query: str | None = None,
+ **_: Any,
+ ) -> GraderScore:
+ """
+ 入口:直接喂 report_content(研究报告文本)
+ - user_query 可选:用于填充 prompt;不提供则用 "(unknown)"
+ """
+
+
+ report = (report_content or "").strip()
+
+ # 清理 markdown 代码块标记
+ report = self._strip_markdown_fences(report)
+
+ if not report:
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason="BadInput: empty report_content",
+ )
+
+ uq = (user_query or "").strip() or "(unknown)"
+
+ user_content = USER_PROMPT_TEMPLATE.format(
+ user_query=uq,
+ report_content=report,
+ )
+ messages = [
+ {"role": "system", "content": QUALITY_SYSTEM_PROMPT},
+ {"role": "user", "content": user_content},
+ ]
+
+ # 核心:OpenJudge 的 OpenAIChatModel 支持 await model.achat([...]),并返回 .content
+ try:
+ resp = await self.model.achat(messages)
+ raw_text = getattr(resp, "content", None)
+ if raw_text is None:
+ raw_text = str(resp)
+ except Exception as e:
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason=f"ModelCallError: {type(e).__name__}: {e}",
+ )
+
+ obj, jerr = strict_load_json(str(raw_text))
+ if obj is None:
+ snippet = str(raw_text)[:200].replace("\n", " ")
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason=f"ParseError: {jerr}; raw[:200]={snippet}",
+ )
+
+ obj, serr = validate_shape(obj)
+ if obj is None:
+ snippet = str(raw_text)[:200].replace("\n", " ")
+ return GraderScore(
+ name=self.name,
+ score=0.0,
+ reason=f"SchemaError: {serr}; raw[:200]={snippet}",
+ )
+
+ score, reason = self._score_and_reason(obj)
+
+ return GraderScore(name=self.name, score=score, reason=reason)
+
+ def _score_and_reason(self, obj: Dict[str, Any]) -> Tuple[float, str]:
+ scan = obj["scan"]
+ structuring = obj["structuring"]
+ editorial = obj["editorial"]
+ top_fixes = obj.get("top_fixes", [])
+
+ # 8项按1/3/5分制计分(强确定性:完全由Python算)
+ score_map: Dict[str, int] = {}
+ note_map: Dict[str, str] = {}
+
+ def take(section: Dict[str, Any], key: str):
+ item = section.get(key)
+ score_map[key] = get_score(item)
+ note_map[key] = get_note(item)
+
+ for k in A_KEYS:
+ take(scan, k)
+ for k in B_KEYS:
+ take(structuring, k)
+ for k in C_KEYS:
+ take(editorial, k)
+
+ # 总分 = 各项得分之和 / 最高可能分 (8*5=40),归一化到[0,1]
+ total_score = sum(score_map.get(k, 1) for k in ALL_KEYS)
+ max_score = len(ALL_KEYS) * 5 # 8 * 5 = 40
+ score = total_score / float(max_score)
+
+ # reason:按分数排序,列出低分项
+ low_items = [(k, score_map.get(k, 1)) for k in ALL_KEYS if score_map.get(k, 1) < 5]
+ low_items.sort(key=lambda x: x[1]) # 从低到高
+ low_str = ", ".join(f"{k}={s}({note_map.get(k,'')})" for k, s in low_items[:4])
+ fixes_str = " | ".join(str(x) for x in (top_fixes or [])[:3])
+
+ parts: List[str] = []
+ parts.append(f"Score {total_score}/{max_score}")
+ if low_items:
+ parts.append(f"Low: {low_str}")
+ if fixes_str:
+ parts.append(f"TopFixes: {fixes_str}")
+
+ reason = " ; ".join(parts)
+ return round(score, 6), reason[:800]
+
+ @staticmethod
+ def _strip_markdown_fences(text: str) -> str:
+ """
+ 清理 markdown 代码块标记
+ - 移除开头的 ```markdown / ```md / ``` 等
+ - 移除结尾的 ```
+ """
+ text = text.strip()
+ # 移除开头的 ```xxx
+ text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE)
+ # 移除结尾的 ```
+ text = re.sub(r'\n?```\s*$', '', text)
+ return text.strip()
diff --git a/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py
new file mode 100644
index 0000000..2852ff8
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py
@@ -0,0 +1,107 @@
+from __future__ import annotations
+
+import json
+import re
+from typing import Any, Dict, Tuple
+
+
+_JSON_RE = re.compile(r"\{.*\}", re.DOTALL)
+
+
+def extract_first_json_object(text: str) -> str | None:
+ """
+ Best-effort: extract the first {...} block.
+ If none found, return None.
+ """
+ if not text:
+ return None
+ m = _JSON_RE.search(text.strip())
+ if not m:
+ return None
+ return m.group(0)
+
+
+def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]:
+ """
+ Return (obj, error). Any parse failure => (None, error_msg)
+ """
+ js = extract_first_json_object(text)
+ if js is None:
+ return None, "No JSON object found in model output"
+ try:
+ obj = json.loads(js)
+ if not isinstance(obj, dict):
+ return None, f"Top-level JSON is not an object: {type(obj).__name__}"
+ return obj, None
+ except Exception as e:
+ return None, f"{type(e).__name__}: {e}"
+
+
+def get_bool_pass(item: Any) -> bool:
+ if isinstance(item, dict):
+ v = item.get("pass")
+ else:
+ v = item
+ if isinstance(v, bool):
+ return v
+ if isinstance(v, (int, float)):
+ return bool(v)
+ if isinstance(v, str):
+ return v.strip().lower() in {"true", "1", "yes", "y"}
+ return False
+
+
+def get_score(item: Any) -> int:
+ """
+ Extract numeric score (1, 3, 5) from item.
+ Returns 1 as default if invalid.
+ """
+ if isinstance(item, dict):
+ v = item.get("score")
+ else:
+ v = item
+ if isinstance(v, (int, float)):
+ v = int(v)
+ if v in (1, 3, 5):
+ return v
+ # clamp to valid range
+ if v <= 1:
+ return 1
+ if v >= 5:
+ return 5
+ return 3
+ return 1
+
+
+def get_note(item: Any) -> str:
+ if isinstance(item, dict):
+ note = item.get("note", "")
+ else:
+ note = ""
+ note = "" if note is None else str(note)
+ note = note.strip()
+ # 最多给点余量,避免reason爆长
+ return note[:120]
+
+
+def validate_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]:
+ """
+ Ensure required sections exist and are dicts; ensure top_fixes is list or str.
+ If missing required field => error.
+ """
+ for sec in ("scan", "structuring", "editorial"):
+ if sec not in obj:
+ return None, f"Missing field: {sec}"
+ if not isinstance(obj[sec], dict):
+ return None, f"Field '{sec}' is not an object"
+ if "top_fixes" not in obj:
+ return None, "Missing field: top_fixes"
+ # normalize top_fixes
+ tf = obj.get("top_fixes")
+ if isinstance(tf, list):
+ obj["top_fixes"] = [str(x) for x in tf][:3]
+ elif tf is None:
+ obj["top_fixes"] = []
+ else:
+ obj["top_fixes"] = [str(tf)][:3]
+ return obj, None
diff --git a/tutorial/example_deep_finance/judge/presentation_quality/prompt.py b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py
new file mode 100644
index 0000000..5e945bf
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py
@@ -0,0 +1,108 @@
+# 8项呈现质量检查:A(3)+B(3)+C(2)=8
+QUALITY_SYSTEM_PROMPT = """
+你是一位“深度研究报告呈现评审官”。你的任务是评估报告的 **用户体验与信息架构 (Presentation & UX)**,为强化学习提供奖励信号。
+
+**严禁评估**:事实真伪、引用准确性(由 Grounding 模型负责)、内容广度与深度。
+**核心关注**:**认知负荷管理**、**信息的可扫读性**、**逻辑的可视化**、**Markdown 渲染质量**。
+
+========================
+评分标准 (1/3/5 分制)
+========================
+对以下 8 个维度进行打分。
+- **1分 (Fail)**:严重阻碍阅读,格式混乱或缺失。
+- **3分 (Pass)**:甚至及格,有基本结构,但平庸、啰嗦或不够直观。
+- **5分 (Excellent)**:出版级质量,结构极佳,一眼能抓取核心,降低了读者的认知成本。
+
+请针对每个子项给出分数(1, 3, 5)及 Note(≤25字,指出具体位置或症状)。
+
+### A) Scan & Navigation(可扫描性)
+**A1 结论先行 (Key Takeaways Top)**
+- 5分:开头有独立的“核心摘要/TL;DR”块,且要点清晰,读者无需滚动即可获取主结论。
+- 3分:有摘要,但写成了流水账段落,或混杂在正文中不够醒目。
+- 1分:无摘要,开篇即陷入细节或背景介绍。
+
+**A2 结构导航 (Navigable Structure)**
+- 5分:层级分明 (H1/H2/H3),长文有清晰的“路标”(小标题),支持快速跳读定位。
+- 3分:有分节,但段落过长(Wall of text),缺乏内部视觉引导。
+- 1分:结构混乱,标题层级错误或缺失,难以导航。
+
+**A3 视觉重点 (Visual Hierarchy)**
+- 5分:利用 **加粗**、*斜体* 或 `代码块` 精准强调核心洞察,信噪比高。
+- 3分:有强调,但过度使用(满篇加粗)或重点不突出(强调了无关词)。
+- 1分:全文平铺直叙,无任何视觉重点。
+
+### B) Information Structuring(信息结构化)
+**B1 密集信息解构 (Dense Info Structured)**
+- 5分:复杂数据/多条件逻辑被转化为 Markdown **表格** 或 **嵌套列表**,一目了然。
+- 3分:使用了列表,但内容仍是长难句堆砌,未真正拆解信息。
+- 1分:关键数字或复杂参数淹没在长段落文本中。
+
+**B2 对比对齐 (Comparisons Aligned)**
+- 5分:涉及对比(方案A vs B / 历史 vs 现状)时,使用表格或对齐结构,维度横向可比。
+- 3分:有对比意图,但分散在不同段落,读者需来回对照。
+- 1分:对比维度混乱或缺失,无法直观比较。
+
+**B3 一致性与渲染 (Consistency & Rendering)**
+- 5分:格式统一(符号/单位),Markdown 渲染完美(表格无断裂、公式无乱码)。
+- 3分:存在少量格式不统一,或轻微的渲染瑕疵但不影响理解。
+- 1分:表格错位、公式未闭合、列表层级混乱,严重影响阅读。
+
+### C) Editorial Clarity(编辑清晰度)
+**C1 论证链可视化 (Argument Chain Presented)**
+- 5分:逻辑链条可视(如使用 `主张 -> 证据 -> 结论` 的结构),引用锚点清晰 `[1]`。
+- 3分:逻辑存在,但淹没在文字中,缺乏连接词或视觉引导。
+- 1分:材料堆砌,缺乏清晰的推导线索。
+
+**C2 风险与行动 (Risk & Actionability Clear)**
+- 5分:独立板块清晰列出“风险/局限性”及“下一步建议”,具有极高的可操作性。
+- 3分:提到了风险或建议,但含糊其辞,或混杂在结论中。
+- 1分:完全未提及风险边界或下一步行动。
+
+**反刷分原则 (Anti-Gaming)**:
+- 空表格、无意义的重复列表、为了格式而格式(如把一句简单的话硬拆成列表) -> 直接判 **1分**,Note 标注“过度格式化”。
+
+========================
+输出要求 (Strict JSON)
+========================
+必须输出可解析 JSON。
+**注意**:为了提供梯度信号,字段由 `pass` 改为 `score`,值必须为 1, 3, or 5。
+
+JSON 模板:
+{
+ "scan": {
+ "A1_key_takeaways_top": {"score": 0, "note": "≤25字定位理由"},
+ "A2_navigable_structure": {"score": 0, "note": "≤25字定位理由"},
+ "A3_visual_hierarchy": {"score": 0, "note": "≤25字定位理由"}
+ },
+ "structuring": {
+ "B1_dense_info_structured": {"score": 0, "note": "≤25字定位理由"},
+ "B2_comparisons_aligned": {"score": 0, "note": "≤25字定位理由"},
+ "B3_consistency": {"score": 0, "note": "≤25字定位理由"}
+ },
+ "editorial": {
+ "C1_argument_chain_presented": {"score": 0, "note": "≤25字定位理由"},
+ "C2_risk_and_actionability_clear": {"score": 0, "note": "≤25字定位理由"}
+ },
+ "top_fixes": ["最多3条,仅谈呈现层面改进,针对最低分项"]
+}
+"""
+
+USER_PROMPT_TEMPLATE = """
+请审计以下研究报告的【呈现质量】(只谈呈现/排版/结构,不谈事实对错/引用支持/覆盖/深度)。
+
+### User Query
+{user_query}
+
+### AI Report
+{report_content}
+
+-----
+请严格按 System Prompt 的锚点输出 JSON;不要输出 Markdown;不要添加额外字段。
+""".strip()
+
+# 8个检查项key(用于Python均分,强确定性)
+A_KEYS = ["A1_key_takeaways_top", "A2_navigable_structure", "A3_visual_hierarchy"]
+B_KEYS = ["B1_dense_info_structured", "B2_comparisons_aligned", "B3_consistency"]
+C_KEYS = ["C1_argument_chain_presented", "C2_risk_and_actionability_clear"]
+
+ALL_KEYS = A_KEYS + B_KEYS + C_KEYS
diff --git a/tutorial/example_deep_finance/judge/traj_adapter.py b/tutorial/example_deep_finance/judge/traj_adapter.py
new file mode 100644
index 0000000..66df53f
--- /dev/null
+++ b/tutorial/example_deep_finance/judge/traj_adapter.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Tuple
+
+
+def extract_text_content(content: Any) -> str:
+ """Extract plain text from common message schemas."""
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ texts: List[str] = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ texts.append(str(item.get("text", "")))
+ elif isinstance(item, str):
+ texts.append(item)
+ return "\n".join(texts)
+ return str(content)
+
+
+def normalize_traj(traj: Any) -> List[Dict[str, Any]]:
+ """
+ Accept common traj shapes:
+ - list[{"role":..., "content":...}, ...]
+ - {"trajectory": [...]}
+ - {"messages": [...]}
+ """
+ if isinstance(traj, list):
+ return traj
+ if isinstance(traj, dict):
+ if isinstance(traj.get("trajectory"), list):
+ return traj["trajectory"]
+ if isinstance(traj.get("messages"), list):
+ return traj["messages"]
+ return []
+
+
+def infer_user_query(trajectory: List[Dict[str, Any]]) -> str:
+ for step in trajectory:
+ if step.get("role") == "user":
+ txt = extract_text_content(step.get("content"))
+ if txt.strip():
+ return txt.strip()
+ return ""
+
+
+def find_final_report(trajectory: List[Dict[str, Any]]) -> str:
+ """
+ Heuristic: last assistant long text or markdown-like content.
+ """
+ for step in reversed(trajectory):
+ if step.get("role") == "assistant":
+ txt = extract_text_content(step.get("content", ""))
+ if len(txt) > 120 or "#" in txt:
+ return txt
+ return ""
+
+
+
diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
index 48824e1..33103fe 100644
--- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
+++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
@@ -1,6 +1,6 @@
# ------------------ 主要配置 ------------------
ajet:
- project_name: ajet_deep_finance
+ project_name: "{{PREFIX}}"
experiment_name: "{{SUFFIX}}"
# Judge 配置(嵌套结构,对应 self.config.ajet.judge.*)
judge:
@@ -10,9 +10,8 @@ ajet:
train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径
val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径
# OpenJudge 权重配置
- report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估
- trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估
- citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性)
+ presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估
+ grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估
rm_weight: {{RM_WEIGHT}} # RM Gallery 权重
task_judge:
# 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service)