-
Notifications
You must be signed in to change notification settings - Fork 3.3k
AMMO Integration with Llama2 Post-Training Quantization Example and Tests #8444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
e6b0db1
AMMO integration with Llama2 PTQ example and tests
janekl 41b3f6d
Jenkins megatron_llama_quantization.py test setup
janekl 71d9529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9c2d7f4
License headers
janekl ae88e47
Add AMMO to requirements_nlp.txt with --extra-index-url for pip install
janekl 6ca03d4
Bump AMMO version to latest
janekl 5170db5
Guards workaround on spec definition
janekl 543dea1
Save artifacts and tokenizer config at once
janekl 785167f
Extend nemo.utils package with new tools
janekl 0f29ac4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] acfb441
Reorganize & reformat
janekl e03cb87
Tests for FP8 and INT4 AWQ
janekl 4b49ec6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 90988b1
Add load_config helper function
janekl d115bef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6fcbcd0
Unused import removal
janekl a5e818f
Fix FP8 Jenkins test
janekl 12f3717
Fix TP=2 test cont'd: no need to use mpirun
janekl a96be0f
Allow for patches in AMMO versioning
janekl c99b992
Drop AWQ test for now (need to debug)
janekl 2eff82f
Allow for patches in AMMO versioning cont'd
janekl f4f6f37
Merge branch 'main' into jlasek/ammo_integration
janekl 739fe30
Use AMMO spec from MCore as it has been published
janekl ae0498d
Make AMMO optional dependency and properly import guard it
janekl b56ff60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 01f215d
Add Llama2 AWQ test and update some paths
janekl fe1eeba
Enable specifying quantization.algorithm=null for baseline accuracy c…
janekl 3a7f07e
Enable exporting qnemo tarball or just to a directory
janekl ac52816
Drop AWQ testing for now
janekl 81e8e07
Test case for export.inference_tensor_parallel=2
janekl bf03390
Flag to export TRT-LLM config.json
janekl d38af45
Merge branch 'main' into jlasek/ammo_integration
janekl d05f8f5
Merge branch 'main' into jlasek/ammo_integration
janekl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| inference: | ||
| greedy: false # Whether or not to use sampling ; use greedy decoding otherwise | ||
| top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. | ||
| top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. | ||
| temperature: 1.0 # sampling temperature | ||
| add_BOS: true # add the bos token at the begining of the prompt | ||
| tokens_to_generate: 30 # The minimum length of the sequence to be generated. | ||
| all_probs: false # whether return the log prob for all the tokens in vocab | ||
| repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. | ||
| min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. | ||
| compute_logprob: false # a flag used to compute logprob of all the input text, a very special case of running inference, default False | ||
| batch_size: 64 # batch size for inference | ||
| max_context_length: 512 # max length of the context, input sequence will be truncated if it is longer than this | ||
|
|
||
| trainer: | ||
| devices: 1 | ||
| num_nodes: 1 | ||
| accelerator: gpu | ||
| logger: false # logger provided by exp_manager | ||
| precision: bf16 # 16, 32, or bf16 | ||
| enable_checkpointing: false | ||
|
|
||
| quantization: | ||
| quantize_bmm1: false | ||
| algorithm: fp8 # int8_sq, fp8, int8, int4_awq, null | ||
| calib_dataset: cnn_dailymail # pileval, wikitext, cnn_dailymail | ||
| num_calib_size: 512 # number of samples used for calibration | ||
|
|
||
| export: | ||
| decoder_type: llama # gptnext, gpt2, llama | ||
| inference_tensor_parallel: 1 # Default using 1 TP for inference | ||
| dtype: 16 # Default precision data type | ||
| export_tensorrt_llm_config: true # export config to build TRT-LLM engine directly | ||
|
|
||
| model_file: llama2-7b-fp16.nemo # Nemo file path | ||
| model_save: llama2-7b-fp8.qnemo # Path where the quantized model will be saved | ||
| tensor_model_parallel_size: 1 | ||
| pipeline_model_parallel_size: 1 |
93 changes: 93 additions & 0 deletions
93
examples/nlp/language_modeling/megatron_llama_quantization.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import torch | ||
| import torch.multiprocessing as mp | ||
| from datasets import load_dataset | ||
|
|
||
| from nemo.core.config import hydra_runner | ||
| from nemo.export.quantize import Quantizer | ||
|
|
||
| mp.set_start_method("spawn", force=True) | ||
|
|
||
| """ | ||
| Nemo quantization example script. | ||
|
|
||
| Please consult nemo.export.quantize.Quantizer class | ||
| and examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml config on available quantization methods, | ||
| models supported as well as how to set up data and inference for calibration (with defaults recommended). | ||
|
|
||
| Example usage: | ||
| ``` | ||
| python examples/nlp/language_modeling/megatron_llama_quantization.py \ | ||
| model_file=llama2-7b-fp16.nemo \ | ||
| model_save=llama2-7b-fp8.qnemo \ | ||
| quantization.algorithm=fp8 \ | ||
| export.decoder_type=llama \ | ||
| export.inference_tensor_parallel=1 | ||
| ``` | ||
| """ | ||
|
|
||
|
|
||
| def get_calib_dataloader(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): | ||
| if data == "pileval": | ||
| dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train") | ||
| text_column = "text" | ||
| elif data == "wikitext": | ||
| dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") | ||
| text_column = "text" | ||
| elif data == "cnn_dailymail": | ||
| dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") | ||
| text_column = "article" | ||
| else: | ||
| # Assume a local JSON dataset with a column named "text" | ||
| dataset = load_dataset("json", data_files=data, split="train") | ||
| text_column = "text" | ||
| calib_size = max(min(len(dataset), calib_size), batch_size) | ||
| for i in range(calib_size // batch_size): | ||
| batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] | ||
| for j in range(len(batch)): | ||
| batch[j] = batch[j][:max_sequence_length] | ||
| yield batch | ||
|
|
||
|
|
||
| @hydra_runner(config_path="conf", config_name="megatron_llama_quantization") | ||
| def main(cfg) -> None: | ||
| if not torch.cuda.is_available(): | ||
| raise EnvironmentError("GPU is required for the inference.") | ||
|
|
||
| quantizer = Quantizer(cfg.quantization, cfg.inference, cfg.export, cfg.trainer) | ||
|
|
||
| # Quantization algorithm can be set to None. This is useful for baseline precision | ||
| # accuracy validation. In this case only weights export step will be performed: | ||
| if cfg.quantization.algorithm is not None: | ||
| dataloader = get_calib_dataloader( | ||
| cfg.quantization.calib_dataset, | ||
| cfg.inference.batch_size, | ||
| cfg.quantization.num_calib_size, | ||
| cfg.inference.max_context_length, | ||
| ) | ||
| dataloader = [data for data in dataloader] | ||
| else: | ||
| dataloader = None | ||
|
|
||
| model = quantizer.quantize( | ||
| cfg.model_file, dataloader, cfg.tensor_model_parallel_size, cfg.pipeline_model_parallel_size | ||
| ) | ||
|
|
||
| quantizer.export(model, cfg.model_save) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .quantizer import Quantizer |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link doesn't work. This one should be okay: https://huggingface.co/datasets/monology/pile-uncopyrighted