diff --git a/Dockerfile b/Dockerfile index 90c84ea07627..970c34a690d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -66,7 +66,7 @@ WORKDIR /workspace/ # We leave it here in case we need to work off of a specific commit in main RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout ad53b1e38689a0ceed75ade7821f4e6c7554abb4 && \ + git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \ pip install . # Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 @@ -132,6 +132,8 @@ RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-chec RUN pip install flash-attn # install numba for latest containers RUN pip install numba>=0.57.1 +# install ammo +RUN pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir # copy nemo source into a scratch image FROM scratch as nemo-src diff --git a/Jenkinsfile b/Jenkinsfile index cfd5853a6882..100a0bd4a6ad 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -91,11 +91,17 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 5f9c870f9f24b482509699d206a9dbb00958f6fc && \ + git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \ pip install .' } } + stage('AMMO installation') { + steps { + sh 'pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir' + } + } + stage('PyTorch Lightning version') { steps { sh 'python -c "import pytorch_lightning; print(pytorch_lightning.__version__)"' @@ -390,6 +396,12 @@ pipeline { } } + stage('Setup test data and models') { + steps { + sh 'python -m tests.setup --save_dir /home/TestData/nlp' + } + } + // TODO: this requires TE >= v0.11 which is not available in 23.06. // please uncomment this test once mcore CI is ready. stage('L2: Community LLM Checkpoints tests') { @@ -405,9 +417,8 @@ pipeline { steps { sh 'CUDA_VISIBLE_DEVICES=0 python scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py \ --in-file=/home/TestData/nlp/megatron_llama/llama-ci-hf \ - --out-file=/home/TestData/nlp/megatron_llama/ci.nemo \ + --out-file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \ --precision=16' - sh 'rm -f /home/TestData/nlp/megatron_llama/ci.nemo' } } stage('StarCoder') { @@ -439,6 +450,54 @@ pipeline { } } + stage('L2: Nemo PTQ') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + parallel { + stage('Llama2 - Export Only') { + steps { + sh 'python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \ + quantization.algorithm=null \ + model_save=/home/TestData/nlp/megatron_llama/ci_baseline' + sh 'rm -rf /home/TestData/nlp/megatron_llama/ci_baseline' + } + } + stage('Llama2 - INT8 SQ') { + steps { + sh 'python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int8_sq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo' + sh 'rm -f /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo' + } + } + stage('Llama2 - FP8') { + steps { + sh 'python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=/home/TestData/nlp/megatron_llama/llama-ci-hf/llama_ci.nemo \ + tensor_model_parallel_size=2 \ + trainer.devices=2 \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=fp8 \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + export.inference_tensor_parallel=2 \ + model_save=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo' + sh 'rm -f /home/TestData/nlp/megatron_llama/ci_fp8.qnemo' + } + } + } + } + stage('L2: ASR dev run') { when { anyOf { diff --git a/examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml b/examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml new file mode 100644 index 000000000000..f3803dc4e69c --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml @@ -0,0 +1,38 @@ +inference: + greedy: false # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: true # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: false # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: false # a flag used to compute logprob of all the input text, a very special case of running inference, default False + batch_size: 64 # batch size for inference + max_context_length: 512 # max length of the context, input sequence will be truncated if it is longer than this + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: false # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + enable_checkpointing: false + +quantization: + quantize_bmm1: false + algorithm: fp8 # int8_sq, fp8, int8, int4_awq, null + calib_dataset: cnn_dailymail # pileval, wikitext, cnn_dailymail + num_calib_size: 512 # number of samples used for calibration + +export: + decoder_type: llama # gptnext, gpt2, llama + inference_tensor_parallel: 1 # Default using 1 TP for inference + dtype: 16 # Default precision data type + export_tensorrt_llm_config: true # export config to build TRT-LLM engine directly + +model_file: llama2-7b-fp16.nemo # Nemo file path +model_save: llama2-7b-fp8.qnemo # Path where the quantized model will be saved +tensor_model_parallel_size: 1 +pipeline_model_parallel_size: 1 diff --git a/examples/nlp/language_modeling/megatron_llama_quantization.py b/examples/nlp/language_modeling/megatron_llama_quantization.py new file mode 100644 index 000000000000..16fb5ae9c13b --- /dev/null +++ b/examples/nlp/language_modeling/megatron_llama_quantization.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.multiprocessing as mp +from datasets import load_dataset + +from nemo.core.config import hydra_runner +from nemo.export.quantize import Quantizer + +mp.set_start_method("spawn", force=True) + +""" +Nemo quantization example script. + +Please consult nemo.export.quantize.Quantizer class +and examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml config on available quantization methods, +models supported as well as how to set up data and inference for calibration (with defaults recommended). + +Example usage: +``` +python examples/nlp/language_modeling/megatron_llama_quantization.py \ + model_file=llama2-7b-fp16.nemo \ + model_save=llama2-7b-fp8.qnemo \ + quantization.algorithm=fp8 \ + export.decoder_type=llama \ + export.inference_tensor_parallel=1 +``` +""" + + +def get_calib_dataloader(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): + if data == "pileval": + dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train") + text_column = "text" + elif data == "wikitext": + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + text_column = "text" + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + text_column = "article" + else: + # Assume a local JSON dataset with a column named "text" + dataset = load_dataset("json", data_files=data, split="train") + text_column = "text" + calib_size = max(min(len(dataset), calib_size), batch_size) + for i in range(calib_size // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] + for j in range(len(batch)): + batch[j] = batch[j][:max_sequence_length] + yield batch + + +@hydra_runner(config_path="conf", config_name="megatron_llama_quantization") +def main(cfg) -> None: + if not torch.cuda.is_available(): + raise EnvironmentError("GPU is required for the inference.") + + quantizer = Quantizer(cfg.quantization, cfg.inference, cfg.export, cfg.trainer) + + # Quantization algorithm can be set to None. This is useful for baseline precision + # accuracy validation. In this case only weights export step will be performed: + if cfg.quantization.algorithm is not None: + dataloader = get_calib_dataloader( + cfg.quantization.calib_dataset, + cfg.inference.batch_size, + cfg.quantization.num_calib_size, + cfg.inference.max_context_length, + ) + dataloader = [data for data in dataloader] + else: + dataloader = None + + model = quantizer.quantize( + cfg.model_file, dataloader, cfg.tensor_model_parallel_size, cfg.pipeline_model_parallel_size + ) + + quantizer.export(model, cfg.model_save) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ac35af38de64..f883f1c1fc7c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -91,6 +91,7 @@ from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset + from megatron.core.deploy.gpt.model_specs import get_gpt_layer_ammo_spec from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -140,6 +141,7 @@ def get_specs(spec_name, num_experts=None): "": get_gpt_layer_with_transformer_engine_spec(num_experts), "megatron_falcon_gpt": get_falcon_layer_spec(), "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), + "ammo": get_gpt_layer_ammo_spec(), } if spec_name not in name_spec_dict: raise ValueError(f"Spec name '{spec_name}' is not recognized.") diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/export/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/export/quantize/__init__.py b/nemo/export/quantize/__init__.py new file mode 100644 index 000000000000..87812e621bb6 --- /dev/null +++ b/nemo/export/quantize/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .quantizer import Quantizer diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py new file mode 100644 index 000000000000..1ae375e6cfe7 --- /dev/null +++ b/nemo/export/quantize/quantizer.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import tarfile +from contextlib import nullcontext +from typing import List, Optional + +import torch.distributed as dist +from megatron.core import parallel_state +from omegaconf import OmegaConf +from omegaconf.omegaconf import DictConfig, open_dict +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging +from nemo.utils.distributed import temporary_directory +from nemo.utils.get_rank import is_global_rank_zero +from nemo.utils.model_utils import load_config, save_artifacts + +try: + import ammo.torch.quantization as atq + from ammo.torch.export import export_model_config + + HAVE_AMMO = True + +except (ImportError, ModuleNotFoundError) as e: + HAVE_AMMO = False + HAVE_AMMO_ERROR = e + + +class Quantizer: + + """ + Post-training quantization of Nemo checkpoints. + + PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. + The process consist of several steps: + + 1. Loading a Nemo model from disk using appropriate parallelism strategy + 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors + 3. Producing output directory or .qnemo tarball with model config (json), + quantized weights (safetensors) and tokenizer config (yaml). + + The output directory (or .qnemo file) produced is intended to be consumed by TensorRT-LLM toolbox + for efficient inference. This can be achieved using Nemo inference containers. + + Currently supported and tested model family is Llama2. Model type needs to be specified in + the quantization command with decoder_type parameter on exporting (see below). Quantizing other + model families is experimental and might not be fully supported. + + Available quantization methods are listed in QUANT_CFG_CHOICES dictionary below. + Please consult AMMO documentation for details. You can also inspect different choices in + examples/nlp/language_modeling/conf/megatron_llama_quantization.yaml for quantization algorithms and + calibration data as well as recommended settings. + + Quantization algorithm can also be conveniently set to 'null' to perform only weights export step + for TensorRT-LLM deployment. This is useful to getting baseline results for a full-precision model. + """ + + def __init__( + self, + quantization_config: DictConfig, + inference_config: DictConfig, + export_config: DictConfig, + trainer_config: DictConfig, + ): + if not HAVE_AMMO: + raise RuntimeError("nvidia-ammo>=0.7 is needed to use Quantizer") from HAVE_AMMO_ERROR + QUANT_CFG_CHOICES = { + "int8": atq.INT8_DEFAULT_CFG, + "int8_sq": atq.INT8_SMOOTHQUANT_CFG, + "fp8": atq.FP8_DEFAULT_CFG, + "int4_awq": atq.INT4_AWQ_CFG, + "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, + } + SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers + assert export_config.dtype in SUPPORTED_DTYPE + assert quantization_config.algorithm is None or quantization_config.algorithm in QUANT_CFG_CHOICES + self.quantization_config = quantization_config + self.inference_config = inference_config + self.export_config = export_config + self.trainer_config = trainer_config + if quantization_config.algorithm is not None: + atq_config = QUANT_CFG_CHOICES[quantization_config.algorithm] + if quantization_config.algorithm != "fp8": + # disable quantization for the last output layer + atq_config = copy.deepcopy(atq_config) + atq_config["quant_cfg"]["*.output_layer.*"] = {"enable": False} + self.atq_config = atq_config + else: + self.atq_config = None + + def _load_model( + self, + model_file: str, + tensor_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_size: Optional[int] = None, + ): + """Load model using AMMO layer spec for quantization.""" + model_cfg = self._load_and_modify_config(model_file, tensor_model_parallel_size, pipeline_model_parallel_size) + + trainer = Trainer(strategy=NLPDDPStrategy(), **self.trainer_config) + connector = NLPSaveRestoreConnector() + + model = MegatronGPTModel.restore_from( + restore_path=model_file, trainer=trainer, override_config_path=model_cfg, save_restore_connector=connector, + ) + model.freeze() + + try: + model.model.module.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + self._check_ddp_initialized(model) + + if is_global_rank_zero(): + print(model) + + return model + + def _check_ddp_initialized(self, model): + if parallel_state.is_unitialized(): + + def dummy(): + return + + if model.trainer.strategy.launcher is not None: + model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) + model.trainer.strategy.setup_environment() + + def _load_and_modify_config( + self, + model_file: str, + tensor_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_size: Optional[int] = None, + ): + model_cfg = load_config(model_file) + + with open_dict(model_cfg): + model_cfg.activations_checkpoint_method = None + model_cfg.activations_checkpoint_granularity = None + if tensor_model_parallel_size is not None: + model_cfg.tensor_model_parallel_size = tensor_model_parallel_size + if pipeline_model_parallel_size is not None: + model_cfg.pipeline_model_parallel_size = pipeline_model_parallel_size + # Only custom AMMO spec is supported for PTQ: this custom spec is largely based on local Megatron-LM + # layer definitions to avoid Transformer Engine implementations that are currently not supported. + model_cfg.name = "ammo" + + return model_cfg + + def quantize( + self, + model_file: str, + dataloader: Optional[List[List[str]]], + tensor_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_size: Optional[int] = None, + ): + """Quantize model checkpoint using given dataloader and optional custom parallelism settings.""" + model = self._load_model(model_file, tensor_model_parallel_size, pipeline_model_parallel_size) + + if self.quantization_config.algorithm is None: + return model + + model.set_inference_config(OmegaConf.to_container(self.inference_config)) + + def forward_loop(): + for i, batch in enumerate(dataloader): + if is_global_rank_zero(): + print(f"Calibrating batch {i}") + model.predict_step(batch, i) + + model = atq.quantize(model, self.atq_config, forward_loop) + return model + + def export(self, model, model_save: str): + """Export model to '.qnemo' format for TensorRT-LLM engine build.""" + torch_dtype = torch_dtype_from_precision(self.export_config.dtype) + + # Setup model export handling: temporary directory for + # '.qnemo' tarball or directly write to model_save + save_qnemo = model_save.endswith(".qnemo") + if save_qnemo: + export_handler = temporary_directory() + else: + export_handler = nullcontext(enter_result=model_save) + + with export_handler as export_dir: + export_model_config( + model=model, + decoder_type=self.export_config.decoder_type, + dtype=torch_dtype, + export_dir=export_dir, + inference_tensor_parallel=self.export_config.inference_tensor_parallel, + export_tensorrt_llm_config=self.export_config.export_tensorrt_llm_config, + ) + dist.barrier() # Wait until all ranks complete export_model_config step + if is_global_rank_zero(): + logging.info(f"Exporting quantized weights, model artifacts, and tokenizer config to {model_save}...") + save_artifacts(model, export_dir) + if save_qnemo: + with tarfile.open(model_save, "w:gz") as tar: + tar.add(export_dir, arcname="./") diff --git a/nemo/utils/distributed.py b/nemo/utils/distributed.py index b0d24de3e5b4..ee6c107b1d85 100644 --- a/nemo/utils/distributed.py +++ b/nemo/utils/distributed.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os +import tempfile import torch +import torch.distributed as dist from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero try: from megatron.core import parallel_state @@ -100,3 +104,22 @@ def gather_objects(partial_results_list, main_rank=None): results_list.extend(r) return results_list + + +@contextlib.contextmanager +def temporary_directory(): + """Create a shared temporary directory across ranks in distributed setup. + + This function assumes that the distributed setup has been already + correctly initialized. It is intended to be used only in single-node + setup so that all ranks can access the directory created.""" + + if is_global_rank_zero(): + tmp_dir = [tempfile.TemporaryDirectory()] + else: + tmp_dir = [None] + dist.broadcast_object_list(tmp_dir) + yield tmp_dir[0].name + # We use barrier below to make sure that rank zero won't exit + # and delete tmp_dir while other ranks may still use it + dist.barrier() diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index b2a6abbf54aa..8889f13d5b98 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import copy import importlib import os +import shutil +import tarfile +import tempfile from dataclasses import dataclass, is_dataclass from enum import Enum from functools import lru_cache @@ -61,6 +65,18 @@ class ArtifactItem: hashed_path: Optional[str] = None +def load_config(model_file: str) -> DictConfig: + """Load model config from extracted directory or '.nemo' tarball.""" + if os.path.isfile(model_file): + with tempfile.TemporaryDirectory() as tmp, tarfile.open(model_file, "r:") as tar: + tar.extract("./model_config.yaml", path=tmp) + model_config = OmegaConf.load(os.path.join(tmp, "model_config.yaml")) + else: + model_config = OmegaConf.load(os.path.join(model_file, "model_config.yaml")) + + return model_config + + def resolve_dataset_name_from_cfg(cfg: 'DictConfig') -> Optional[str]: """ Parses items of the provided sub-config to find the first potential key that @@ -636,3 +652,36 @@ def ckpt_to_dir(filepath: Union[str, Path]) -> Path: checkpoint_dir = filepath.with_name(filepath.stem) return checkpoint_dir + + +def save_artifacts(model, output_dir: str, use_abspath: bool = False) -> None: + """Save all model artifacts and tokenizer config to a given output directory.""" + app_state = AppState() + model_file = app_state.model_restore_path + model_cfg = copy.deepcopy(model.cfg) + + # Setup model file handling context: directory or tarball + if os.path.isfile(model_file): + model_file_handler = tarfile.open + kwargs = {"name": model_file, "mode": "r:"} + elif os.path.isdir(model_file): + model_file_handler = contextlib.nullcontext + kwargs = {} + else: + raise FileNotFoundError(model_file) + + # Copy or extract artifacts depending on the context + with model_file_handler(**kwargs) as maybe_tar: + for arti_name, arti_item in model.artifacts.items(): + _, arti_file = arti_item.path.split("nemo:") + arti_path = os.path.join(output_dir, arti_name) + if maybe_tar is not None: + maybe_tar.extract(f"./{arti_file}", path=output_dir) + os.rename(os.path.join(output_dir, arti_file), arti_path) + else: + shutil.copy(os.path.join(model_file, arti_file), arti_path) + # Store artifact path as basename by default. Otherwise save absolute path but bear in mind + # that in this case output directory should be permanent for correct artifact recovery later + arti_path = os.path.abspath(arti_path) if use_abspath else os.path.basename(arti_path) + OmegaConf.update(model_cfg, arti_name, arti_path) + OmegaConf.save(model_cfg.tokenizer, os.path.join(output_dir, "tokenizer_config.yaml")) diff --git a/tests/setup/__main__.py b/tests/setup/__main__.py new file mode 100644 index 000000000000..289a2537e2f2 --- /dev/null +++ b/tests/setup/__main__.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from .data.create_sample_jsonl import create_sample_jsonl +from .models.create_hf_model import create_hf_model + +print("Setup test data and models...") + +parser = argparse.ArgumentParser("Setup test data and models.") +parser.add_argument("--save_dir", required=True, help="Root save directory for artifacts") +parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files and directories") +args = parser.parse_args() + +print(f"Arguments are: {vars(args)}") + +os.makedirs(args.save_dir, exist_ok=True) + +create_sample_jsonl( + output_file=os.path.join(args.save_dir, "test_quantization", "test.json"), overwrite=args.overwrite, +) + +create_hf_model( + model_name_or_path="/home/TestData/nlp/meta-llama/Llama-2-7b-hf", + output_dir=os.path.join(args.save_dir, "megatron_llama/llama-ci-hf"), + config_updates={"hidden_size": 256, "num_attention_heads": 4, "num_hidden_layers": 2, "num_key_value_heads": 4}, + overwrite=args.overwrite, +) +print("Setup done.") diff --git a/tests/setup/data/create_sample_jsonl.py b/tests/setup/data/create_sample_jsonl.py new file mode 100644 index 000000000000..00f789548f81 --- /dev/null +++ b/tests/setup/data/create_sample_jsonl.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +""" +Create sample JSONL file for functional testing. Each line contains a dictionary +with a single element "text" for storing data. +""" + + +def create_sample_jsonl(output_file: str, overwrite: bool = False): + """Create sample JSONL.""" + if os.path.isfile(output_file) and not overwrite: + print(f"File {output_file} exists and overwrite flag is not set so exiting.") + return + + texts = [ + "Sample data for functional tests", + "Once upon a time, in the middle of a dense forest, there was a small house, where lived a pretty little girl " + "named Little Red Riding Hood.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore " + "magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea " + "commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat " + "nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit " + "anim id est laborum...", + "Next please!", + "¡H E L L O W O R L D!", + "Korzystając z okazji chciałbym pozdrowić całą moją rodzinę i przyjaciół", + ] + print(f"Writing {len(texts)} line(s) to {output_file}...") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, mode="w", encoding="utf-8") as f: + for text in texts: + json.dump({"text": text}, f) + f.write("\n") + print("OK.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Create sample JSONL file.") + parser.add_argument("--output_file", help="Output file name") + parser.add_argument("--overwrite", action="store_true", help="Overwrite file if it exists") + args = parser.parse_args() + create_sample_jsonl(args.output_file) diff --git a/tests/setup/models/create_hf_model.py b/tests/setup/models/create_hf_model.py new file mode 100644 index 000000000000..9f57d5996dfc --- /dev/null +++ b/tests/setup/models/create_hf_model.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +from typing import Any, Dict, Optional + +import transformers + +""" +Create a randomly initialized HuggingFace model for testing purposes. + +Model can be specified by name or path for creating its config and tokenizer using +HuggingFace transformers AutoConfig and AutoTokenizer functions. + +Parameter config_updates can be used to override specific model config fields to make +it smaller, for example, by changing number of layers or hidden layers dimensionality, +making it adequate for testing purposes. This parameter should be specified as +a dictionary that can be parsed using json.loads method. + +Example usage for Llama2 model (requires HF login): +``` +python tests/setup/models/create_tiny_hf_model.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --output_dir tiny_llama2_hf \ + --config_updates '{"hidden_size": 128, "num_attention_heads": 4, "num_hidden_layers": 2, "num_key_value_heads": 4}' +``` +""" + + +def get_hf_model_class(hf_config): + """Get HuggingFace model class from config.""" + if len(hf_config.architectures) > 1: + print(f"More than one model architecture available, choosing 1st: {hf_config.architectures}") + model_name = hf_config.architectures[0] + model_class = getattr(transformers, model_name) + return model_class + + +def create_hf_model( + model_name_or_path: str, output_dir: str, config_updates: Optional[Dict[str, Any]] = None, overwrite: bool = False +): + """Create HuggingFace model with optional config updates.""" + if os.path.isdir(output_dir) and not overwrite: + print(f"Output directory {output_dir} exists and overwrite flag is not set so exiting.") + return + + hf_config = transformers.AutoConfig.from_pretrained(model_name_or_path) + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) + model_class = get_hf_model_class(hf_config) + + if config_updates is not None: + hf_config.update(config_updates) + print(hf_config) + + model = model_class(hf_config) + print(model) + + os.makedirs(output_dir, exist_ok=True) + print(f"Saving model to {output_dir}...") + tokenizer.save_pretrained(output_dir) + model.save_pretrained(output_dir) + print("OK.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Create a HuggingFace model (random initialization) for testing purposes.") + parser.add_argument( + "--model_name_or_path", required=True, help="Model name or local path with model config and tokenizer", + ) + parser.add_argument( + "--output_dir", required=True, help="Output directory", + ) + parser.add_argument( + "--config_updates", type=json.loads, help="Parameter updates in JSON format to overwrite for model config", + ) + parser.add_argument( + "--overwrite", action="store_true", help="Overwrite file if it exists", + ) + args = parser.parse_args() + create_hf_model(args.model_name_or_path, args.output_dir, args.config_updates)