Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions docs/source/Megatron-SWIFT/Ascend.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,135 @@ while iteration < args.train_iters:
...
prof.stop()
```

## NPU 精度数据采集
### 配置

按需修改ms-swift目录下msprobe_config.json文件中的dump_path、level等配置项
更多配置可参考[配置示例](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md)和[配置文件介绍](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md)


### 代码修改
为了支持 msprobe 工具进行精度调试,我们需要修改 `swift/megatron/model/mm_gpt_model.py` 文件中的 `_patch_word_embeddings` 函数。主要改动是调整函数参数和内部实现逻辑,使其能够正确地对嵌入层进行patch

下面是具体的修改内容:

修改前:
```python
def _patch_word_embeddings(self, kwargs):
origin_forward = VocabParallelEmbedding.forward

def forward(_self, input_):
from ..trainers.utils import split_cp_inputs
args = get_args()
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
_self.reduce_scatter_embeddings = False
input_ = torch.masked_fill(input_, input_ < 0, 0)
res = origin_forward(_self, input_)
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
packed_seq_params = kwargs.get('packed_seq_params')
# ...其他逻辑...
return res
VocabParallelEmbedding.forward = forward
try:
yield
finally:
VocabParallelEmbedding.forward = origin_forward

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
) -> torch.Tensor:
if decoder_input is not None:
pass
elif self.pre_process:
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
with self._patch_word_embeddings(kwargs):
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)

# ...其他逻辑...
```

修改后:
```python
def _patch_word_embeddings(self, kwargs, emb): # 修改1
origin_forward = emb.word_embeddings.forward # 修改2

def forward(input_): # 修改3
from ..trainers.utils import split_cp_inputs
args = get_args()
_self = emb.word_embeddings # 修改4
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
_self.reduce_scatter_embeddings = False
input_ = torch.masked_fill(input_, input_ < 0, 0)
res = origin_forward(input_) # 修改5
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
packed_seq_params = kwargs.get('packed_seq_params')
# ...其他逻辑...
return res

emb.word_embeddings.forward = forward # 修改6
try:
yield
finally:
emb.word_embeddings.forward = origin_forward # 修改7

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
) -> torch.Tensor:
if decoder_input is not None:
pass
elif self.pre_process:
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # 修改8
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)

# ...其他逻辑...
```

主要变化包括:
1. `_patch_word_embeddings` 方法增加了 `emb` 参数,用于接收 embedding 模块实例
2. 直接获取 `emb.word_embeddings.forward` 而不是 `VocabParallelEmbedding.forward`
3. 内部 `forward` 函数签名从 `(_self, input_)` 改为 `(input_)`
4. 在函数内部通过 `emb.word_embeddings` 获取 `_self`
5. 调用原始 forward 时直接传入 `input_`
6. 使用 `emb.word_embeddings.forward` 进行替换和恢复操作(修改6、7)
7. 在调用 `_patch_word_embeddings` 时传入 `self.language_model.embedding` 实例


### 使能
在启动脚本添加`--enable_msprobe True`

另外,由于msprobe不支持融合计算,还需要添加`--no_bias_dropout_fusion True`、`--no_bias_swiglu_fusion True`、`--cross_entropy_loss_fusion False`
#### 示例
```shell
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
megatron sft \
--load Qwen2.5-7B-Instruct-mcore \
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
'AI-ModelScope/alpaca-gpt4-data-en#500' \
'swift/self-cognition#500' \
--tensor_model_parallel_size 2 \
...
--no_bias_dropout_fusion True \
--no_bias_swiglu_fusion True \
--cross_entropy_loss_fusion False \
--enable_msprobe True
```
132 changes: 132 additions & 0 deletions docs/source_en/Megatron-SWIFT/Ascend.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,135 @@ while iteration < args.train_iters:
...
prof.stop()
```

# NPU Accuracy Data Collection
### Configuration

Modify the dump_path, level and other configuration items in the msprobe_config.json file under the ms-swift directory as needed.
More configurations can be found in [Configuration Examples](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md) and [Configuration File Introduction](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md)


### Code Modification
To support accuracy debugging with the msprobe tool, we need to modify the `_patch_word_embeddings` function in the `swift/megatron/model/mm_gpt_model.py` file. The main changes are to adjust the function parameters and internal implementation logic so that it can correctly patch the embedding layer.

The specific modification content is as follows:

Before modification:
```python
def _patch_word_embeddings(self, kwargs):
origin_forward = VocabParallelEmbedding.forward

def forward(_self, input_):
from ..trainers.utils import split_cp_inputs
args = get_args()
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
_self.reduce_scatter_embeddings = False
input_ = torch.masked_fill(input_, input_ < 0, 0)
res = origin_forward(_self, input_)
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
packed_seq_params = kwargs.get('packed_seq_params')
# ...other logic...
return res
VocabParallelEmbedding.forward = forward
try:
yield
finally:
VocabParallelEmbedding.forward = origin_forward

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
) -> torch.Tensor:
if decoder_input is not None:
pass
elif self.pre_process:
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
with self._patch_word_embeddings(kwargs):
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)

# ...other logic...
```

After modification:
```python
def _patch_word_embeddings(self, kwargs, emb): # Modification 1
origin_forward = emb.word_embeddings.forward # Modification 2

def forward(input_): # Modification 3
from ..trainers.utils import split_cp_inputs
args = get_args()
_self = emb.word_embeddings # Modification 4
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
_self.reduce_scatter_embeddings = False
input_ = torch.masked_fill(input_, input_ < 0, 0)
res = origin_forward(input_) # Modification 5
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
packed_seq_params = kwargs.get('packed_seq_params')
# ...other logic...
return res

emb.word_embeddings.forward = forward # Modification 6
try:
yield
finally:
emb.word_embeddings.forward = origin_forward # Modification 7

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
decoder_input: torch.Tensor = None,
labels: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
**kwargs,
) -> torch.Tensor:
if decoder_input is not None:
pass
elif self.pre_process:
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # Modification 8
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)

# ...other logic...
```

Major changes include:
1. The `_patch_word_embeddings` method adds an `emb` parameter to receive the embedding module instance
2. Directly obtain `emb.word_embeddings.forward` instead of `VocabParallelEmbedding.forward`
3. The internal `forward` function signature changed from `(_self, input_)` to `(input_)`
4. Get `_self` through `emb.word_embeddings` inside the function
5. Pass `input_` directly when calling the original forward
6. Use `emb.word_embeddings.forward` for replacement and recovery operations (Modifications 6, 7)
7. Pass the `self.language_model.embedding` instance when calling `_patch_word_embeddings`


### Enablement
Add `--enable_msprobe True` to the startup script

In addition, since msprobe does not support fused computation, you also need to add `--no_bias_dropout_fusion True`, `--no_bias_swiglu_fusion True`, `--cross_entropy_loss_fusion False`
#### Example
```shell
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=2 \
CUDA_VISIBLE_DEVICES=0,1 \
megatron sft \
--load Qwen2.5-7B-Instruct-mcore \
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
'AI-ModelScope/alpaca-gpt4-data-en#500' \
'swift/self-cognition#500' \
--tensor_model_parallel_size 2 \
...
--no_bias_dropout_fusion True \
--no_bias_swiglu_fusion True \
--cross_entropy_loss_fusion False \
--enable_msprobe True
```
15 changes: 15 additions & 0 deletions msprobe_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"task": "statistics",
"dump_path": "./dump_path",
"rank": [],
"step": [],
"level": "mix",
"async_dump": false,
"statistics": {
"scope": [],
"list": [],
"tensor_list": [],
"data_mode": ["all"],
"summary_mode": "statistics"
}
}
1 change: 1 addition & 0 deletions requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ importlib_metadata
jieba
json_repair
matplotlib
mindstudio-probe
modelscope>=1.23
nltk
numpy
Expand Down
4 changes: 4 additions & 0 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
# qwen3_vl, qwen3_omni
mrope_interleaved: Optional[bool] = None

# dump
enable_msprobe: bool = False
msprobe_config: str = './msprobe_config.json'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The msprobe_config is defined with a relative path (./msprobe_config.json). This can lead to issues if the working directory changes or if the script is executed from a different location. Consider resolving this path to an absolute path during initialization or making it configurable to ensure robustness.


@staticmethod
def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]:
res = {}
Expand Down
15 changes: 13 additions & 2 deletions swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,19 @@ def _all_reduce_metric(self,
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
**kwargs):
new_data_iterator = self._replace_data_iterator(data_iterator, model)
return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
config, *args, **kwargs)
debugger_on = self.args.enable_msprobe
if debugger_on:
from msprobe.pytorch import PrecisionDebugger
debugger = PrecisionDebugger(config_path=self.args.msprobe_config, model=model)
debugger.start()
try:
origin_train_step_out = self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
config, *args, **kwargs)
finally:
if debugger_on:
debugger.stop()
debugger.step()
return origin_train_step_out

# Code borrowed from NVIDIA/Megatron-LM
def evaluate(
Expand Down