diff --git a/.gitignore b/.gitignore index 7a35919..1698ba2 100644 --- a/.gitignore +++ b/.gitignore @@ -158,6 +158,9 @@ tutorial/example_deep_finance/yaml/* tutorial/example_deep_finance/config/* tutorial/example_deep_finance/scripts/* flash_attn-2.8.*.whl +tutorial/example_deep_finance/prepare_data/* +tutorial/example_deep_finance/judge/analytical_sufficiency/* + .dockerignore benchmark_datasets modelscope_cache diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index dc192aa..51f6d60 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -82,6 +82,18 @@ def extract_text_content_from_content_dict(self, msg): # }, # ], # } + # or tool_result format?? not observed yet: + # msg = { + # "role": "tool", + # "content": [ + # { + # "type": "tool_result", + # "id": "call_xxx", + # "output": "tool output content", + # "name": "tool_name" + # }, + # ], + # } str_content = "" diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index 88f9ab1..c261056 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -9,6 +9,7 @@ from ajet.schema.trajectory import Reward from ajet.task_runner.base_runner import BaseAgentRunner from ajet.utils.dynamic_import import dynamic_import +from ajet.utils.metric_helper.reward_metric_helper import populate_reward_metadata_from_stats class GeneralRunner(BaseAgentRunner): @@ -73,6 +74,10 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: madness=0, description="", ) + + # Populate reward metadata with deep_finance reward stats if available + if "reward_stats" in workflow_output.metadata: + populate_reward_metadata_from_stats(reward, workflow_output.metadata["reward_stats"]) context_tracker.process_reward(reward) # generate token before merging context_tracker.group_merge() diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py index 76d034b..ea951d5 100644 --- a/ajet/utils/metric_helper/reward_metric_helper.py +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -11,9 +11,12 @@ - judge_time/ Judge time consumption statistics """ -from typing import List, Dict, Any +from typing import List, Dict, Any, TYPE_CHECKING import numpy as np +if TYPE_CHECKING: + from ajet.schema.trajectory import Reward + def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: """ @@ -72,22 +75,15 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties) metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0 - # ========== Detect OpenJudge Usage ========== + # ========== OpenJudge Metrics (PresentationQualityGrader, GroundingGrader) ========== openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False)) if openjudge_enabled_count > 0: - # ========== OpenJudge Metrics ========== - - # Dynamically extract OpenJudge grader fields - # Currently supported graders: report_resolution, trajectory_faithfulness, - # rubrics_performance, trajectory_comprehensive, information_gain, action_loop + # OpenJudge graders: presentation_quality, grounding openjudge_graders = [ - "report_resolution", - "trajectory_faithfulness", - "rubrics_performance", - "trajectory_comprehensive", - "information_gain", - "action_loop", + "presentation_quality", + "grounding", + "planning" ] for grader_name in openjudge_graders: @@ -151,3 +147,18 @@ def compute_reward_metrics_from_trajectories(trajectories: List[Any], prefix: st reward_stats_list = extract_reward_stats_from_trajectories(trajectories) return compute_reward_metrics(reward_stats_list, prefix=prefix) + +def populate_reward_metadata_from_stats(reward: "Reward", reward_stats: Dict[str, Any]) -> None: + """ + Populate Reward.metadata with all reward statistics. + + Args: + reward: The Reward object to populate + reward_stats: The reward_stats dictionary from judge + """ + if not reward_stats: + return + + # Directly copy all reward_stats into metadata + reward.metadata.update(reward_stats) + diff --git a/tutorial/example_deep_finance/__init__.py b/tutorial/example_deep_finance/__init__.py new file mode 100644 index 0000000..36e084c --- /dev/null +++ b/tutorial/example_deep_finance/__init__.py @@ -0,0 +1 @@ +# tutorial/example_deep_finance package diff --git a/tutorial/example_deep_finance/deep_finance.sh b/tutorial/example_deep_finance/deep_finance.sh index 6e3c13b..bee02ac 100644 --- a/tutorial/example_deep_finance/deep_finance.sh +++ b/tutorial/example_deep_finance/deep_finance.sh @@ -1,10 +1,10 @@ #!/bin/bash -set -e +set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="deep_finance" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 +SUFFIX="newjudge" # 实验后缀,影响所有日志和实验名称 +PREFIX="ajet_newjudge" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 @@ -12,10 +12,9 @@ RM_LLM='qwen-max' # RM Gallery 评分模型 JUDGE_CONCURRENCY=10 # 奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 +RM_WEIGHT=0.5 +PRESENTATION_QUALITY_WEIGHT=0.25 +GROUNDING_WEIGHT=0.25 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -23,7 +22,8 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -# 主目录 +# 主目录(需要更改) +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" NNODES=${WORLD_SIZE} @@ -46,7 +46,7 @@ fi # 2. 动态生成配置文件 (从yaml template生成yaml) #=============================================================================== # 修改:配置文件生成路径,现在动态生成到 yaml 目录下 -CONFIG_TEMPLATE="tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml" +CONFIG_TEMPLATE="tutorial/example_deep_finance/deep_finance.yaml" CONFIG_FILE="${AJET_ROOT}/tutorial/example_deep_finance/yaml/${SUFFIX}.yaml" mkdir -p $(dirname ${CONFIG_FILE}) @@ -55,12 +55,11 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ -e "s|{{NNODES}}|${NNODES}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ + -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ @@ -72,7 +71,7 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== # 3. 环境配置 @@ -106,7 +105,7 @@ export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS # 其他服务配置 HF_ENDPOINT="https://hf-mirror.com" ES_HOSTS="http://11.160.132.46:8200" -export HF_ENDPOINT ES_HOSTS +export HF_ENDPOINT ES_HOSTS # log 文件位置 CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") @@ -114,7 +113,7 @@ LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - +env_log_prefix="${SUFFIX}__${CURRENT_TIME}" # 多机训练参数配置 GPUS_PER_NODE=8 EXPECTED_WORKERS=$WORLD_SIZE @@ -156,6 +155,8 @@ export NCCL_ASYNC_ERROR_HANDLING=1 export PYTHONPATH="${AJET_ROOT}:${PYTHONPATH}" export RAY_CLUSTER_MODE="multi_node" +export DEEPFINANCE_PATH="${ENV_SERVICE_ROOT}" # AgentJet 内部可能使用此路径 +export DEEPFINANCE_SCRIPT="source /mnt/data/taoshuchang.tsc/anaconda3/etc/profile.d/conda.sh && conda activate finworld_1209 && cd ${ENV_SERVICE_ROOT} && DEEPFINANCE_TOOL_RESULT_MAX_CHARS=${DEEPFINANCE_TOOL_RESULT_MAX_CHARS} DEEPFINANCE_MCP_CONFIG=${DEEPFINANCE_MCP_CONFIG} CACHE_TYPE=${CACHE_TYPE} MONGO_URI=${MONGO_URI} MONGO_DB_NAME=${MONGO_DB_NAME} MONGO_COLLECTION_NAME=${MONGO_COLLECTION_NAME} python -m env_service.env_service --env finworld --portal 0.0.0.0 --port 8080" #=============================================================================== @@ -202,11 +203,12 @@ if [[ $HOSTNAME == *"-master-"* ]]; then # 启动训练任务(最核心) python ajet/launcher.py \ + --with-deepfinance \ --conf ${CONFIG_FILE} \ --backbone="verl" \ - --prefix=${SUFFIX} \ + --prefix=${env_log_prefix} \ 2>&1 | tee ${TRAIN_LOG} - + #=============================================================================== # 6.2 Worker 节点启动流程 @@ -218,4 +220,4 @@ else ray stop || true ray start --address $MASTER_ADDR:6379 --num-gpus 8 while true; do sleep 60; done -fi +fi \ No newline at end of file diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index 15dd566..33103fe 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -1,19 +1,18 @@ # ------------------ 主要配置 ------------------ ajet: - project_name: ajet_deep_finance - experiment_name: "ajet_deep_finance" + project_name: "{{PREFIX}}" + experiment_name: "{{SUFFIX}}" # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) judge: - openjudge_llm: qwen-flash # OpenJudge 模型 - rm_llm: qwen-max # RM Gallery 模型 - concurrency: 10 # Judge 并发数 + openjudge_llm: {{OPENJUDGE_LLM}} # OpenJudge 模型 + rm_llm: {{RM_LLM}} # RM Gallery 模型 + concurrency: {{JUDGE_CONCURRENCY}} # Judge 并发数 train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 - report_resolution_weight: 0.2 # 报告质量评估 - trajectory_faithfulness_weight: 0.2 # 事实准确性评估 - citation_audit_weight: 0.2 # 引用审计评估 (覆盖率 + 真实性) - rm_weight: 0.4 # RM Gallery 权重 + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 + rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service) judge_protocol: tutorial.example_deep_finance.deep_finance_judge->DeepFinanceJudgeByOpenJudge @@ -21,7 +20,7 @@ ajet: # ✨✨✨✨ 设置待训练的模型 path: {{MODEL_PATH}} trainer_common: - nnodes: 8 + nnodes: {{NNODES}} n_gpus_per_node: 8 val_before_train: True val_pass_n: 8 @@ -32,10 +31,10 @@ ajet: rollout: # ✨✨✨✨ 编写并选择Agent user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol - force_disable_toolcalls: True + force_disable_toolcalls: False enable_oversample: False tensor_model_parallel_size: 8 - num_repeat: 4 + num_repeat: {{NUM_REPEAT}} max_env_worker: 64 # 增加环境并行数 max_num_seqs: 64 # 增加VLLM并发序列数 max_response_length_in_one_turn: 8000 @@ -43,14 +42,14 @@ ajet: agent_madness_reward: 0.0 compute_madness_checklist: None multi_turn: - max_steps: 6 + max_steps: {{NUM_STEPS}} interchange_server: interchange_method: 'tcp' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) debug: debug_max_parallel: 1 # 增加并行任务数,充分利用GPU debug_first_n_tasks: 100 # 增加处理的任务数 data: - train_batch_size: 32 + train_batch_size: {{TRAIN_BATCH_SIZE}} max_prompt_length: 8000 max_response_length: 41000 @@ -58,18 +57,16 @@ ajet: type: deep_finance # 数据从 JSON 加载并组装 init_messages,工具调用走 env_service deep_finance: training: - file_path: {{TRAIN_PATH}} + file_path: {{TRAIN_DATA_PATH}} validation: - file_path: {{VAL_PATH}} + file_path: {{VAL_DATA_PATH}} # env_service 仍需配置(用于工具调用) env_service: env_type: "finworld" env_url: {{ENV_SERVICE_URL}} env_action_preference: code - - trainer: - default_local_dir: {{CKPT_SAVE_PATH}} + default_local_dir: "{{CKPT_SAVE_PATH}}/{{PREFIX}}/{{SUFFIX}}" # resume_mode: disable # 禁用自动恢复,从头开始训练 actor_rollout_ref: rollout: diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 31e4be0..03f1013 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -1,5 +1,5 @@ """DeepFinance Task Judge - OpenJudge 版本 -集成: RM Gallery, OpenJudge Graders (含 CitationAudit) +集成: RM Gallery, PresentationQualityGrader """ import os @@ -13,32 +13,11 @@ from ajet.task_judge.base_judge import BaseJudge from ajet.workflow import WorkflowOutput, WorkflowTask -from openjudge.graders.agent.action.action_loop import ActionLoopDetectionGrader -from openjudge.graders.agent.observation.observation_information_gain import ( - ObservationInformationGainGrader, -) -from openjudge.graders.agent.trajectory.trajectory_comprehensive import ( - TrajectoryComprehensiveGrader, -) from openjudge.models.openai_chat_model import OpenAIChatModel -from openjudge.models.schema.prompt_template import LanguageEnum from openjudge.runner.grading_runner import GraderConfig, GradingRunner -from openjudge.scenarios.deep_research.graders.financial_report_resolution import ( - FinancialReportResolutionGrader, -) -from openjudge.scenarios.deep_research.graders.financial_trajectory_faithfulness import ( - FinancialTrajectoryFaithfulGrader, -) -from openjudge.scenarios.deep_research.graders.rubrics_based_trajectory_performance import ( - RubricsBasedTrajectoryPerformance, -) -from openjudge.scenarios.deep_research.graders.financial_report_citation_audit import ( - FinancialReportCitationAuditGrader, -) +from tutorial.example_deep_finance.judge import PresentationQualityGrader, GroundingGrader -# RewardStats 不再使用,OpenJudge 版本直接使用字典存储 -# Reference Answer 路径现在从 config 中读取,见 _init_reference_answers 方法 # OpenJudge imports # ============================================================================= @@ -88,7 +67,7 @@ def load_reference_answers_from_file(file_path: str) -> Tuple[Dict[str, str], Di class DeepFinanceJudgeByOpenJudge(BaseJudge): """ 使用 OpenJudge 框架的 DeepFinance Judge - 集成: RM Gallery, OpenJudge Graders (含 CitationAudit) + 集成: RM Gallery, PresentationQualityGrader 分析: - compute_reward 每次处理 **一条采样**(单个 workflow_output) @@ -116,26 +95,15 @@ def _setup_weights(self): 配置 OpenJudge 各 grader 的权重并归一化 graders 对应关系: - - financial_report_resolution: 报告质量和问题解决能力 - - financial_trajectory_faithfulness: 事实准确性(忠实度) - - citation_audit: 引用审计(覆盖率 + 真实性) - - rubrics_based_trajectory_performance: 基于 rubrics 的评估 - - trajectory_comprehensive: 轨迹综合评估 - - observation_information_gain: 信息增益(去重) - - action_loop_detection: 动作循环检测(惩罚项) + - presentation_quality: 报告呈现质量评估 """ cfg = getattr(self.config, "ajet", None) - # 定义各 grader 的权重(可从 config 中读取)- 与 deep_finance_judge.py 对齐 + # 定义各 grader 的权重(可从 config 中读取) self.w = { "rm": getattr(cfg, "rm_weight", 1.0) if cfg else 1.0, # RM Gallery 权重 - "citation_audit": getattr(cfg, "citation_audit_weight", 0.0) if cfg else 0.0, # CitationAudit 权重 - "report_resolution": getattr(cfg, "report_resolution_weight", 0.0) if cfg else 0.0, - "trajectory_faithfulness": getattr(cfg, "trajectory_faithfulness_weight", 0.0) if cfg else 0.0, - # "rubrics_performance": getattr(cfg, "rubrics_performance_weight", 0.2) if cfg else 0.2, - # "trajectory_comprehensive": getattr(cfg, "trajectory_comprehensive_weight", 0.2) if cfg else 0.2, - # "information_gain": getattr(cfg, "information_gain_weight", 0.1) if cfg else 0.1, - # "action_loop": getattr(cfg, "action_loop_weight", 0.1) if cfg else 0.1 + "presentation_quality": getattr(cfg, "presentation_quality_weight", 0.25) if cfg else 0.25, + "grounding": getattr(cfg, "grounding_weight", 0.25) if cfg else 0.25, } # 归一化(注意:action_loop 是惩罚项,不参与归一化;rm 需要参与归一化) @@ -244,15 +212,14 @@ def _create_runner_in_loop(self) -> GradingRunner: 注意:GradingRunner 内部的 Semaphore 会绑定到创建时的事件循环, 因此不能使用单例模式,必须在每次调用的事件循环中创建新实例。 """ - language = LanguageEnum.ZH - grader_configs = self._create_grader_configs(self.model, language) + grader_configs = self._create_grader_configs(self.model) return GradingRunner( grader_configs=grader_configs, max_concurrency=self.max_concurrency, show_progress=False ) - def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) -> Dict[str, GraderConfig]: + def _create_grader_configs(self, model: OpenAIChatModel) -> Dict[str, GraderConfig]: """ 创建所有 grader 的配置 @@ -260,54 +227,35 @@ def _create_grader_configs(self, model: OpenAIChatModel, language: LanguageEnum) - key: grader 名称 - value: GraderConfig(grader=..., mapper=...) """ + + def extract_user_query(data: Dict) -> str: + """从 messages 中提取第一条 user 消息的 content""" + for msg in data.get("messages", []): + if msg.get("role") == "user": + return msg.get("content", "") + return "" + + def extract_report_content(data: Dict) -> str: + """从 messages 中提取最后一条 assistant 消息的 content""" + for msg in reversed(data.get("messages", [])): + if msg.get("role") == "assistant": + return msg.get("content", "") + return "" + return { - # 1. 报告质量评估 - 需要 messages 和 chat_date - "report_resolution": GraderConfig( - grader=FinancialReportResolutionGrader(model=model, language=language), + # 报告呈现质量评估 - 需要 user_query 和 report_content + "presentation_quality": GraderConfig( + grader=PresentationQualityGrader(model=model), mapper=lambda data: { - "messages": data["messages"], - "chat_date": data.get("chat_date") + "user_query": extract_user_query(data), + "report_content": extract_report_content(data), }, ), - - # 2. 事实准确性评估 - 需要 messages - "trajectory_faithfulness": GraderConfig( - grader=FinancialTrajectoryFaithfulGrader(model=model, language=language), - mapper=lambda data: {"messages": data["messages"]}, - ), - - # 3. 引用审计评估 - 需要 messages - "citation_audit": GraderConfig( - grader=FinancialReportCitationAuditGrader(model=model, language=language), - mapper=lambda data: {"messages": data["messages"]}, + # 引用规范性评估 - 需要完整的 traj + "grounding": GraderConfig( + grader=GroundingGrader(model=model), + mapper=lambda data: {"traj": data}, ), - - # 4. Rubrics 评估 - 需要 messages 和 rubrics - # "rubrics_performance": GraderConfig( - # grader=RubricsBasedTrajectoryPerformance(model=model, language=language), - # mapper=lambda data: { - # "messages": data["messages"], - # "rubrics": data.get("rubrics", []) - # }, - # ), - - # 5. 轨迹综合评估 - 需要 messages - # "trajectory_comprehensive": GraderConfig( - # grader=TrajectoryComprehensiveGrader(model=model, language=language), - # mapper=lambda data: {"messages": data["messages"]}, - # ), - - # 6. 信息增益评估 - 需要 messages(非 LLM grader) - # "information_gain": GraderConfig( - # grader=ObservationInformationGainGrader(similarity_threshold=0.5), - # mapper=lambda data: {"messages": data["messages"]}, - # ), - - # 7. 动作循环检测 - 需要 messages(非 LLM grader) - # "action_loop": GraderConfig( - # grader=ActionLoopDetectionGrader(similarity_threshold=1.0), - # mapper=lambda data: {"messages": data["messages"]}, - # ), } def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> Tuple[float, bool]: @@ -361,14 +309,28 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO chat_date=chat_date ) + if openjudge_sample.get('messages'): + last_msg = openjudge_sample['messages'][-1] + # 3. 调用 OpenJudge Runner.arun(异步) grading_start_time = time.time() grader_results = self._run_openjudge_evaluation([openjudge_sample]) grading_time = time.time() - grading_start_time + # 4. 提取各 grader 分数(arun 返回 Dict[str, List[GraderScore]],这里取第一条) grader_scores, quota_exceeded_flags = self._extract_grader_scores(grader_results) + # 4.5 如果有分数为0的grader,保存调试信息到单独文件 + self._save_zero_score_debug( + grader_scores=grader_scores, + grader_results=grader_results, + query=query, + history=history, + report=assistants[-1] if assistants else "", + task_id=task_id + ) + # 5. 加权融合(包含 RM Gallery 和 OpenJudge Graders) fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) @@ -552,8 +514,7 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ 输入: - grader_results: Dict[str, List[GraderScore]] { - "report_resolution": [GraderScore(score=0.88, reason="...", metadata={...})], - "trajectory_faithfulness": [GraderScore(score=1.0, ...)], + "presentation_quality": [GraderScore(score=0.88, reason="...", metadata={...})], ... } @@ -570,6 +531,10 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ if score_list and len(score_list) > 0: # 取第一条采样的分数(因为每次只评估一条) grader_score = score_list[0] + + # DEBUG: 记录详细信息 + reason_str = getattr(grader_score, 'reason', None) + print(f" [DEBUG] {grader_name}: score={getattr(grader_score, 'score', 'N/A')}, reason={str(reason_str)[:300] if reason_str else 'N/A'}") if hasattr(grader_score, "score"): scores[grader_name] = grader_score.score # 检测错误类型:分数为0且有错误信息 @@ -581,6 +546,7 @@ def _extract_grader_scores(self, grader_results: Dict[str, List[Any]]) -> Tuple[ else: # 如果出错,设为 0 scores[grader_name] = 0.0 + print(f" [DEBUG] {grader_name}: no 'score' attr, grader_score={grader_score}") else: scores[grader_name] = 0.0 @@ -657,6 +623,69 @@ def _save_rm_log(self, result, query: str, task_id: str): except Exception: pass + def _save_zero_score_debug( + self, + grader_scores: Dict[str, float], + grader_results: Dict[str, List[Any]], + query: str, + history: List[Dict], + report: str, + task_id: str + ): + """ + 当有 grader 分数为 0 时,保存详细调试信息到单独文件 + + 保存内容包括: + - query: 用户查询 + - traj: 对话历史 + - report: 最终报告(前500字) + - zero_score_reasons: 得 0 分的原因 + """ + try: + # 检查是否有分数为 0 的 grader + zero_score_graders = [name for name, score in grader_scores.items() if score == 0.0] + if not zero_score_graders: + return + + # 提取得 0 分的原因 + zero_score_reasons = {} + for grader_name in zero_score_graders: + if grader_name in grader_results: + score_list = grader_results[grader_name] + if score_list and len(score_list) > 0: + grader_score = score_list[0] + reason = getattr(grader_score, 'reason', None) + zero_score_reasons[grader_name] = str(reason) if reason else "N/A" + else: + zero_score_reasons[grader_name] = "empty score_list" + else: + zero_score_reasons[grader_name] = "grader not in results" + + # 构建调试日志 + debug_log = { + "task_id": task_id, + "timestamp": datetime.now().isoformat(), + "query": query, + "report": report if report else "", + "trajectory": history, + "grader_scores": grader_scores, + "zero_score_graders": zero_score_graders, + "zero_score_reasons": zero_score_reasons + } + + # 保存到单独文件 + save_dir = "/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new/tutorial/example_deep_finance/outputs/reward_zero_debug" + os.makedirs(save_dir, exist_ok=True) + log_file = os.path.join(save_dir, f"zeroscore_{datetime.now().strftime('%Y%m%d')}.jsonl") + with open(log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(debug_log, ensure_ascii=False) + "\n") + + print(f" [ZERO SCORE DEBUG] task_id={task_id}, zero_graders={zero_score_graders}, saved to {log_file}") + + except Exception as e: + print(f"⚠️ Failed to save zero score debug: {e}") + pass + def _compute_penalty(self, tool_calls: int) -> float: """ 计算工具调用惩罚(保留原有逻辑) @@ -689,13 +718,7 @@ def _update_metadata_stats( 更新 metadata["reward_stats"] - 直接使用 OpenJudge 原始字段 OpenJudge graders(按实际启用情况): - - report_resolution: 报告质量和问题解决能力 - - trajectory_faithfulness: 事实准确性(忠实度) - - citation_audit: 引用审计(覆盖率 + 真实性) - - rubrics_performance: 基于 rubrics 的评估(可选) - - trajectory_comprehensive: 轨迹综合评估(可选) - - information_gain: 信息增益/去重(可选) - - action_loop: 动作循环检测(惩罚项,可选) + - presentation_quality: 报告呈现质量评估 注意:不再硬套 RewardStats 的字段名,直接使用 openjudge_ 前缀 """ @@ -712,10 +735,6 @@ def _update_metadata_stats( "penalty": penalty, "step_reward": step_reward, "openjudge_enabled": True, - # Quota exceeded (429) 统计 - "quota_exceeded_any": quota_exceeded_any, # 是否有任何 grader 超额 - "quota_exceeded_count": quota_exceeded_count, # 超额的 grader 数量 - "quota_exceeded_graders": quota_exceeded_flags, # 各 grader 的超额标记 # RM Gallery 相关 "rm_enabled": self._rm_enabled, "rm_raw": rm_raw, diff --git a/tutorial/example_deep_finance/deep_finance_single.sh b/tutorial/example_deep_finance/deep_finance_single.sh index 6b27d0f..e794dff 100644 --- a/tutorial/example_deep_finance/deep_finance_single.sh +++ b/tutorial/example_deep_finance/deep_finance_single.sh @@ -3,8 +3,8 @@ set -e #=============================================================================== # 1. 配置区域 - 用户只需修改这里 #=============================================================================== -SUFFIX="ajet_deep_finance" # 实验后缀,影响所有日志和实验名称 -PREFIX="open" # 实验前缀,影响日志和实验所在文件夹 +SUFFIX="newjudge" # 实验后缀,影响所有日志和实验名称 +PREFIX="ajet_newjudge" # 实验前缀,影响日志和实验所在文件夹 # OpenJudge 模型配置 OPENJUDGE_LLM='qwen-flash' # OpenJudge 评分模型 @@ -12,10 +12,9 @@ RM_LLM='qwen-max' # RM Gallery 评分模型 JUDGE_CONCURRENCY=10 # 奖励权重配置 -RM_WEIGHT=0.4 -CITATION_AUDIT_WEIGHT=0.2 -REPORT_RESOLUTION_WEIGHT=0.2 -TRAJECTORY_FAITHFULNESS_WEIGHT=0.2 +RM_WEIGHT=0.5 +PRESENTATION_QUALITY_WEIGHT=0.25 +GROUNDING_WEIGHT=0.25 # 训练参数配置 NUM_REPEAT=4 # group size,每个query rollout NUM_REPEAT次 @@ -23,7 +22,8 @@ TRAIN_BATCH_SIZE=32 # 训练batchsize NUM_STEPS=6 # 每个样本step轮数 DEEPFINANCE_TOOL_RESULT_MAX_CHARS=10000 -# 主目录 +# 主目录(需要更改) +export AJET_ROOT="/mnt/data_cpfs/taoshuchang.tsc/deepresearch/AgentJet_new" NNODES=${WORLD_SIZE} @@ -55,70 +55,23 @@ sed -e "s|{{SUFFIX}}|${SUFFIX}|g" \ -e "s|{{MODEL_PATH}}|${MODEL_PATH}|g" \ -e "s|{{NNODES}}|${NNODES}|g" \ -e "s|{{RM_WEIGHT}}|${RM_WEIGHT}|g" \ - -e "s|{{CITATION_AUDIT_WEIGHT}}|${CITATION_AUDIT_WEIGHT}|g" \ + -e "s|{{PRESENTATION_QUALITY_WEIGHT}}|${PRESENTATION_QUALITY_WEIGHT}|g" \ + -e "s|{{GROUNDING_WEIGHT}}|${GROUNDING_WEIGHT}|g" \ -e "s|{{OPENJUDGE_LLM}}|${OPENJUDGE_LLM}|g" \ -e "s|{{RM_LLM}}|${RM_LLM}|g" \ -e "s|{{JUDGE_CONCURRENCY}}|${JUDGE_CONCURRENCY}|g" \ - -e "s|{{REPORT_RESOLUTION_WEIGHT}}|${REPORT_RESOLUTION_WEIGHT}|g" \ - -e "s|{{TRAJECTORY_FAITHFULNESS_WEIGHT}}|${TRAJECTORY_FAITHFULNESS_WEIGHT}|g" \ -e "s|{{NUM_REPEAT}}|${NUM_REPEAT}|g" \ -e "s|{{NUM_STEPS}}|${NUM_STEPS}|g" \ -e "s|{{TRAIN_BATCH_SIZE}}|${TRAIN_BATCH_SIZE}|g" \ -e "s|{{TRAIN_DATA_PATH}}|${TRAIN_DATA_PATH}|g" \ -e "s|{{VAL_DATA_PATH}}|${VAL_DATA_PATH}|g" \ - -e "s|{{ENV_SERVICE_URL}}|${ENV_SERVICE_URL}|g" \ -e "s|{{TRAIN_REF_ANS_PATH}}|${TRAIN_REF_ANS_PATH}|g" \ -e "s|{{VAL_REF_ANS_PATH}}|${VAL_REF_ANS_PATH}|g" \ -e "s|{{CKPT_SAVE_PATH}}|${CKPT_SAVE_PATH}|g" \ ${AJET_ROOT}/${CONFIG_TEMPLATE} > ${CONFIG_FILE} echo "配置文件已生成: ${CONFIG_FILE}" -echo "参数确认: RM=${RM_WEIGHT}, Citation=${CITATION_AUDIT_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" - -#=============================================================================== -# 3. 环境配置 -#=============================================================================== -# MongoDB 缓存配置 -CACHE_TYPE="mongodb" -MONGO_URI="mongodb://${ADDR}:27117/" -MONGO_DB_NAME="finworld_cache" -MONGO_COLLECTION_NAME="tool_cache" -export CACHE_TYPE MONGO_URI MONGO_DB_NAME MONGO_COLLECTION_NAME - -# DeepFinance MCP 配置 -DEEPFINANCE_MCP_CONFIG="${AJET_ROOT}/tutorial/example_deep_finance/config/mcp_finance_tool_generated.json" - -# 动态生成 MCP 配置文件 -mkdir -p $(dirname ${DEEPFINANCE_MCP_CONFIG}) -cat > ${DEEPFINANCE_MCP_CONFIG} << EOF -{ - "mcpServers": { - "flowllm": { - "transport": "sse", - "url": "http://${ADDR}:${MCP_PORT}/sse", - "timeout": 600, - "sse_read_timeout": 1200 - } - } -} -EOF -export DEEPFINANCE_MCP_CONFIG DEEPFINANCE_TOOL_RESULT_MAX_CHARS - -# 其他服务配置 -HF_ENDPOINT="https://hf-mirror.com" -ES_HOSTS="http://11.160.132.46:8200" -export HF_ENDPOINT ES_HOSTS - -# log 文件位置 -CURRENT_TIME=$(date "+%Y%m%d_%H%M%S") -LOG_DIR="${AJET_ROOT}/logs/${PREFIX}" -MASTER_IP_FILE="${LOG_DIR}/master-ip_${SUFFIX}.log" -ENV_SERVICE_LOG="${LOG_DIR}/env_service_${SUFFIX}_${CURRENT_TIME}.log" -TRAIN_LOG="${LOG_DIR}/train_${SUFFIX}_${CURRENT_TIME}.log" - -# 多机训练参数配置 -GPUS_PER_NODE=8 -EXPECTED_WORKERS=$WORLD_SIZE +echo "参数确认: RM=${RM_WEIGHT}, PresentationQuality=${PRESENTATION_QUALITY_WEIGHT}, Grounding=${GROUNDING_WEIGHT}, OpenJudge=${OPENJUDGE_LLM}, RM_LLM=${RM_LLM}" #=============================================================================== @@ -162,7 +115,6 @@ export RAY_CLUSTER_MODE="multi_node" #=============================================================================== # 6. 主流程 #=============================================================================== -log "开始多机多卡训练: ${SUFFIX}" log "节点数: ${NNODES}, 每节点GPU数: ${GPUS_PER_NODE}" mkdir -p ${LOG_DIR} mkdir -p $(dirname ${CONFIG_FILE}) diff --git a/tutorial/example_deep_finance/judge/__init__.py b/tutorial/example_deep_finance/judge/__init__.py new file mode 100644 index 0000000..75c8cef --- /dev/null +++ b/tutorial/example_deep_finance/judge/__init__.py @@ -0,0 +1,11 @@ +# 使得可以通过 from judge import PresentationQualityGrader 直接引用 +from .grounding.grader import GroundingGrader +from .presentation_quality.grader import PresentationQualityGrader +# from .research_depth.grader import ResearchDepthGrader +# from .research_breadth.grader import ResearchBreadthGrader + +# 以后添加了其他 grader 也可以加在这里 +# from .grounding.grader import GroundingGrader +# from .research_breadth.grader import ResearchBreadthGrader +# __all__ = ["PresentationQualityGrader", "GroundingGrader", "ResearchDepthGrader", "ResearchBreadthGrader"] +__all__ = ["PresentationQualityGrader", "GroundingGrader"] diff --git a/tutorial/example_deep_finance/judge/grounding/__init__.py b/tutorial/example_deep_finance/judge/grounding/__init__.py new file mode 100644 index 0000000..1123382 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/__init__.py @@ -0,0 +1,4 @@ +"""Grounding Grader - 引用规范性评估""" +from .grader import GroundingGrader + +__all__ = ["GroundingGrader"] diff --git a/tutorial/example_deep_finance/judge/grounding/grader.py b/tutorial/example_deep_finance/judge/grounding/grader.py new file mode 100644 index 0000000..599ccc9 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/grader.py @@ -0,0 +1,222 @@ +"""Grounding Grader - 引用规范性评估 (OpenJudge 版本)""" +from __future__ import annotations + +import os +from typing import Any, Dict, List, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +# import path 兼容两种写法 +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import GROUNDING_SYSTEM_PROMPT, GROUNDING_USER_PROMPT_TEMPLATE +from .json_utils import strict_load_json, validate_shape, construct_reward_prompt + + +class GroundingGrader(BaseGrader): + """ + 引用规范性评估 Grader + + - 输入:traj(完整对话轨迹) + - 输出:GraderScore(name, score, reason) + - score:综合分数,范围[0,1] + - citation_coverage_score: 引用覆盖率(0.5 权重) + - grounding_score: 引用真实性(0.5 权重) + - invalid_penalty: 无效引用惩罚(最多扣 0.5) + - determinism:建议用 temperature=0 + disable thinking + - 解析失败:score=0,并在 reason 显示报错 + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "grounding", + **kwargs: Any, + ): + super().__init__(name=name, **kwargs) + self.model = model + + @staticmethod + def create_default_model( + model_name: str, + api_key: str | None = None, + base_url: str | None = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 0, + ) -> OpenAIChatModel: + """ + 创建默认模型 + 也可以不调用这个工厂,自己在外面 new OpenAIChatModel + """ + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + extra_body.update( + { + "temperature": 0, + "top_p": 1, + "seed": seed, + "presence_penalty": 0, + "frequency_penalty": 0, + } + ) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def aevaluate( + self, + traj: Any, + **_: Any, + ) -> GraderScore: + """ + 入口:必须喂 traj(完整对话轨迹) + + Args: + traj: 对话轨迹,格式为 [{"role": ..., "content": ...}, ...] + 或者 {"messages": [...]} 格式 + + Returns: + GraderScore(name, score, reason) + """ + # 1. 提取 messages(兼容两种格式) + if isinstance(traj, dict): + messages_list = traj.get("messages", []) + elif isinstance(traj, list): + messages_list = traj + else: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: traj must be list or dict with 'messages'", + ) + + if not messages_list: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty trajectory", + ) + + # 2. 构建 prompt + user_prompt = construct_reward_prompt(messages_list, GROUNDING_USER_PROMPT_TEMPLATE) + + messages = [ + {"role": "system", "content": GROUNDING_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt} + ] + + # 3. 调用模型 + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + # 4. 解析 JSON + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + obj, serr = validate_shape(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + # 5. 计算分数 + score, reason = self._compute_scores(obj) + return GraderScore(name=self.name, score=score, reason=reason) + + def _compute_scores(self, obj: Dict[str, Any]) -> Tuple[float, str]: + """ + 根据 LLM 返回的结果计算评分 + + Args: + obj: LLM 返回的 JSON,包含 total_key_facts, cited_key_facts, fake_count 等 + + Returns: + (score, reason) 元组 + """ + total_key_facts = obj.get('total_key_facts', 0) + cited_key_facts = obj.get('cited_key_facts', 0) + fake_count = obj.get('fake_count', 0) + missing_count = obj.get('missing_count', 0) + + # invalid refs: 结构化/可追溯性问题 + invalid_reference_nums = obj.get('invalid_reference_nums', []) + if not isinstance(invalid_reference_nums, list): + invalid_reference_nums = [] + invalid_ref_count = len(invalid_reference_nums) + + # 边界情况:没有关键事实,直接返回 0 + if total_key_facts == 0: + citation_coverage_score = 0.0 + grounding_score = 0.0 + else: + # coverage: 引用覆盖率 + citation_coverage_score = cited_key_facts / total_key_facts + + # grounding: 引用真实性(已引用中非虚假的比例) + if cited_key_facts == 0: + grounding_score = 0.0 + else: + grounding_score = max(0.0, 1 - fake_count / cited_key_facts) + + # 轻量惩罚:存在 invalid refs 会降低 reward + # 每个 invalid 号扣 0.1,最多扣 0.5 + invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + + # final_reward: 综合分数(权重 0.5:0.5),再叠加 invalid 惩罚 + final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score + final_reward = max(0.0, final_reward - invalid_penalty) + + # 构建 reason + good_citations = obj.get('good_citations', []) + good_str = "; ".join(str(x)[:50] for x in good_citations[:2]) if good_citations else "" + + parts: List[str] = [ + f"total={total_key_facts}", + f"cited={cited_key_facts}", + f"missing={missing_count}", + f"fake={fake_count}", + f"invalid={invalid_ref_count}", + f"coverage={citation_coverage_score:.3f}", + f"grounding={grounding_score:.3f}", + f"penalty={invalid_penalty:.2f}", + ] + if good_str: + parts.append(f"good:[{good_str}]") + + reason = " | ".join(parts) + return round(final_reward, 6), reason[:800] diff --git a/tutorial/example_deep_finance/judge/grounding/json_utils.py b/tutorial/example_deep_finance/judge/grounding/json_utils.py new file mode 100644 index 0000000..a3f793a --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/json_utils.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Tuple + +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + + +def extract_first_json_object(text: str) -> str | None: + """ + Best-effort: extract the first {...} block. + If none found, return None. + """ + if not text: + return None + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: + """ + Return (obj, error). Any parse failure => (None, error_msg) + """ + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found in model output" + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except Exception as e: + return None, f"{type(e).__name__}: {e}" + + +def get_bool_pass(item: Any) -> bool: + if isinstance(item, dict): + v = item.get("pass") + else: + v = item + if isinstance(v, bool): + return v + if isinstance(v, (int, float)): + return bool(v) + if isinstance(v, str): + return v.strip().lower() in {"true", "1", "yes", "y"} + return False + + +def get_note(item: Any) -> str: + if isinstance(item, dict): + note = item.get("note", "") + else: + note = "" + note = "" if note is None else str(note) + note = note.strip() + # 最多给点余量,避免reason爆长 + return note[:120] + + +def validate_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: + """ + 验证 grounding JSON 结构 + + 必需字段: + - total_key_facts: int + - cited_key_facts: int + - missing_count: int + - fake_count: int + - good_citations: list + - invalid_reference_nums: list + """ + # 必需的 int 字段 + int_fields = ["total_key_facts", "cited_key_facts", "missing_count", "fake_count"] + for field in int_fields: + if field not in obj: + return None, f"Missing field: {field}" + val = obj[field] + # 尝试转换为 int + if isinstance(val, (int, float)): + obj[field] = int(val) + elif isinstance(val, str) and val.isdigit(): + obj[field] = int(val) + elif not isinstance(val, int): + return None, f"Field '{field}' must be int, got {type(val).__name__}" + + # good_citations 必须是 list + if "good_citations" not in obj: + obj["good_citations"] = [] + elif not isinstance(obj["good_citations"], list): + obj["good_citations"] = [] + else: + # 确保每个元素是字符串,最多保留 2 条 + obj["good_citations"] = [str(x) for x in obj["good_citations"][:2]] + + # invalid_reference_nums 必须是 list + if "invalid_reference_nums" not in obj: + obj["invalid_reference_nums"] = [] + elif not isinstance(obj["invalid_reference_nums"], list): + obj["invalid_reference_nums"] = [] + else: + # 确保每个元素是 int,最多保留 5 个 + nums = [] + for x in obj["invalid_reference_nums"][:5]: + if isinstance(x, int): + nums.append(x) + elif isinstance(x, (float, str)): + try: + nums.append(int(x)) + except ValueError: + pass + obj["invalid_reference_nums"] = sorted(nums) + + return obj, None + + + + +# ============================================================================= +# Trajectory 处理辅助函数 +# ============================================================================= + +def _extract_text_content(content) -> str: + """统一提取纯文本内容""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + out.append(item.get("text", "")) + elif isinstance(item, str): + out.append(item) + return "\n".join(out) + return str(content) + + +def _strip_think(text: str) -> str: + """去除 ... 标签""" + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + + +def _strip_markdown_fences(text: str) -> str: + """ + 清理 markdown 代码块标记 + - 移除开头的 ```markdown / ```md / ``` 等 + - 移除结尾的 ``` + """ + text = text.strip() + # 移除开头的 ```xxx + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + # 移除结尾的 ``` + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() + + +def _normalize_traj(trajectory): + """兼容 [[...]] 格式""" + if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): + return trajectory[0] + return trajectory + + +def _extract_tool_call_json(text: str) -> str: + """提取工具调用 JSON""" + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + l, r = text.find("["), text.rfind("]") + if l != -1 and r != -1 and r > l: + cand = text[l:r+1].strip() + if ("tool_name" in cand) and ("tool_args" in cand): + return cand + return "" + + +def _looks_like_tool_result(text: str) -> bool: + """判断是否为工具返回结果""" + t = text.strip() + if t.startswith("Tool:") or t.startswith("Result:"): + return True + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + return False + + +def _is_probably_final_report(text: str) -> bool: + """判断是否为最终报告""" + t = text.strip() + return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ") + + +def construct_reward_prompt(trajectory: List[Dict[str, Any]], user_prompt_template: str) -> str: + """ + 从 trajectory 构建 reward prompt + + Args: + trajectory: 对话轨迹 [{"role": ..., "content": ...}, ...] + + Returns: + 构建好的 user prompt 字符串 + """ + traj = _normalize_traj(trajectory) + if not traj: + traj = [] + + user_query = "" + tool_calls: List[str] = [] + evidence: List[str] = [] + final_report = "" + + # 找到 final report(从后往前找第一个符合条件的 assistant 消息) + for i in range(len(traj) - 1, -1, -1): + step = traj[i] + if step.get("role") == "assistant": + txt = _strip_think(_extract_text_content(step.get("content"))) + if _is_probably_final_report(txt): + final_report = txt + break + if not final_report: + for i in range(len(traj) - 1, -1, -1): + if traj[i].get("role") == "assistant": + final_report = _strip_think(_extract_text_content(traj[i].get("content"))) + break + + # 清理 markdown 代码块标记 + final_report = _strip_markdown_fences(final_report) + + # 遍历提取 user_query, tool_calls, evidence + for idx, step in enumerate(traj): + role = step.get("role") + raw = _extract_text_content(step.get("content")) + txt = _strip_think(raw) + if not raw: + continue + + if role == "user" and not user_query and (not _looks_like_tool_result(raw)): + user_query = txt + continue + + if role == "assistant": + call_json = _extract_tool_call_json(raw) + if call_json: + tool_calls.append(f"[Step {idx}] TOOL_CALL:\n{call_json}") + + if role in ("tool", "user"): + if _looks_like_tool_result(raw): + evidence.append(f"[Step {idx}] EVIDENCE_TOOL_RESULT:\n{raw}") + else: + # query 之后的用户补充也保留为 evidence + if user_query: + evidence.append(f"[Step {idx}] EVIDENCE_USER_CONTEXT:\n{txt}") + + evidence_text = "\n\n".join(tool_calls + evidence) + + return user_prompt_template.format( + user_query=user_query, + evidence_text=evidence_text, + final_report=final_report + ).strip() diff --git a/tutorial/example_deep_finance/judge/grounding/prompt.py b/tutorial/example_deep_finance/judge/grounding/prompt.py new file mode 100644 index 0000000..24bea13 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/prompt.py @@ -0,0 +1,117 @@ +"""Grounding Grader Prompt - 引用规范性评估""" + +GROUNDING_SYSTEM_PROMPT = """你是一位"引用审计员",负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +======================== +一、引用规范(以此为准) +======================== +1) 关键事实句必须引用: + - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 + - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 + +2) 引用位置规则(严格执行): + - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 + - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 + +3) References 必须存在且可追溯: + - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 + - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 + +4) References 条目两种合法形式(必须满足其一): + A) URL 形式:`[n] 标题或简述 - https://...` + - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 + B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` + - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 + - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 + +======================== +二、输入 +======================== +你会收到: +- User Query +- Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) +- AI Report(待审计报告,含正文与 References) + +真实性核对原则: +- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 +- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 + +======================== +三、统计与判定口径(严格遵守) +======================== +【文本范围】 +- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 +- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 + +【句子/条目如何计数】 +- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 +- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 +- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 + +【关键事实句识别(务求稳定)】 +- 满足任一条件可视为关键事实句: + (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; + (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); + (c) 对具体公司/行业/政策做了可验证的确定性陈述; + (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 + +【引用是否“句末”】【重要】 +- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: + - “……增长 20%[3]” + - “……增长 20% [3][4]” +- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 + +【invalid_reference_nums 的定义】 +- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: + (a) References 中不存在该编号条目; + (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); + (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 +- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 + +【missing_count 的定义】 +- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 + +【cited_key_facts 的定义】 +- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 + +【fake_count 的定义(只在明显时计数)】 +- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 +- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 + +【good_citations 的定义】 +- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: + - 是关键事实句; + - 句末有 [n]; + - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 +- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 + +======================== +四、输出(只输出 JSON,字段固定) +======================== +{ + "total_key_facts": , + "cited_key_facts": , + "good_citations": ["...", "..."], + "missing_count": , + "fake_count": , + "invalid_reference_nums": [, ...] +} + +只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +""" + +# ============================================================================= +# User Prompt Template +# ============================================================================= + +GROUNDING_USER_PROMPT_TEMPLATE = """请审计以下 AI 研究报告的引用规范性,只输出 JSON。 + +### User Query +{user_query} + +### Evidence +{evidence_text} + +### AI Report(待审计报告) +{final_report} +""" diff --git a/tutorial/example_deep_finance/judge/grounding/reference.py b/tutorial/example_deep_finance/judge/grounding/reference.py new file mode 100644 index 0000000..6e67a38 --- /dev/null +++ b/tutorial/example_deep_finance/judge/grounding/reference.py @@ -0,0 +1,363 @@ +GROUNDING_SYSTEM_PROMPT = """你是一位“引用审计员”,负责审计金融研究报告是否遵守引用规范,并输出用于训练的 JSON 结果(只输出 JSON)。 + +======================== +一、引用规范(以此为准) +======================== +1) 关键事实句必须引用: + - 关键事实句包括:数字(金额/比例/增速/同比环比/份额/排名等)、日期/期间、财务指标、估值倍数、明确事实结论、具体事件、具体公司/行业的可验证陈述、政策/条款等。 + - 不确定或推断性表述必须显式写“推测/可能/假设/预计/或有风险”等,不得用引用把推断包装成既定事实。 + +2) 引用位置规则(严格执行): + - 关键事实句必须在“句末”出现引用编号:[1] 或 [1][2](可以多个,但必须紧贴句末)。 + - 若引用出现在句中但句末没有引用编号,则该句仍按“缺引用(missing)”处理。 + +3) References 必须存在且可追溯: + - 报告末尾必须包含标题 `## References`(大小写/空格差异可容忍,但必须是一个清晰的 References 区块)。 + - 正文出现的每个 [n] 必须能在 References 中找到对应条目。 + +4) References 条目两种合法形式(必须满足其一): + A) URL 形式:`[n] 标题或简述 - https://...` + - URL 必须为可用的 http/https 链接,不能为空,也不能是 `javascript:void(0)` 之类的伪链接。 + B) no-url 形式:`[n] 简述,工具:,参数:,数据日期/报告期: - (no-url)` + - no-url 必须同时包含:工具名、参数、日期/报告期 三者(缺一即不合规)。 + - `javascript:void(0)` 等无效链接视为无效 URL(会进入 invalid_reference_nums),若要合规应改为 no-url 记录来源。 + +======================== +二、输入 +======================== +你会收到: +- User Query +- Evidence(从完整 trajectory 提取的工具调用/工具返回/用户补充信息) +- AI Report(待审计报告,含正文与 References) + +真实性核对原则: +- 以 Evidence 为准:只有在“明显矛盾”或“Evidence 明显找不到任何依据且该句仍把内容写成确定事实”时,才判 fake。 +- 无法确认/证据缺失/证据不充分时,不要判 fake(宁可不判)。 + +======================== +三、统计与判定口径(严格遵守) +======================== +【文本范围】 +- 只审计 AI Report 的“正文部分”(不包含 References 区块内部的文字)。 +- References 区块仅用于校验编号是否存在、格式是否合规、URL 是否有效。 + +【句子/条目如何计数】 +- “句子/条目”包括:普通句号/分号/换行分点(如列表项、段落中的 bullet)、表格中的单元格陈述(若表达了关键事实,也算关键事实句)。 +- 一句包含多个数字/多个事实点:仍按 1 条关键事实句计数(不要过度拆分)。 +- 同一句若重复出现多次(复制粘贴重复段落):按出现次数计数。 + +【关键事实句识别(务求稳定)】 +- 满足任一条件可视为关键事实句: + (a) 含具体数值/比例/排名/区间/估值倍数/财务指标; + (b) 含具体日期或期间(如 “2024Q3/2025年/截至XX日”); + (c) 对具体公司/行业/政策做了可验证的确定性陈述; + (d) 给出明确结论且呈确定口吻并可被证据支持/反驳。 + +【引用是否“句末”】【重要】 +- 句末引用指:该句最后的可见字符为一个或多个连续的 [n](允许中间无空格或有极少空格),例如: + - “……增长 20%[3]” + - “……增长 20% [3][4]” +- 若 [n] 后面仍有正文内容(哪怕很短),则不算句末引用。 + +【invalid_reference_nums 的定义】 +- 统计“正文中出现过”的编号 n(去重),若满足任一条件则判为 invalid: + (a) References 中不存在该编号条目; + (b) 该编号条目为 URL 形式但 URL 无效(空/非 http(s)/javascript:void(0) 等); + (c) 该编号条目为 no-url 形式但缺少 工具名/参数/日期(报告期) 任意之一。 +- invalid_reference_nums 输出按数字升序;最多 5 个,超出截断。 + +【missing_count 的定义】 +- 关键事实句中“句末没有任何 [n]”的数量(即使句中出现 [n] 也算 missing)。 + +【cited_key_facts 的定义】 +- 关键事实句中“句末包含至少一个 [n]”的数量(不要求该引用有效)。 + +【fake_count 的定义(只在明显时计数)】 +- 关键事实句若“句末带引用”,但与 Evidence 明显矛盾,或 Evidence 明显找不到任何依据且该句仍用确定口吻陈述为事实,计为 fake。 +- 若只是 Evidence 未覆盖/不充分/不确定,不计 fake。 + +【good_citations 的定义】 +- 从报告原文中抽取最多 2 条“引用做得正确”的关键事实句,要求同时满足: + - 是关键事实句; + - 句末有 [n]; + - 所有句末 [n] 在 References 中均存在且条目合法(URL 有效或 no-url 字段齐全)。 +- good_citations 是原文截取,不要加解释;最多 2 条,超出截断。 + +======================== +四、输出(只输出 JSON,字段固定) +======================== +{ + "total_key_facts": , + "cited_key_facts": , + "good_citations": ["...", "..."], + "missing_count": , + "fake_count": , + "invalid_reference_nums": [, ...] +} + +只输出 JSON,不要输出解释文字或 Markdown。确保 JSON 可被严格解析(双引号、逗号、方括号等格式正确)。 +""" + + + +import json +import re +from typing import Dict, Any, List + + +def _extract_text_content(content) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + out = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + out.append(item.get("text", "")) + elif isinstance(item, str): + out.append(item) + return "\n".join(out) + return str(content) + +def _strip_think(text: str) -> str: + return re.sub(r".*?\s*", "", text, flags=re.S).strip() + +def _normalize_traj(trajectory): + # 兼容 [[...]] :contentReference[oaicite:1]{index=1} + if isinstance(trajectory, list) and trajectory and isinstance(trajectory[0], list): + return trajectory[0] + return trajectory + +def _extract_tool_call_json(text: str) -> str: + m = re.search(r"```json\s*(\[[\s\S]*?\])\s*```", text) + if m: + return m.group(1).strip() + l, r = text.find("["), text.rfind("]") + if l != -1 and r != -1 and r > l: + cand = text[l:r+1].strip() + if ("tool_name" in cand) and ("tool_args" in cand): + return cand + return "" + +def _looks_like_tool_result(text: str) -> bool: + t = text.strip() + if t.startswith("Tool:") or t.startswith("Result:"): + return True + if t.startswith("{") and ("query" in t) and ("search_results" in t or "response_content" in t): + return True + if ("股票代码 |" in t) or ("单位:" in t) or t.startswith("### "): + return True + return False + +def _is_probably_final_report(text: str) -> bool: + t = text.strip() + return ("## References" in t) or ("[TASK_COMPLETED]" in t) or t.lstrip().startswith("# ") + +def construct_reward_prompt(trajectory: List[Dict[str, Any]]) -> str: + traj = _normalize_traj(trajectory) + + user_query = "" + tool_calls: List[str] = [] + evidence: List[str] = [] + final_report = "" + + # final report + for i in range(len(traj) - 1, -1, -1): + step = traj[i] + if step.get("role") == "assistant": + txt = _strip_think(_extract_text_content(step.get("content"))) + if _is_probably_final_report(txt): + final_report = txt + break + if not final_report: + for i in range(len(traj) - 1, -1, -1): + if traj[i].get("role") == "assistant": + final_report = _strip_think(_extract_text_content(traj[i].get("content"))) + break + + # iterate + for idx, step in enumerate(traj): + role = step.get("role") + raw = _extract_text_content(step.get("content")) + txt = _strip_think(raw) + if not raw: + continue + + if role == "user" and not user_query and (not _looks_like_tool_result(raw)): + user_query = txt + continue + + if role == "assistant": + call_json = _extract_tool_call_json(raw) + if call_json: + tool_calls.append(f"[Step {idx}] TOOL_CALL:\n{call_json}") + + if role in ("tool", "user"): + if _looks_like_tool_result(raw): + evidence.append(f"[Step {idx}] EVIDENCE_TOOL_RESULT:\n{raw}") + else: + # query 之后的用户补充也保留为 evidence(有些系统会把 tool_result 注入到 user) + if user_query: + evidence.append(f"[Step {idx}] EVIDENCE_USER_CONTEXT:\n{txt}") + + evidence_text = "\n\n".join(tool_calls + evidence) + + return f"""请审计以下 AI 研究报告的引用规范性,只输出 JSON。 + +### User Query +{user_query} + +### Evidence(来自完整 trajectory) +{evidence_text} + +### AI Report(待审计报告) +{final_report} +""".strip() + + +class RefJudgeEvaluator: + """ + 引用规范性评估器 + + 使用 LLM 评估报告的引用覆盖率和引用真实性。 + """ + + def __init__(self, llm_client): + """ + 初始化评估器 + + Args: + llm_client: LLMJudgeClient 实例 + """ + self.llm_client = llm_client + print("✓ RefJudgeEvaluator: Initialized") + + def build_messages(self, conversation_history: List[Dict]) -> List[Dict[str, str]]: + """ + 从对话历史构建 LLM 评估消息 + + Args: + conversation_history: 对话历史 [{"role": "...", "content": "..."}] + + Returns: + LLM 消息列表 + """ + print(f"\n[RefJudgeEvaluator] 构建评估消息...") + print(f" - 对话历史轮数: {len(conversation_history)}") + + # 调用现有的 prompt 构建函数 + user_prompt = construct_reward_prompt(conversation_history) + + messages = [ + {"role": "system", "content": GROUNDING_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt} + ] + + print(f" ✓ 消息构建完成,system prompt 长度: {len(GROUNDING_SYSTEM_PROMPT)}") + print(f" ✓ user prompt 长度: {len(user_prompt)}") + + return messages + + def _compute_scores(self, raw_result: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据 LLM 返回的原始结果计算评分 + + Args: + raw_result: LLM 返回的 JSON,包含 total_key_facts, cited_key_facts, fake_count 等 + + Returns: + 包含 citation_coverage_score, grounding_score, final_reward 的字典 + """ + total_key_facts = raw_result.get('total_key_facts', 0) + cited_key_facts = raw_result.get('cited_key_facts', 0) + fake_count = raw_result.get('fake_count', 0) + + # invalid refs: 结构化/可追溯性问题(来自 prompt 的 invalid_reference_nums) + invalid_reference_nums = raw_result.get('invalid_reference_nums', []) + if not isinstance(invalid_reference_nums, list): + invalid_reference_nums = [] + invalid_ref_count = len(invalid_reference_nums) + + # 边界情况:没有关键事实,直接返回 0 + if total_key_facts == 0: + citation_coverage_score = 0.0 + grounding_score = 0.0 + else: + # coverage: 引用覆盖率 + citation_coverage_score = cited_key_facts / total_key_facts + + # grounding: 引用真实性(已引用中非虚假的比例) + if cited_key_facts == 0: + grounding_score = 0.0 + else: + grounding_score = max(0.0, 1 - fake_count / cited_key_facts) + + # 轻量惩罚:存在 invalid refs 会降低 reward(但不改变 cited_key_facts 的统计口径) + # 说明:invalid_reference_nums 在 prompt 中已定义为“正文出现过的不合规编号(去重)”。 + # 这里采用简单、确定性的惩罚:每个 invalid 号扣 0.1,最多扣 0.5。 + invalid_penalty = min(0.1 * invalid_ref_count, 0.5) + + # final_reward: 综合分数(代码计算,权重 0.5:0.5),再叠加 invalid 惩罚 + final_reward = 0.5 * citation_coverage_score + 0.5 * grounding_score + final_reward = max(0.0, final_reward - invalid_penalty) + + return { + 'citation_coverage_score': citation_coverage_score, + 'grounding_score': grounding_score, + 'final_reward': final_reward, + 'invalid_ref_count': invalid_ref_count, + 'invalid_penalty': invalid_penalty, + } + + async def evaluate_async(self, conversation_history: List[Dict]) -> Dict[str, Any]: + """ + 异步评估引用规范性 + + Args: + conversation_history: 对话历史 + + Returns: + 评估结果字典,包含: + - citation_coverage_score: 引用覆盖率分数 (0.0-1.0) + - grounding_score: 引用真实性分数 (0.0-1.0) + - final_reward: 最终奖励分数 (0.0-1.0) + - total_key_facts, cited_key_facts, fake_count 等原始字段 + """ + # print(f"\n开始评估引用规范性...") + + messages = self.build_messages(conversation_history) + raw_result = await self.llm_client.evaluate_async(messages) + + # 计算评分 + scores = self._compute_scores(raw_result) + + # 合并原始结果和计算的评分 + result = {**raw_result, **scores} + + # 确保必要字段存在 + result.setdefault('total_key_facts', 0) + result.setdefault('cited_key_facts', 0) + result.setdefault('missing_count', 0) + result.setdefault('fake_count', 0) + result.setdefault('invalid_reference_nums', []) + result.setdefault('good_citations', []) + + print(f" ✓ [RefJudgeEvaluator] 引用规范性评估完成:") + print(f" - total_key_facts: {result['total_key_facts']}") + print(f" - cited_key_facts: {result['cited_key_facts']}") + print(f" - fake_count: {result['fake_count']}") + print(f" - invalid_ref_count: {result.get('invalid_ref_count', 0)}") + print(f" - invalid_penalty: {result.get('invalid_penalty', 0.0):.4f}") + print(f" - citation_coverage_score: {result['citation_coverage_score']:.4f}") + print(f" - grounding_score: {result['grounding_score']:.4f}") + print(f" - final_reward: {result['final_reward']:.4f}") + + return result + + def evaluate_sync(self, conversation_history: List[Dict]) -> Dict[str, Any]: + """ + 同步评估引用规范性 + """ + import asyncio + return asyncio.run(self.evaluate_async(conversation_history)) diff --git a/tutorial/example_deep_finance/judge/presentation_quality/__init__.py b/tutorial/example_deep_finance/judge/presentation_quality/__init__.py new file mode 100644 index 0000000..2db690f --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/__init__.py @@ -0,0 +1,4 @@ +"""Grounding Grader - 引用规范性评估""" +from .grader import PresentationQualityGrader + +__all__ = ["PresentationQualityGrader"] diff --git a/tutorial/example_deep_finance/judge/presentation_quality/grader.py b/tutorial/example_deep_finance/judge/presentation_quality/grader.py new file mode 100644 index 0000000..c440c3e --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/grader.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import os +import re +from typing import Any, Dict, List, Tuple + +from openjudge.graders.base_grader import BaseGrader +from openjudge.graders.schema import GraderScore + +# import path 兼容两种写法(文档里两种都出现过) +try: + from openjudge.models import OpenAIChatModel +except Exception: # pragma: no cover + from openjudge.models.openai_chat_model import OpenAIChatModel + +from .prompt import ( + QUALITY_SYSTEM_PROMPT, + USER_PROMPT_TEMPLATE, + ALL_KEYS, + A_KEYS, + B_KEYS, + C_KEYS, +) +from .json_utils import strict_load_json, validate_shape, get_score, get_note + + +class PresentationQualityGrader(BaseGrader): + """ + - 输入:report_content(研究报告文本) + - 输出:GraderScore(name, score, reason) + - score:8项按1/3/5分制评分,总分归一化到[0,1](总分/40) + - determinism:建议用 temperature=0 + disable thinking 等(见 create_default_model) + - 解析失败:score=0,并在 reason 显示报错 + """ + + def __init__( + self, + model: OpenAIChatModel, + name: str = "presentation_quality", + **kwargs: Any, + ): + super().__init__(name=name, **kwargs) + self.model = model + + @staticmethod + def create_default_model( + model_name: str, + api_key: str | None = None, + base_url: str | None = None, + deterministic: bool = True, + enable_thinking: bool = False, + seed: int = 0, + ) -> OpenAIChatModel: + """ + 你也可以不调用这个工厂,自己在外面 new OpenAIChatModel。 + QuickStart 文档确认 OpenAIChatModel 会从 OPENAI_API_KEY/OPENAI_BASE_URL 读取。 + """ + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + extra_body: Dict[str, Any] = {} + if deterministic: + # OpenAI兼容接口常见字段;DashScope/Qwen 常用 enable_thinking + extra_body.update( + { + "temperature": 0, + "top_p": 1, + "seed": seed, + "presence_penalty": 0, + "frequency_penalty": 0, + } + ) + if enable_thinking is False: + extra_body["enable_thinking"] = False + + kwargs: Dict[str, Any] = {"model": model_name} + if api_key: + kwargs["api_key"] = api_key + if base_url: + kwargs["base_url"] = base_url + if extra_body: + kwargs["extra_body"] = extra_body + + return OpenAIChatModel(**kwargs) + + async def aevaluate( + self, + report_content: str, + user_query: str | None = None, + **_: Any, + ) -> GraderScore: + """ + 入口:直接喂 report_content(研究报告文本) + - user_query 可选:用于填充 prompt;不提供则用 "(unknown)" + """ + + + report = (report_content or "").strip() + + # 清理 markdown 代码块标记 + report = self._strip_markdown_fences(report) + + if not report: + return GraderScore( + name=self.name, + score=0.0, + reason="BadInput: empty report_content", + ) + + uq = (user_query or "").strip() or "(unknown)" + + user_content = USER_PROMPT_TEMPLATE.format( + user_query=uq, + report_content=report, + ) + messages = [ + {"role": "system", "content": QUALITY_SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ] + + # 核心:OpenJudge 的 OpenAIChatModel 支持 await model.achat([...]),并返回 .content + try: + resp = await self.model.achat(messages) + raw_text = getattr(resp, "content", None) + if raw_text is None: + raw_text = str(resp) + except Exception as e: + return GraderScore( + name=self.name, + score=0.0, + reason=f"ModelCallError: {type(e).__name__}: {e}", + ) + + obj, jerr = strict_load_json(str(raw_text)) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"ParseError: {jerr}; raw[:200]={snippet}", + ) + + obj, serr = validate_shape(obj) + if obj is None: + snippet = str(raw_text)[:200].replace("\n", " ") + return GraderScore( + name=self.name, + score=0.0, + reason=f"SchemaError: {serr}; raw[:200]={snippet}", + ) + + score, reason = self._score_and_reason(obj) + + return GraderScore(name=self.name, score=score, reason=reason) + + def _score_and_reason(self, obj: Dict[str, Any]) -> Tuple[float, str]: + scan = obj["scan"] + structuring = obj["structuring"] + editorial = obj["editorial"] + top_fixes = obj.get("top_fixes", []) + + # 8项按1/3/5分制计分(强确定性:完全由Python算) + score_map: Dict[str, int] = {} + note_map: Dict[str, str] = {} + + def take(section: Dict[str, Any], key: str): + item = section.get(key) + score_map[key] = get_score(item) + note_map[key] = get_note(item) + + for k in A_KEYS: + take(scan, k) + for k in B_KEYS: + take(structuring, k) + for k in C_KEYS: + take(editorial, k) + + # 总分 = 各项得分之和 / 最高可能分 (8*5=40),归一化到[0,1] + total_score = sum(score_map.get(k, 1) for k in ALL_KEYS) + max_score = len(ALL_KEYS) * 5 # 8 * 5 = 40 + score = total_score / float(max_score) + + # reason:按分数排序,列出低分项 + low_items = [(k, score_map.get(k, 1)) for k in ALL_KEYS if score_map.get(k, 1) < 5] + low_items.sort(key=lambda x: x[1]) # 从低到高 + low_str = ", ".join(f"{k}={s}({note_map.get(k,'')})" for k, s in low_items[:4]) + fixes_str = " | ".join(str(x) for x in (top_fixes or [])[:3]) + + parts: List[str] = [] + parts.append(f"Score {total_score}/{max_score}") + if low_items: + parts.append(f"Low: {low_str}") + if fixes_str: + parts.append(f"TopFixes: {fixes_str}") + + reason = " ; ".join(parts) + return round(score, 6), reason[:800] + + @staticmethod + def _strip_markdown_fences(text: str) -> str: + """ + 清理 markdown 代码块标记 + - 移除开头的 ```markdown / ```md / ``` 等 + - 移除结尾的 ``` + """ + text = text.strip() + # 移除开头的 ```xxx + text = re.sub(r'^```(?:markdown|md)?\s*\n?', '', text, flags=re.IGNORECASE) + # 移除结尾的 ``` + text = re.sub(r'\n?```\s*$', '', text) + return text.strip() diff --git a/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py new file mode 100644 index 0000000..2852ff8 --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/json_utils.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import json +import re +from typing import Any, Dict, Tuple + + +_JSON_RE = re.compile(r"\{.*\}", re.DOTALL) + + +def extract_first_json_object(text: str) -> str | None: + """ + Best-effort: extract the first {...} block. + If none found, return None. + """ + if not text: + return None + m = _JSON_RE.search(text.strip()) + if not m: + return None + return m.group(0) + + +def strict_load_json(text: str) -> Tuple[Dict[str, Any] | None, str | None]: + """ + Return (obj, error). Any parse failure => (None, error_msg) + """ + js = extract_first_json_object(text) + if js is None: + return None, "No JSON object found in model output" + try: + obj = json.loads(js) + if not isinstance(obj, dict): + return None, f"Top-level JSON is not an object: {type(obj).__name__}" + return obj, None + except Exception as e: + return None, f"{type(e).__name__}: {e}" + + +def get_bool_pass(item: Any) -> bool: + if isinstance(item, dict): + v = item.get("pass") + else: + v = item + if isinstance(v, bool): + return v + if isinstance(v, (int, float)): + return bool(v) + if isinstance(v, str): + return v.strip().lower() in {"true", "1", "yes", "y"} + return False + + +def get_score(item: Any) -> int: + """ + Extract numeric score (1, 3, 5) from item. + Returns 1 as default if invalid. + """ + if isinstance(item, dict): + v = item.get("score") + else: + v = item + if isinstance(v, (int, float)): + v = int(v) + if v in (1, 3, 5): + return v + # clamp to valid range + if v <= 1: + return 1 + if v >= 5: + return 5 + return 3 + return 1 + + +def get_note(item: Any) -> str: + if isinstance(item, dict): + note = item.get("note", "") + else: + note = "" + note = "" if note is None else str(note) + note = note.strip() + # 最多给点余量,避免reason爆长 + return note[:120] + + +def validate_shape(obj: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]: + """ + Ensure required sections exist and are dicts; ensure top_fixes is list or str. + If missing required field => error. + """ + for sec in ("scan", "structuring", "editorial"): + if sec not in obj: + return None, f"Missing field: {sec}" + if not isinstance(obj[sec], dict): + return None, f"Field '{sec}' is not an object" + if "top_fixes" not in obj: + return None, "Missing field: top_fixes" + # normalize top_fixes + tf = obj.get("top_fixes") + if isinstance(tf, list): + obj["top_fixes"] = [str(x) for x in tf][:3] + elif tf is None: + obj["top_fixes"] = [] + else: + obj["top_fixes"] = [str(tf)][:3] + return obj, None diff --git a/tutorial/example_deep_finance/judge/presentation_quality/prompt.py b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py new file mode 100644 index 0000000..5e945bf --- /dev/null +++ b/tutorial/example_deep_finance/judge/presentation_quality/prompt.py @@ -0,0 +1,108 @@ +# 8项呈现质量检查:A(3)+B(3)+C(2)=8 +QUALITY_SYSTEM_PROMPT = """ +你是一位“深度研究报告呈现评审官”。你的任务是评估报告的 **用户体验与信息架构 (Presentation & UX)**,为强化学习提供奖励信号。 + +**严禁评估**:事实真伪、引用准确性(由 Grounding 模型负责)、内容广度与深度。 +**核心关注**:**认知负荷管理**、**信息的可扫读性**、**逻辑的可视化**、**Markdown 渲染质量**。 + +======================== +评分标准 (1/3/5 分制) +======================== +对以下 8 个维度进行打分。 +- **1分 (Fail)**:严重阻碍阅读,格式混乱或缺失。 +- **3分 (Pass)**:甚至及格,有基本结构,但平庸、啰嗦或不够直观。 +- **5分 (Excellent)**:出版级质量,结构极佳,一眼能抓取核心,降低了读者的认知成本。 + +请针对每个子项给出分数(1, 3, 5)及 Note(≤25字,指出具体位置或症状)。 + +### A) Scan & Navigation(可扫描性) +**A1 结论先行 (Key Takeaways Top)** +- 5分:开头有独立的“核心摘要/TL;DR”块,且要点清晰,读者无需滚动即可获取主结论。 +- 3分:有摘要,但写成了流水账段落,或混杂在正文中不够醒目。 +- 1分:无摘要,开篇即陷入细节或背景介绍。 + +**A2 结构导航 (Navigable Structure)** +- 5分:层级分明 (H1/H2/H3),长文有清晰的“路标”(小标题),支持快速跳读定位。 +- 3分:有分节,但段落过长(Wall of text),缺乏内部视觉引导。 +- 1分:结构混乱,标题层级错误或缺失,难以导航。 + +**A3 视觉重点 (Visual Hierarchy)** +- 5分:利用 **加粗**、*斜体* 或 `代码块` 精准强调核心洞察,信噪比高。 +- 3分:有强调,但过度使用(满篇加粗)或重点不突出(强调了无关词)。 +- 1分:全文平铺直叙,无任何视觉重点。 + +### B) Information Structuring(信息结构化) +**B1 密集信息解构 (Dense Info Structured)** +- 5分:复杂数据/多条件逻辑被转化为 Markdown **表格** 或 **嵌套列表**,一目了然。 +- 3分:使用了列表,但内容仍是长难句堆砌,未真正拆解信息。 +- 1分:关键数字或复杂参数淹没在长段落文本中。 + +**B2 对比对齐 (Comparisons Aligned)** +- 5分:涉及对比(方案A vs B / 历史 vs 现状)时,使用表格或对齐结构,维度横向可比。 +- 3分:有对比意图,但分散在不同段落,读者需来回对照。 +- 1分:对比维度混乱或缺失,无法直观比较。 + +**B3 一致性与渲染 (Consistency & Rendering)** +- 5分:格式统一(符号/单位),Markdown 渲染完美(表格无断裂、公式无乱码)。 +- 3分:存在少量格式不统一,或轻微的渲染瑕疵但不影响理解。 +- 1分:表格错位、公式未闭合、列表层级混乱,严重影响阅读。 + +### C) Editorial Clarity(编辑清晰度) +**C1 论证链可视化 (Argument Chain Presented)** +- 5分:逻辑链条可视(如使用 `主张 -> 证据 -> 结论` 的结构),引用锚点清晰 `[1]`。 +- 3分:逻辑存在,但淹没在文字中,缺乏连接词或视觉引导。 +- 1分:材料堆砌,缺乏清晰的推导线索。 + +**C2 风险与行动 (Risk & Actionability Clear)** +- 5分:独立板块清晰列出“风险/局限性”及“下一步建议”,具有极高的可操作性。 +- 3分:提到了风险或建议,但含糊其辞,或混杂在结论中。 +- 1分:完全未提及风险边界或下一步行动。 + +**反刷分原则 (Anti-Gaming)**: +- 空表格、无意义的重复列表、为了格式而格式(如把一句简单的话硬拆成列表) -> 直接判 **1分**,Note 标注“过度格式化”。 + +======================== +输出要求 (Strict JSON) +======================== +必须输出可解析 JSON。 +**注意**:为了提供梯度信号,字段由 `pass` 改为 `score`,值必须为 1, 3, or 5。 + +JSON 模板: +{ + "scan": { + "A1_key_takeaways_top": {"score": 0, "note": "≤25字定位理由"}, + "A2_navigable_structure": {"score": 0, "note": "≤25字定位理由"}, + "A3_visual_hierarchy": {"score": 0, "note": "≤25字定位理由"} + }, + "structuring": { + "B1_dense_info_structured": {"score": 0, "note": "≤25字定位理由"}, + "B2_comparisons_aligned": {"score": 0, "note": "≤25字定位理由"}, + "B3_consistency": {"score": 0, "note": "≤25字定位理由"} + }, + "editorial": { + "C1_argument_chain_presented": {"score": 0, "note": "≤25字定位理由"}, + "C2_risk_and_actionability_clear": {"score": 0, "note": "≤25字定位理由"} + }, + "top_fixes": ["最多3条,仅谈呈现层面改进,针对最低分项"] +} +""" + +USER_PROMPT_TEMPLATE = """ +请审计以下研究报告的【呈现质量】(只谈呈现/排版/结构,不谈事实对错/引用支持/覆盖/深度)。 + +### User Query +{user_query} + +### AI Report +{report_content} + +----- +请严格按 System Prompt 的锚点输出 JSON;不要输出 Markdown;不要添加额外字段。 +""".strip() + +# 8个检查项key(用于Python均分,强确定性) +A_KEYS = ["A1_key_takeaways_top", "A2_navigable_structure", "A3_visual_hierarchy"] +B_KEYS = ["B1_dense_info_structured", "B2_comparisons_aligned", "B3_consistency"] +C_KEYS = ["C1_argument_chain_presented", "C2_risk_and_actionability_clear"] + +ALL_KEYS = A_KEYS + B_KEYS + C_KEYS diff --git a/tutorial/example_deep_finance/judge/traj_adapter.py b/tutorial/example_deep_finance/judge/traj_adapter.py new file mode 100644 index 0000000..66df53f --- /dev/null +++ b/tutorial/example_deep_finance/judge/traj_adapter.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Tuple + + +def extract_text_content(content: Any) -> str: + """Extract plain text from common message schemas.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + texts: List[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(str(item.get("text", ""))) + elif isinstance(item, str): + texts.append(item) + return "\n".join(texts) + return str(content) + + +def normalize_traj(traj: Any) -> List[Dict[str, Any]]: + """ + Accept common traj shapes: + - list[{"role":..., "content":...}, ...] + - {"trajectory": [...]} + - {"messages": [...]} + """ + if isinstance(traj, list): + return traj + if isinstance(traj, dict): + if isinstance(traj.get("trajectory"), list): + return traj["trajectory"] + if isinstance(traj.get("messages"), list): + return traj["messages"] + return [] + + +def infer_user_query(trajectory: List[Dict[str, Any]]) -> str: + for step in trajectory: + if step.get("role") == "user": + txt = extract_text_content(step.get("content")) + if txt.strip(): + return txt.strip() + return "" + + +def find_final_report(trajectory: List[Dict[str, Any]]) -> str: + """ + Heuristic: last assistant long text or markdown-like content. + """ + for step in reversed(trajectory): + if step.get("role") == "assistant": + txt = extract_text_content(step.get("content", "")) + if len(txt) > 120 or "#" in txt: + return txt + return "" + + + diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 48824e1..33103fe 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -1,6 +1,6 @@ # ------------------ 主要配置 ------------------ ajet: - project_name: ajet_deep_finance + project_name: "{{PREFIX}}" experiment_name: "{{SUFFIX}}" # Judge 配置(嵌套结构,对应 self.config.ajet.judge.*) judge: @@ -10,9 +10,8 @@ ajet: train_ref_ans_path: {{TRAIN_REF_ANS_PATH}} # 训练集 Reference Answer 路径 val_ref_ans_path: {{VAL_REF_ANS_PATH}} # 验证集 Reference Answer 路径 # OpenJudge 权重配置 - report_resolution_weight: {{REPORT_RESOLUTION_WEIGHT}} # 报告质量评估 - trajectory_faithfulness_weight: {{TRAJECTORY_FAITHFULNESS_WEIGHT}} # 事实准确性评估 - citation_audit_weight: {{CITATION_AUDIT_WEIGHT}} # 引用审计评估 (覆盖率 + 真实性) + presentation_quality_weight: {{PRESENTATION_QUALITY_WEIGHT}} # 报告呈现质量评估 + grounding_weight: {{GROUNDING_WEIGHT}} # 引用规范性评估 rm_weight: {{RM_WEIGHT}} # RM Gallery 权重 task_judge: # 使用本地 DeepFinanceJudge 进行评估(解耦远程 env_service)