Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions benchmarks/mlperf/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

## Create TPU VM.
Follow these [instructions](https://cloud.google.com/tpu/docs/v5e-inference#tpu-vm) to create TPU v5e-8 VM and ssh into the VM

Expand Down Expand Up @@ -46,8 +45,8 @@ accelerate==0.21.0
```
export DATA_DISK_DIR=~/loadgen_run_data
mkdir -p ${DATA_DISK_DIR}
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl ${DATA_DISK_DIR}/processed-calibration-data.pkl
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl ${DATA_DISK_DIR}/processed-data.pkl
gcloud storage cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl ${DATA_DISK_DIR}/processed-calibration-data.pkl
gcloud storage cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl ${DATA_DISK_DIR}/processed-data.pkl
```

## Download Maxtext and Jetstream
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/mlperf/scripts/download_loadgen_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ export DATA_DISK_DIR=/loadgen_run_data

mkdir -p ${DATA_DISK_DIR}
cd ${DATA_DISK_DIR}
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl .
gcloud storage cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl .
mv open_orca_gpt4_tokenized_llama.calibration_1000.pkl processed-calibration-data.pkl

gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl .
gcloud storage cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl .
mv open_orca_gpt4_tokenized_llama.sampled_24576.pkl processed-data.pkl
2 changes: 1 addition & 1 deletion benchmarks/mlperf/scripts/init.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ git clone https://github.com/google/jax.git
cd jax
git reset 44359cb30ab5cdbe97e6b78c2c64fe9f8add29ca --hard
pip install -e .
gsutil cp gs://zhihaoshan-maxtext-profiling/jax_proxy_stream_buffer/jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64-mlperf_version_3.whl .
gcloud storage cp gs://zhihaoshan-maxtext-profiling/jax_proxy_stream_buffer/jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64-mlperf_version_3.whl .
mv jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64-mlperf_version_3.whl jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64.whl
pip install jaxlib-0.4.31.dev20240719-cp310-cp310-manylinux2014_x86_64.whl
4 changes: 2 additions & 2 deletions benchmarks/mlperf/scripts/launch_microbenchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ echo "inference_microbenchmark_prefill_lengths: ${inference_microbenchmark_prefi
cd /maxtext
export run_dir=${base_output_dir}/microbenchmark/${run_name}/${experiment_time}/
echo "run_dir: ${run_dir}"
gsutil cp ${config_file_path} ${run_dir}/
gcloud storage cp ${config_file_path} ${run_dir}/

python3 -m MaxText.inference_microbenchmark \
${config_file_path} \
Expand Down Expand Up @@ -43,4 +43,4 @@ python3 -m MaxText.inference_microbenchmark \
checkpoint_is_quantized=${checkpoint_is_quantized} \
compute_axis_order=${compute_axis_order} \
prefill_cache_axis_order=${prefill_cache_axis_order} \
ar_cache_axis_order=${ar_cache_axis_order} 2>&1 | tee results.log && gsutil mv results.log ${run_dir}/
ar_cache_axis_order=${ar_cache_axis_order} 2>&1 | tee results.log && gcloud storage mv results.log ${run_dir}/
4 changes: 2 additions & 2 deletions docs/online-inference-with-maxtext-engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ You can run the JetStream MaxText Server with Gemma and Llama2 models. This sect

* You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b).
* After downloading orbax Gemma checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively.
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* `gcloud storage cp --recursive ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
* Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.

Expand All @@ -70,7 +70,7 @@ Note: For more information about the Gemma model and checkpoints, see [About Gem

* You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/).
* After downloading PyTorch checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`. You should also set two more paths `$MAXTEXT_BUCKET_SCANNED` and `$MAXTEXT_BUCKET_UNSCANNED` that point to the locations of the maxtext checkpoints for the scanned and unscanned (inference-optimized) versions, respectively.
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* `gcloud storage cp --recursive ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
* Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
* Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.

Expand Down
2 changes: 1 addition & 1 deletion jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs
export DATASET_PATH=gs://${USER}-maxtext-dataset

# Prepare C4 dataset for fine tuning: https://github.com/allenai/allennlp/discussions/5056
sudo gsutil -u $4 -m cp 'gs://allennlp-tensorflow-datasets/c4/en/3.0.1/*' ${DATASET_PATH}/c4/en/3.0.1/
sudo gcloud storage cp 'gs://allennlp-tensorflow-datasets/c4/en/3.0.1/*' ${DATASET_PATH}/c4/en/3.0.1/

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
Expand Down
Loading