From 165b151353e3509021100525326253ef74c330a4 Mon Sep 17 00:00:00 2001 From: wanghao Date: Mon, 15 Dec 2025 15:57:54 +0800 Subject: [PATCH] add msprobe support --- docs/source/Megatron-SWIFT/Ascend.md | 132 +++++++++++++++++++++++ docs/source_en/Megatron-SWIFT/Ascend.md | 132 +++++++++++++++++++++++ msprobe_config.json | 15 +++ requirements/framework.txt | 1 + swift/megatron/argument/megatron_args.py | 4 + swift/megatron/trainers/base.py | 15 ++- 6 files changed, 297 insertions(+), 2 deletions(-) create mode 100644 msprobe_config.json diff --git a/docs/source/Megatron-SWIFT/Ascend.md b/docs/source/Megatron-SWIFT/Ascend.md index 1d822685c1..72ccad9567 100644 --- a/docs/source/Megatron-SWIFT/Ascend.md +++ b/docs/source/Megatron-SWIFT/Ascend.md @@ -41,3 +41,135 @@ while iteration < args.train_iters: ... prof.stop() ``` + +## NPU 精度数据采集 +### 配置 + +按需修改ms-swift目录下msprobe_config.json文件中的dump_path、level等配置项 +更多配置可参考[配置示例](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md)和[配置文件介绍](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md) + + +### 代码修改 +为了支持 msprobe 工具进行精度调试,我们需要修改 `swift/megatron/model/mm_gpt_model.py` 文件中的 `_patch_word_embeddings` 函数。主要改动是调整函数参数和内部实现逻辑,使其能够正确地对嵌入层进行patch + +下面是具体的修改内容: + +修改前: +```python +def _patch_word_embeddings(self, kwargs): + origin_forward = VocabParallelEmbedding.forward + + def forward(_self, input_): + from ..trainers.utils import split_cp_inputs + args = get_args() + reduce_scatter_embeddings = _self.reduce_scatter_embeddings + _self.reduce_scatter_embeddings = False + input_ = torch.masked_fill(input_, input_ < 0, 0) + res = origin_forward(_self, input_) + _self.reduce_scatter_embeddings = reduce_scatter_embeddings + packed_seq_params = kwargs.get('packed_seq_params') + # ...其他逻辑... + return res + VocabParallelEmbedding.forward = forward + try: + yield + finally: + VocabParallelEmbedding.forward = origin_forward + +def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + **kwargs, +) -> torch.Tensor: + if decoder_input is not None: + pass + elif self.pre_process: + kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params}) + with self._patch_word_embeddings(kwargs): + decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) + + # ...其他逻辑... +``` + +修改后: +```python +def _patch_word_embeddings(self, kwargs, emb): # 修改1 + origin_forward = emb.word_embeddings.forward # 修改2 + + def forward(input_): # 修改3 + from ..trainers.utils import split_cp_inputs + args = get_args() + _self = emb.word_embeddings # 修改4 + reduce_scatter_embeddings = _self.reduce_scatter_embeddings + _self.reduce_scatter_embeddings = False + input_ = torch.masked_fill(input_, input_ < 0, 0) + res = origin_forward(input_) # 修改5 + _self.reduce_scatter_embeddings = reduce_scatter_embeddings + packed_seq_params = kwargs.get('packed_seq_params') + # ...其他逻辑... + return res + + emb.word_embeddings.forward = forward # 修改6 + try: + yield + finally: + emb.word_embeddings.forward = origin_forward # 修改7 + +def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + **kwargs, +) -> torch.Tensor: + if decoder_input is not None: + pass + elif self.pre_process: + kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params}) + with self._patch_word_embeddings(kwargs, self.language_model.embedding): # 修改8 + decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) + + # ...其他逻辑... +``` + +主要变化包括: +1. `_patch_word_embeddings` 方法增加了 `emb` 参数,用于接收 embedding 模块实例 +2. 直接获取 `emb.word_embeddings.forward` 而不是 `VocabParallelEmbedding.forward` +3. 内部 `forward` 函数签名从 `(_self, input_)` 改为 `(input_)` +4. 在函数内部通过 `emb.word_embeddings` 获取 `_self` +5. 调用原始 forward 时直接传入 `input_` +6. 使用 `emb.word_embeddings.forward` 进行替换和恢复操作(修改6、7) +7. 在调用 `_patch_word_embeddings` 时传入 `self.language_model.embedding` 实例 + + +### 使能 +在启动脚本添加`--enable_msprobe True` + +另外,由于msprobe不支持融合计算,还需要添加`--no_bias_dropout_fusion True`、`--no_bias_swiglu_fusion True`、`--cross_entropy_loss_fusion False` +#### 示例 +```shell +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ + 'AI-ModelScope/alpaca-gpt4-data-en#500' \ + 'swift/self-cognition#500' \ + --tensor_model_parallel_size 2 \ + ... + --no_bias_dropout_fusion True \ + --no_bias_swiglu_fusion True \ + --cross_entropy_loss_fusion False \ + --enable_msprobe True +``` diff --git a/docs/source_en/Megatron-SWIFT/Ascend.md b/docs/source_en/Megatron-SWIFT/Ascend.md index ee27bd057b..47f5d0e0b1 100644 --- a/docs/source_en/Megatron-SWIFT/Ascend.md +++ b/docs/source_en/Megatron-SWIFT/Ascend.md @@ -41,3 +41,135 @@ while iteration < args.train_iters: ... prof.stop() ``` + +# NPU Accuracy Data Collection +### Configuration + +Modify the dump_path, level and other configuration items in the msprobe_config.json file under the ms-swift directory as needed. +More configurations can be found in [Configuration Examples](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md) and [Configuration File Introduction](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md) + + +### Code Modification +To support accuracy debugging with the msprobe tool, we need to modify the `_patch_word_embeddings` function in the `swift/megatron/model/mm_gpt_model.py` file. The main changes are to adjust the function parameters and internal implementation logic so that it can correctly patch the embedding layer. + +The specific modification content is as follows: + +Before modification: +```python +def _patch_word_embeddings(self, kwargs): + origin_forward = VocabParallelEmbedding.forward + + def forward(_self, input_): + from ..trainers.utils import split_cp_inputs + args = get_args() + reduce_scatter_embeddings = _self.reduce_scatter_embeddings + _self.reduce_scatter_embeddings = False + input_ = torch.masked_fill(input_, input_ < 0, 0) + res = origin_forward(_self, input_) + _self.reduce_scatter_embeddings = reduce_scatter_embeddings + packed_seq_params = kwargs.get('packed_seq_params') + # ...other logic... + return res + VocabParallelEmbedding.forward = forward + try: + yield + finally: + VocabParallelEmbedding.forward = origin_forward + +def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + **kwargs, +) -> torch.Tensor: + if decoder_input is not None: + pass + elif self.pre_process: + kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params}) + with self._patch_word_embeddings(kwargs): + decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) + + # ...other logic... +``` + +After modification: +```python +def _patch_word_embeddings(self, kwargs, emb): # Modification 1 + origin_forward = emb.word_embeddings.forward # Modification 2 + + def forward(input_): # Modification 3 + from ..trainers.utils import split_cp_inputs + args = get_args() + _self = emb.word_embeddings # Modification 4 + reduce_scatter_embeddings = _self.reduce_scatter_embeddings + _self.reduce_scatter_embeddings = False + input_ = torch.masked_fill(input_, input_ < 0, 0) + res = origin_forward(input_) # Modification 5 + _self.reduce_scatter_embeddings = reduce_scatter_embeddings + packed_seq_params = kwargs.get('packed_seq_params') + # ...other logic... + return res + + emb.word_embeddings.forward = forward # Modification 6 + try: + yield + finally: + emb.word_embeddings.forward = origin_forward # Modification 7 + +def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + **kwargs, +) -> torch.Tensor: + if decoder_input is not None: + pass + elif self.pre_process: + kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params}) + with self._patch_word_embeddings(kwargs, self.language_model.embedding): # Modification 8 + decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) + + # ...other logic... +``` + +Major changes include: +1. The `_patch_word_embeddings` method adds an `emb` parameter to receive the embedding module instance +2. Directly obtain `emb.word_embeddings.forward` instead of `VocabParallelEmbedding.forward` +3. The internal `forward` function signature changed from `(_self, input_)` to `(input_)` +4. Get `_self` through `emb.word_embeddings` inside the function +5. Pass `input_` directly when calling the original forward +6. Use `emb.word_embeddings.forward` for replacement and recovery operations (Modifications 6, 7) +7. Pass the `self.language_model.embedding` instance when calling `_patch_word_embeddings` + + +### Enablement +Add `--enable_msprobe True` to the startup script + +In addition, since msprobe does not support fused computation, you also need to add `--no_bias_dropout_fusion True`, `--no_bias_swiglu_fusion True`, `--cross_entropy_loss_fusion False` +#### Example +```shell +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ + 'AI-ModelScope/alpaca-gpt4-data-en#500' \ + 'swift/self-cognition#500' \ + --tensor_model_parallel_size 2 \ + ... + --no_bias_dropout_fusion True \ + --no_bias_swiglu_fusion True \ + --cross_entropy_loss_fusion False \ + --enable_msprobe True +``` diff --git a/msprobe_config.json b/msprobe_config.json new file mode 100644 index 0000000000..024a11fdca --- /dev/null +++ b/msprobe_config.json @@ -0,0 +1,15 @@ +{ + "task": "statistics", + "dump_path": "./dump_path", + "rank": [], + "step": [], + "level": "mix", + "async_dump": false, + "statistics": { + "scope": [], + "list": [], + "tensor_list": [], + "data_mode": ["all"], + "summary_mode": "statistics" + } +} \ No newline at end of file diff --git a/requirements/framework.txt b/requirements/framework.txt index 874ae2bb46..f19e4cac2d 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -14,6 +14,7 @@ importlib_metadata jieba json_repair matplotlib +mindstudio-probe modelscope>=1.23 nltk numpy diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 5db3660a58..43673cba9b 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -359,6 +359,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): # qwen3_vl, qwen3_omni mrope_interleaved: Optional[bool] = None + # dump + enable_msprobe: bool = False + msprobe_config: str = './msprobe_config.json' + @staticmethod def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]: res = {} diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 0f0614e342..cd52e70694 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -522,8 +522,19 @@ def _all_reduce_metric(self, def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args, **kwargs): new_data_iterator = self._replace_data_iterator(data_iterator, model) - return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, - config, *args, **kwargs) + debugger_on = self.args.enable_msprobe + if debugger_on: + from msprobe.pytorch import PrecisionDebugger + debugger = PrecisionDebugger(config_path=self.args.msprobe_config, model=model) + debugger.start() + try: + origin_train_step_out = self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, + config, *args, **kwargs) + finally: + if debugger_on: + debugger.stop() + debugger.step() + return origin_train_step_out # Code borrowed from NVIDIA/Megatron-LM def evaluate(