diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 23f5e60d9489..7c8a7ab8aa78 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -21,7 +21,7 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.mixins import AccessMixin from nemo.utils import logging, model_utils -from nemo.utils.export_utils import cast_all +from nemo.utils.cast_utils import cast_all __all__ = ['ASRModel'] diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 0b9bb966e3e8..a44b560fbb2d 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -42,6 +42,7 @@ from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes import Exportable from nemo.utils import AppState, logging, timers try: @@ -56,7 +57,7 @@ __all__ = ["MegatronNMTModel"] -class MegatronNMTModel(MegatronLMEncoderDecoderModel): +class MegatronNMTModel(MegatronLMEncoderDecoderModel, Exportable): """ Megatron NMT training """ @@ -750,5 +751,12 @@ def decoder(self): return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device) @property - def classifier(self): + def log_softmax(self): return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device) + + @property + def input_module(self): + return self.encoder + + def list_export_subnets(self): + return ['encoder', 'log_softmax', 'decoder'] diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_export.py b/nemo/collections/nlp/modules/common/megatron/megatron_export.py index 6fd9a239380c..8b9a5fff9e88 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_export.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_export.py @@ -49,21 +49,23 @@ def forward(self, dec_output): if isinstance(dec_output, list): dec_output = dec_output[0] - dec_output = torch.permute(dec_output, (1, 0, 2)) - if self.tokens_head_bias is not None: return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight, self.tokens_head_bias) return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight) - def input_example(self, max_batch=1, max_dim=1024, seq_len=6): + def input_example(self, max_batch=1, max_dim=768, seq_len=6): return [ - torch.randint(low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32) + torch.randint(low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32) ] + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { - "hidden_states": NeuralType(('T', 'B', 'D'), ChannelType()), + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), } @property @@ -107,18 +109,28 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec # dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask _ = dec_mems - return self.decoder(dec_input, decoder_mask, encoder_embeddings, encoder_mask).float() + return ( + self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask) + .float() + .permute(1, 0, 2) + ) - def input_example(self, max_batch=1, max_dim=1024, seq_len=6): + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def input_example(self, max_batch=1, max_dim=768, seq_len=6): enc_output = torch.randint( - low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32 + low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32 ) enc_attn_mask = torch.tensor([[1 for _ in range(seq_len)]]).to(self.device) dec_len = random.randint(10, 128) dec_input = torch.randint(low=0, high=1000, size=(max_batch, dec_len), device=self.device) dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device) - decoder_mems = torch.zeros([8, 6, 1024], dtype=torch.float32).to(self.device) + + # constant decoder_mems as placeholder for now + decoder_mems = torch.zeros([8, 6, max_dim], dtype=torch.float32).to(self.device) # input_ids, decoder_mask, encoder_mask, encoder_embeddings return (dec_input, dec_attn_mask, enc_attn_mask, enc_output, decoder_mems) @@ -128,14 +140,14 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: return { "input_ids": NeuralType(('B', 'T', 'D'), ChannelType()), "decoder_mask": NeuralType(('B', 'T'), MaskType()), - "encoder_mask": NeuralType(('T', 'B', 'D'), ChannelType()), + "encoder_mask": NeuralType(('B', 'T', 'D'), ChannelType()), "encoder_embeddings": NeuralType(('B', 'T'), MaskType()), - "decoder_mems": NeuralType(('T', 'B', 'D'), ChannelType()), + "decoder_mems": NeuralType(('B', 'T', 'D'), ChannelType()), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} @property def input_names(self) -> List[str]: @@ -172,15 +184,19 @@ def forward(self, input_ids, encoder_mask): enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None) # pass input through the encoder - return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32) + return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2) - def input_example(self): + def input_example(self, max_batch=1, max_dim=30000, seq_len=6): seq_len = random.randint(0, 128) return ( - torch.randint(0, 30000, (1, seq_len)).to(self.device), - torch.ones((1, seq_len), dtype=int).to(self.device), + torch.randint(0, max_dim, (max_batch, seq_len)).to(self.device), + torch.ones((max_batch, seq_len), dtype=int).to(self.device), ) + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { @@ -190,7 +206,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} @property def input_names(self) -> List[str]: diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index fd3cbd2699d1..0b2748e580af 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -14,6 +14,13 @@ from nemo.utils.app_state import AppState +from nemo.utils.cast_utils import ( + CastToFloat, + avoid_bfloat16_autocast_context, + avoid_float16_autocast_context, + cast_all, + cast_tensor, +) from nemo.utils.nemo_logging import Logger as _Logger from nemo.utils.nemo_logging import LogMode as logging_mode diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py new file mode 100644 index 000000000000..9eb064936ea5 --- /dev/null +++ b/nemo/utils/cast_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch + + +def avoid_bfloat16_autocast_context(): + """ + If the current autocast context is bfloat16, + cast it to float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def avoid_float16_autocast_context(): + """ + If the current autocast context is float16, cast it to bfloat16 + if available (unless we're in jit) or float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.cuda.amp.autocast(dtype=torch.float32) + + if torch.cuda.is_bf16_supported(): + return torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + def __init__(self, mod): + super(CastToFloat, self).__init__() + self.mod = mod + + def forward(self, x): + with avoid_float16_autocast_context(): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + return ret diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 3d042dd070e2..a5c4e5b3d24f 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from contextlib import nullcontext from enum import Enum from typing import Callable, Dict, Optional, Type @@ -21,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F -from nemo.utils import logging +from nemo.utils import CastToFloat, logging try: import onnxruntime @@ -45,36 +46,6 @@ class ExportFormat(Enum): } -def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): - return x.to(dtype=to_dtype) if x.dtype == from_dtype else x - - -def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): - if isinstance(x, torch.Tensor): - return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) - else: - if isinstance(x, dict): - new_dict = {} - for k in x.keys(): - new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) - return new_dict - elif isinstance(x, tuple): - return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) - - -class CastToFloat(nn.Module): - def __init__(self, mod): - super(CastToFloat, self).__init__() - self.mod = mod - - def forward(self, x): - if torch.is_autocast_enabled(): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) - else: - ret = self.mod.forward(x) - return ret - - class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): super(LinearWithBiasSkip, self).__init__() @@ -88,27 +59,46 @@ def forward(self, x): return F.linear(x, self.weight, self.bias), None -# ScaledMaskedSoftmax replacement -def mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -def exportable_ScaledMaskedSoftmax(input, mask, scale): - if scale is not None: - input = input * scale - - mask_output = mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - probs = probs.half() - return probs +class ExportableMatchedScaleMaskSoftmax(nn.Module): + def __init__(self, mod): + super(ExportableMatchedScaleMaskSoftmax, self).__init__() + self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale) + + def init_module( + self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale, + ): + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + self.softmax_in_fp32 = softmax_in_fp32 + self.mask_func = mask_func + self.scale = scale + + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + + def forward(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + all_k_masked = mask.all(axis=-1) + zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] + probs = probs * zero_attention_mask + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + return probs def get_export_format(filename: str): _, ext = os.path.splitext(filename) try: - return _EXT_DICT[ext] + return _EXT_DICT[ext.lower()] except KeyError: raise ValueError(f"Export file {filename} extension does not correspond to any export format!") @@ -186,12 +176,12 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n") onnx.checker.check_model(onnx_model, full_check=True) return - onnx_session_opt = onnxruntime.SessionOptions() - onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC sess = onnxruntime.InferenceSession( onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'] ) + del onnx_model all_good = True for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) @@ -244,7 +234,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.transformer.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: + def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. Args: @@ -252,19 +242,16 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: Returns: Equivalent LayerNorm module """ - if ( - not isinstance(n, FusedLayerNorm) - and not isinstance(n, FastLayerNorm) - and not isinstance(n, MixedFusedLayerNorm) - ): - return None - dev = next(n.parameters()).device + p = next(n.parameters()) if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): - mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev) + shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine elif isinstance(n, FastLayerNorm): - mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float16,).to(dev) + shape, eps, affine = n.weight.shape, n.epsilon, True + else: + return None + mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) n_state = n.state_dict() mod.load_state_dict(n_state) return mod @@ -281,7 +268,7 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: raise ValueError("This function can only change the RowParallelLinear module.") dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) n_state = n.state_dict() mod.load_state_dict(n_state) @@ -357,6 +344,20 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn +def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces MatchedScaleMaskSoftmax with exportable softmax layer + Args: + n: module to replace + Returns: + exportable module + """ + + mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale) + + return mod + + def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT wrapper. @@ -421,6 +422,7 @@ def replace_modules( "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax), } diff --git a/scripts/export.py b/scripts/export.py index ca9c4fed64aa..b3d6317e936c 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -133,22 +133,21 @@ def nemo_export(argv): logging.info("Caching support is enabled.") autocast = nullcontext - model.to(device=args.device).freeze() - model.eval() - with torch.inference_mode(): - input_example = model.input_module.input_example(**in_args) - if check_trace: - check_trace = [input_example] - if max_dim: - in_args["max_dim"] = (max_dim + 1) // 2 - in_args["max_batch"] = (max_batch + 1) // 2 - input_example2 = model.input_module.input_example(**in_args) - check_trace.append(input_example2) - if args.autocast: autocast = torch.cuda.amp.autocast try: - with autocast(), torch.inference_mode(): + with autocast(), torch.no_grad(), torch.inference_mode(): + model.to(device=args.device).freeze() + model.eval() + input_example = None + if check_trace and len(in_args) > 0: + input_example = model.input_module.input_example(**in_args) + check_trace = [input_example] + for key, arg in in_args: + in_args[key] = (arg + 1) // 2 + input_example2 = model.input_module.input_example(**in_args) + check_trace.append(input_example2) + _, descriptions = model.export( out, input_example=input_example,