Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,31 @@ print(results)
- FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors.
- FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown.

### FireRedASR-AED Optimization with ROCm
1. Build docker image with `docker/Dockerfile.rocm` to setup environemt
```
docker build --network=host -f docker/Dockerfile.rocm -t rocm/firered-asr-opt
```

2. Launch docker container
```
docker run -it --ipc=host --network=host --privileged --security-opt seccomp=unconfined --cap-add=CAP_SYS_ADMIN --cap-add=SYS_PTRACE --device=/dev/kfd --device=/dev/dri --device=/dev/mem rocm/firered-asr-opt
```

3. Run performance test with native MHA (baseline)
```python
ATTENTION_BACKEND="NATIVE" python examples/benchmark_firered_asr.py
```

4. Run performance test with MHA using torch SDPA
```python
ATTENTION_BACKEND="SDPA" python examples/benchmark_firered_asr.py
```

5. Run performance test with MHA using xFormers
```python
ATTENTION_BACKEND="XFORMERS" python examples/benchmark_firered_asr.py
```

## Acknowledgements
Thanks to the following open-source works:
Expand Down
28 changes: 28 additions & 0 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
ARG BASE_IMAGE=rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0
FROM ${BASE_IMAGE} AS base

ARG PYTORCH_ROCM_ARCH=gfx942
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ENV HIP_ARCHITECTURE=${PYTORCH_ROCM_ARCH}
ENV ROCM_PATH=/opt/rocm
ENV XFORMERS_CK_FLASH_ATTN=1
ARG PYTHON_VERSION=3.12
ENV DEBIAN_FRONTEND=noninteractive

WORKDIR /root

RUN set -ex && usermod -a -G video $(whoami)

RUN python3 -m pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython

RUN git clone https://github.com/FireRedTeam/FireRedASR /root/FireRedASR \
&& cd /root/FireRedASR \
&& python3 -m pip install -r requirements.txt

RUN git clone https://github.com/ROCm/xformers.git /root/xformers \
&& cd /root/xformers \
&& git checkout 5f0419a \
&& git submodule update --init --recursive \
&& PYTORCH_ROCM_ARCH=$PYTORCH_ROCM_ARCH HIP_ARCHITECTURE=$HIP_ARCHITECTURE XFORMERS_CK_FLASH_ATTN=$XFORMERS_CK_FLASH_ATTN python3 setup.py install

WORKDIR /root/FireRedASR
187 changes: 187 additions & 0 deletions examples/benchmark_firered_asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import os
import time
import torch
import numpy as np
from tqdm import tqdm
import json
import librosa
import soundfile as sf
import argparse
import torch

from fireredasr.models.fireredasr import FireRedAsr

from torch.profiler import profile as torch_profiler
from torch.profiler import ProfilerActivity


ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "XFORMERS") # Option: "NATIVE", "SDPA", "XFORMERS"


def load_model(model_path="pretrained_models/FireRedASR-AED-L"):
print("==========Load model:========")
model = FireRedAsr.from_pretrained("aed", model_path)
model.model.to(torch.float16)
model.model.cuda()
model.model.eval()

return model

def load_audio(wav_path):
print("==========load audio:=========")
audio, sr = sf.read(wav_path,dtype=np.float32)
print(len(audio), audio.dtype)
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
return audio

def run(model, batch_wav_path, warmpup=2, trials=10, enable_profile=False, offset=0):
batch_uttid = list(range(offset, offset + len(batch_wav_path), 1))
results = None
total_dur = None

preprocess_start = time.time()
feats, lengths, durs = model.feat_extractor(batch_wav_path)
feats = feats.to(torch.float16)
feats, lengths = feats.cuda(), lengths.cuda()
preprocess_dur = time.time() - preprocess_start
print(f"preprocess duration: {preprocess_dur:.3f} s")
total_dur = sum(durs)
avg_audio_dur_per_sample = total_dur / len(durs)
print(f"total input audio duration: {total_dur:.3f} s, avg input audio duration per sample: {avg_audio_dur_per_sample:.3f} s")
# Warmup
print("==========warmup========")
for _ in range(warmpup):
with torch.no_grad():
_ = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm

# Benchmark
print("==========start benchmark========")
total_time = 0
results = []
rtf_list = []
if enable_profile:
warmup=1
trials=1
for _ in tqdm(range(trials)):
start = time.time()
with torch.no_grad():
if enable_profile:
with torch_profiler(
activities=[ProfilerActivity.CPU,
ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
profile_memory=False) as prof:
hyps = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm
print(prof.key_averages().table(sort_by="cuda_time_total"))
prof.export_chrome_trace(f"firered_asr_profile_{batch}_{ATTENTION_BACKEND}.json")
else:
hyps = model.model.transcribe(feats, lengths, beam_size=3, nbest=1, decode_max_len=0, softmax_smoothing=1.25, length_penalty=0.6, eos_penalty=1.0) # repetition_penalty=1.0, decode_min_len=0, temperature=1.0 used only for llm
total_time += time.time() - start
elapsed = time.time() - start

rtf = elapsed / total_dur if total_dur > 0 else 0
for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps):
hyp = hyp[0] # only return 1-best
hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
text = model.tokenizer.detokenize(hyp_ids)
results.append({"uttid": uttid, "text": text, "wav": wav,
"rtf": f"{rtf:.4f}"})
rtf_list.append(rtf)

avg_latency = total_time / trials
rps = batch / avg_latency
# Only print last result for debug purpose
print("Only print last run results for debug purpose...")
for res in results[-batch:]:
print(res)
avg_rtf = sum(rtf_list) / len(rtf_list)
print(f"Finished benchmark test for batch size: {len(batch_wav_path)}, average latency: {avg_latency:.3f}s | RPS: {rps:.2f}, avg RTF: {avg_rtf:.3f}")

return rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, results[-batch:]

def benchmark(model, audio_dir, batch, warmpup=2, trials=10, enable_profile=False):
# Get list of .wav files (case-insensitive)
batch_wav_path = []

# Collect file paths and durations
file_durations = []
input_wav_path_list = [os.path.join(audio_dir, f)
for f in os.listdir(audio_dir)
if f.lower().endswith('.wav')]

for file_path in input_wav_path_list:
try:
y, sr = librosa.load(file_path, sr=None) # keep original sampling rate
duration = len(y) / sr
file_durations.append((file_path, duration))
except Exception as e:
print(f"Error processing {file_path}: {e}")

# Sort by duration (longest first)
file_durations.sort(key=lambda x: x[1], reverse=True)
# Optional: print all audio file with duration
for i, (fp, dur) in enumerate(file_durations, start=1):
print(f"{i}. {fp} - {dur:.2f} sec")

dataset_size = len(input_wav_path_list)
# Loop through data in batches
benchmark_results = []
e2e_start = time.time()
for start in range(0, dataset_size - dataset_size % batch, batch):
batch_wav_path = [path for path, _ in file_durations[start:start + batch]]
print(f"Processing {batch} batched data from index {start} to {start + batch-1}")
rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, batch_wav_path, warmpup, trials, enable_profile, offset=start)
benchmark_results.append((batch, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results))

# Process remaining data if any
remainder = dataset_size % batch
if remainder:
start+=batch
last_batch_wav_path = [path for path, _ in file_durations[-remainder:]]
print(f"Processing {remainder} remaining data : {last_batch_wav_path}")
rps, avg_latency, avg_rtf, avg_audio_dur_per_sample, model_results = run(model, last_batch_wav_path, warmpup, trials, enable_profile, offset=start)
benchmark_results.append((remainder, avg_audio_dur_per_sample, avg_latency, rps, avg_rtf, model_results))
e2e_duration = time.time() - e2e_start

return benchmark_results, e2e_duration

if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='Benchmark scripts for FireRedASR', usage='%(prog)s [options]')
parser.add_argument('-b', '--batch_sizes', type=int, nargs='+', default=1, help='List of batch sizes for performance evaluation')
parser.add_argument('-m', '--model_path', type=str, default="pretrained_models/FireRedASR-AED-L", help='Path to model directory')
parser.add_argument('-a', '--audio_dir', type=str, default='examples/wav', help="Path to input audio directory")
parser.add_argument('-d', '--device', type=str, default='cuda', help="Target inference device")
parser.add_argument('-p', '--profile', action='store_true', help='Enable torch profiler')
args = parser.parse_args()
audio_dir = args.audio_dir
model_path = args.model_path
device = args.device
enable_profile = args.profile
batch_sizes = args.batch_sizes # [1, 4, 8, 16, 32, 64, 128, 256]
model = load_model(model_path)

if enable_profile:
benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=1, enable_profile=enable_profile)
else:
for batch in batch_sizes:
print(f"*************************** batch size {batch} ***************************")
benchmark_results, e2e_duration = benchmark(model, audio_dir, batch=batch, enable_profile=enable_profile)

print(f"\nbatch size: {batch}, e2e latency: {e2e_duration} s")
save_results = []
save_path = f"ATTENTION_BACKEND_{ATTENTION_BACKEND}_bs_{batch}_output.json"
for res in benchmark_results:
print(res[5])
save_results+=res[5]
for res in benchmark_results:
print(f"batch size: {res[0]}, avg audio duration per sample: {res[1]:.3f} s, avg inference latency {res[2]:.3f} s | RPS: {res[3]:.2f}, avg RTF: {res[4]:.3f}")
with open(save_path, "w", encoding="utf-8") as final:
json.dump(save_results,
final,
indent=2,
ensure_ascii=False, # Keep non-ASCII characters intact
default=lambda x: list(x) if isinstance(x, tuple) else str(x)
)
print(f"Performance results written to {save_path}")
2 changes: 1 addition & 1 deletion fireredasr/models/fireredasr.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def transcribe(self, batch_uttid, batch_wav_path, args={}):


def load_fireredasr_aed_model(model_path):
package = torch.load(model_path, map_location=lambda storage, loc: storage)
package = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
print("model args:", package["args"])
model = FireRedAsrAed.from_args(package["args"])
model.load_state_dict(package["model_state_dict"], strict=True)
Expand Down
Loading