Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e6b0db1
AMMO integration with Llama2 PTQ example and tests
janekl Feb 16, 2024
41b3f6d
Jenkins megatron_llama_quantization.py test setup
janekl Feb 16, 2024
71d9529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2024
9c2d7f4
License headers
janekl Feb 19, 2024
ae88e47
Add AMMO to requirements_nlp.txt with --extra-index-url for pip install
janekl Feb 19, 2024
6ca03d4
Bump AMMO version to latest
janekl Feb 22, 2024
5170db5
Guards workaround on spec definition
janekl Feb 22, 2024
543dea1
Save artifacts and tokenizer config at once
janekl Feb 29, 2024
785167f
Extend nemo.utils package with new tools
janekl Feb 29, 2024
0f29ac4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 29, 2024
acfb441
Reorganize & reformat
janekl Mar 4, 2024
e03cb87
Tests for FP8 and INT4 AWQ
janekl Mar 4, 2024
4b49ec6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
90988b1
Add load_config helper function
janekl Mar 4, 2024
d115bef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
6fcbcd0
Unused import removal
janekl Mar 4, 2024
a5e818f
Fix FP8 Jenkins test
janekl Mar 4, 2024
12f3717
Fix TP=2 test cont'd: no need to use mpirun
janekl Mar 5, 2024
a96be0f
Allow for patches in AMMO versioning
janekl Mar 5, 2024
c99b992
Drop AWQ test for now (need to debug)
janekl Mar 5, 2024
2eff82f
Allow for patches in AMMO versioning cont'd
janekl Mar 5, 2024
f4f6f37
Merge branch 'main' into jlasek/ammo_integration
janekl Mar 6, 2024
739fe30
Use AMMO spec from MCore as it has been published
janekl Mar 6, 2024
ae0498d
Make AMMO optional dependency and properly import guard it
janekl Mar 8, 2024
b56ff60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
01f215d
Add Llama2 AWQ test and update some paths
janekl Mar 12, 2024
fe1eeba
Enable specifying quantization.algorithm=null for baseline accuracy c…
janekl Mar 12, 2024
3a7f07e
Enable exporting qnemo tarball or just to a directory
janekl Mar 12, 2024
ac52816
Drop AWQ testing for now
janekl Mar 12, 2024
81e8e07
Test case for export.inference_tensor_parallel=2
janekl Mar 12, 2024
bf03390
Flag to export TRT-LLM config.json
janekl Mar 12, 2024
d38af45
Merge branch 'main' into jlasek/ammo_integration
janekl Mar 12, 2024
d05f8f5
Merge branch 'main' into jlasek/ammo_integration
janekl Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ WORKDIR /workspace/
# We leave it here in case we need to work off of a specific commit in main
RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout ad53b1e38689a0ceed75ade7821f4e6c7554abb4 && \
git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \
pip install .

# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771
Expand Down Expand Up @@ -132,6 +132,8 @@ RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-chec
RUN pip install flash-attn
# install numba for latest containers
RUN pip install numba>=0.57.1
# install ammo
RUN pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir

# copy nemo source into a scratch image
FROM scratch as nemo-src
Expand Down
65 changes: 62 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,17 @@ pipeline {
steps {
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 5f9c870f9f24b482509699d206a9dbb00958f6fc && \
git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \
pip install .'
}
}

stage('AMMO installation') {
steps {
sh 'pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir'
}
}

stage('PyTorch Lightning version') {
steps {
sh 'python -c "import pytorch_lightning; print(pytorch_lightning.__version__)"'
Expand Down Expand Up @@ -390,6 +396,12 @@ pipeline {
}
}

stage('Setup test data and models') {
steps {
sh 'python -m tests.setup --save_dir /home/TestData/nlp'
}
}

// TODO: this requires TE >= v0.11 which is not available in 23.06.
// please uncomment this test once mcore CI is ready.
stage('L2: Community LLM Checkpoints tests') {
Expand All @@ -405,9 +417,8 @@ pipeline {
steps {
sh 'CUDA_VISIBLE_DEVICES=0 python scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py \
--in-file=/home/TestData/nlp/megatron_llama/llama-ci-hf \
--out-file=/home/TestData/nlp/megatron_llama/ci.nemo \
--out-file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \
--precision=16'
sh 'rm -f /home/TestData/nlp/megatron_llama/ci.nemo'
}
}
stage('StarCoder') {
Expand Down Expand Up @@ -439,6 +450,54 @@ pipeline {
}
}

stage('L2: Nemo PTQ') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
parallel {
stage('Llama2 - Export Only') {
steps {
sh 'python examples/nlp/language_modeling/megatron_llama_quantization.py \
model_file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \
quantization.algorithm=null \
model_save=/home/TestData/nlp/megatron_llama/ci_baseline'
sh 'rm -rf /home/TestData/nlp/megatron_llama/ci_baseline'
}
}
stage('Llama2 - INT8 SQ') {
steps {
sh 'python examples/nlp/language_modeling/megatron_llama_quantization.py \
model_file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \
quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \
quantization.algorithm=int8_sq \
quantization.num_calib_size=8 \
inference.batch_size=2 \
model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo'
sh 'rm -f /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo'
}
}
stage('Llama2 - FP8') {
steps {
sh 'python examples/nlp/language_modeling/megatron_llama_quantization.py \
model_file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \
tensor_model_parallel_size=2 \
trainer.devices=2 \
quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \
quantization.algorithm=fp8 \
quantization.num_calib_size=8 \
inference.batch_size=2 \
export.inference_tensor_parallel=2 \
model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo'
sh 'rm -f /home/TestData/nlp/megatron_llama/ci_fp8.qnemo'
}
}
}
}

stage('L2: ASR dev run') {
when {
anyOf {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
inference:
greedy: false # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: true # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: false # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: false # a flag used to compute logprob of all the input text, a very special case of running inference, default False
batch_size: 64 # batch size for inference
max_context_length: 512 # max length of the context, input sequence will be truncated if it is longer than this

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: false # logger provided by exp_manager
precision: bf16 # 16, 32, or bf16
enable_checkpointing: false

quantization:
quantize_bmm1: false
algorithm: fp8 # int8_sq, fp8, int8, int4_awq, null
calib_dataset: cnn_dailymail # pileval, wikitext, cnn_dailymail
num_calib_size: 512 # number of samples used for calibration

export:
decoder_type: llama # gptnext, gpt2, llama
inference_tensor_parallel: 1 # Default using 1 TP for inference
dtype: 16 # Default precision data type
export_tensorrt_llm_config: true # export config to build TRT-LLM engine directly

model_file: llama2-7b-fp16.nemo # Nemo file path
model_save: llama2-7b-fp8.qnemo # Path where the quantized model will be saved
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
93 changes: 93 additions & 0 deletions examples/nlp/language_modeling/megatron_llama_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.multiprocessing as mp
from datasets import load_dataset

from nemo.core.config import hydra_runner
from nemo.export.quantize import Quantizer

mp.set_start_method("spawn", force=True)

"""
Nemo quantization example script.

Please consult nemo.export.quantize.Quantizer class
and examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml config on available quantization methods,
models supported as well as how to set up data and inference for calibration (with defaults recommended).

Example usage:
```
python examples/nlp/language_modeling/megatron_llama_quantization.py \
model_file=llama2-7b-fp16.nemo \
model_save=llama2-7b-fp8.qnemo \
quantization.algorithm=fp8 \
export.decoder_type=llama \
export.inference_tensor_parallel=1
```
"""


def get_calib_dataloader(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512):
if data == "pileval":
dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link doesn't work. This one should be okay: https://huggingface.co/datasets/monology/pile-uncopyrighted

text_column = "text"
elif data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
text_column = "text"
elif data == "cnn_dailymail":
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
text_column = "article"
else:
# Assume a local JSON dataset with a column named "text"
dataset = load_dataset("json", data_files=data, split="train")
text_column = "text"
calib_size = max(min(len(dataset), calib_size), batch_size)
for i in range(calib_size // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size][text_column]
for j in range(len(batch)):
batch[j] = batch[j][:max_sequence_length]
yield batch


@hydra_runner(config_path="conf", config_name="megatron_llama_quantization")
def main(cfg) -> None:
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for the inference.")

quantizer = Quantizer(cfg.quantization, cfg.inference, cfg.export, cfg.trainer)

# Quantization algorithm can be set to None. This is useful for baseline precision
# accuracy validation. In this case only weights export step will be performed:
if cfg.quantization.algorithm is not None:
dataloader = get_calib_dataloader(
cfg.quantization.calib_dataset,
cfg.inference.batch_size,
cfg.quantization.num_calib_size,
cfg.inference.max_context_length,
)
dataloader = [data for data in dataloader]
else:
dataloader = None

model = quantizer.quantize(
cfg.model_file, dataloader, cfg.tensor_model_parallel_size, cfg.pipeline_model_parallel_size
)

quantizer.export(model, cfg.model_save)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.deploy.gpt.model_specs import get_gpt_layer_ammo_spec
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
Expand Down Expand Up @@ -140,6 +141,7 @@ def get_specs(spec_name, num_experts=None):
"": get_gpt_layer_with_transformer_engine_spec(num_experts),
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(),
"ammo": get_gpt_layer_ammo_spec(),
}
if spec_name not in name_spec_dict:
raise ValueError(f"Spec name '{spec_name}' is not recognized.")
Expand Down
13 changes: 13 additions & 0 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15 changes: 15 additions & 0 deletions nemo/export/quantize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .quantizer import Quantizer
Loading