From 31e7a5320d1f3bd904873211460c29ca0ac665db Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Mon, 7 Nov 2022 12:24:38 -0800 Subject: [PATCH 01/10] export update for Megatron + change ORT optimization Signed-off-by: David Mosallanezhad --- .../machine_translation/megatron_nmt_model.py | 2 +- .../common/megatron/megatron_export.py | 46 +++++++---- nemo/utils/export_utils.py | 80 ++++++++++++++----- 3 files changed, 92 insertions(+), 36 deletions(-) diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 0b9bb966e3e8..6eb3a62de195 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -750,5 +750,5 @@ def decoder(self): return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device) @property - def classifier(self): + def log_softmax(self): return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_export.py b/nemo/collections/nlp/modules/common/megatron/megatron_export.py index 6fd9a239380c..d6f1593c36bf 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_export.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_export.py @@ -49,21 +49,23 @@ def forward(self, dec_output): if isinstance(dec_output, list): dec_output = dec_output[0] - dec_output = torch.permute(dec_output, (1, 0, 2)) - if self.tokens_head_bias is not None: return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight, self.tokens_head_bias) return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight) - def input_example(self, max_batch=1, max_dim=1024, seq_len=6): + def input_example(self, max_batch=1, max_dim=768, seq_len=6): return [ - torch.randint(low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32) + torch.randint(low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32) ] + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { - "hidden_states": NeuralType(('T', 'B', 'D'), ChannelType()), + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), } @property @@ -107,18 +109,24 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec # dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask _ = dec_mems - return self.decoder(dec_input, decoder_mask, encoder_embeddings, encoder_mask).float() + return self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask).float().permute(1, 0, 2) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False - def input_example(self, max_batch=1, max_dim=1024, seq_len=6): + def input_example(self, max_batch=1, max_dim=768, seq_len=6): enc_output = torch.randint( - low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32 + low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32 ) enc_attn_mask = torch.tensor([[1 for _ in range(seq_len)]]).to(self.device) dec_len = random.randint(10, 128) dec_input = torch.randint(low=0, high=1000, size=(max_batch, dec_len), device=self.device) dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device) - decoder_mems = torch.zeros([8, 6, 1024], dtype=torch.float32).to(self.device) + + # constant decoder_mems as placeholder for now + decoder_mems = torch.zeros([8, 6, max_dim], dtype=torch.float32).to(self.device) # input_ids, decoder_mask, encoder_mask, encoder_embeddings return (dec_input, dec_attn_mask, enc_attn_mask, enc_output, decoder_mems) @@ -128,14 +136,14 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: return { "input_ids": NeuralType(('B', 'T', 'D'), ChannelType()), "decoder_mask": NeuralType(('B', 'T'), MaskType()), - "encoder_mask": NeuralType(('T', 'B', 'D'), ChannelType()), + "encoder_mask": NeuralType(('B', 'T', 'D'), ChannelType()), "encoder_embeddings": NeuralType(('B', 'T'), MaskType()), - "decoder_mems": NeuralType(('T', 'B', 'D'), ChannelType()), + "decoder_mems": NeuralType(('B', 'T', 'D'), ChannelType()), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} @property def input_names(self) -> List[str]: @@ -172,15 +180,19 @@ def forward(self, input_ids, encoder_mask): enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None) # pass input through the encoder - return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32) + return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32).permute(1, 0, 2) - def input_example(self): + def input_example(self, max_batch=1, max_dim=30000, seq_len=6): seq_len = random.randint(0, 128) return ( - torch.randint(0, 30000, (1, seq_len)).to(self.device), - torch.ones((1, seq_len), dtype=int).to(self.device), + torch.randint(0, max_dim, (max_batch, seq_len)).to(self.device), + torch.ones((max_batch, seq_len), dtype=int).to(self.device), ) + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { @@ -190,7 +202,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} @property def input_names(self) -> List[str]: diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 3d042dd070e2..b4e48b21d7fa 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -87,22 +87,49 @@ def forward(self, x): return F.linear(x, self.weight), self.bias return F.linear(x, self.weight, self.bias), None - -# ScaledMaskedSoftmax replacement -def mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -def exportable_ScaledMaskedSoftmax(input, mask, scale): - if scale is not None: - input = input * scale - - mask_output = mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - probs = probs.half() - return probs +class ExportableMatchedScaleMaskSoftmax(nn.Module): + def __init__(self, mod): + super(ExportableMatchedScaleMaskSoftmax, self).__init__() + self.init_module( + mod.input_in_fp16, + mod.input_in_bf16, + mod.mask_func, + mod.softmax_in_fp32, + mod.scale + ) + + def init_module(self, input_in_fp16, + input_in_bf16, + mask_func, + softmax_in_fp32, + scale, + ): + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + self.softmax_in_fp32 = softmax_in_fp32 + self.mask_func = mask_func + self.scale = scale + + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + + def forward(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + all_k_masked = mask.all(axis=-1) + zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] + probs = probs * zero_attention_mask + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + return probs def get_export_format(filename: str): @@ -188,7 +215,7 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 return onnx_session_opt = onnxruntime.SessionOptions() - onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC sess = onnxruntime.InferenceSession( onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'] ) @@ -263,7 +290,7 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev) elif isinstance(n, FastLayerNorm): - mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float16,).to(dev) + mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float,).to(dev) n_state = n.state_dict() mod.load_state_dict(n_state) @@ -357,6 +384,22 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn + +def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces + Args: + n: + Returns: + + """ + + mod = ExportableMatchedScaleMaskSoftmax( + n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale + ) + + return mod + def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT wrapper. @@ -421,6 +464,7 @@ def replace_modules( "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax) } From 21a535141f1aba060b46d159168b5a11dbda4c36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Nov 2022 20:28:15 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/megatron/megatron_export.py | 6 ++++- nemo/utils/export_utils.py | 26 ++++++------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_export.py b/nemo/collections/nlp/modules/common/megatron/megatron_export.py index d6f1593c36bf..fe9bd34c4804 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_export.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_export.py @@ -109,7 +109,11 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec # dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask _ = dec_mems - return self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask).float().permute(1, 0, 2) + return ( + self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask) + .float() + .permute(1, 0, 2) + ) def freeze(self): for param in self.parameters(): diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index b4e48b21d7fa..d28dadd02ac3 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -87,22 +87,14 @@ def forward(self, x): return F.linear(x, self.weight), self.bias return F.linear(x, self.weight, self.bias), None + class ExportableMatchedScaleMaskSoftmax(nn.Module): def __init__(self, mod): super(ExportableMatchedScaleMaskSoftmax, self).__init__() - self.init_module( - mod.input_in_fp16, - mod.input_in_bf16, - mod.mask_func, - mod.softmax_in_fp32, - mod.scale - ) - - def init_module(self, input_in_fp16, - input_in_bf16, - mask_func, - softmax_in_fp32, - scale, + self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale) + + def init_module( + self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale, ): self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 @@ -384,7 +376,6 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn - def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: """ Replaces @@ -394,12 +385,11 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: """ - mod = ExportableMatchedScaleMaskSoftmax( - n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale - ) + mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale) return mod + def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT wrapper. @@ -464,7 +454,7 @@ def replace_modules( "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), - "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax) + "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax), } From 836d519900f313ee935dfc92f14727816c2f34fd Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 8 Nov 2022 08:51:19 -0800 Subject: [PATCH 03/10] updated export_utils to use autocast instead of manually casting >:/ Signed-off-by: David Mosallanezhad --- .../nlp/modules/common/megatron/megatron_export.py | 2 +- nemo/utils/export_utils.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_export.py b/nemo/collections/nlp/modules/common/megatron/megatron_export.py index d6f1593c36bf..cebfe50c7c8a 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_export.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_export.py @@ -180,7 +180,7 @@ def forward(self, input_ids, encoder_mask): enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None) # pass input through the encoder - return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32).permute(1, 0, 2) + return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2) def input_example(self, max_batch=1, max_dim=30000, seq_len=6): seq_len = random.randint(0, 128) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index b4e48b21d7fa..e5202c2af4ee 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -68,9 +68,7 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - if torch.is_autocast_enabled(): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) - else: + with torch.autocast(device_type='cuda', dtype=x.dtype): ret = self.mod.forward(x) return ret From 8c1a1d48c58b16cc4db10691df1d847a2683c41a Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 8 Nov 2022 08:54:37 -0800 Subject: [PATCH 04/10] removed dtype from LayerNorm Signed-off-by: David Mosallanezhad --- nemo/utils/export_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 2d094bf20496..c2803f55de2e 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -280,7 +280,7 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev) elif isinstance(n, FastLayerNorm): - mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float,).to(dev) + mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True,).to(dev) n_state = n.state_dict() mod.load_state_dict(n_state) From c6686d864a6a85b08d556a1c3c81bb3408ec64b8 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 8 Nov 2022 08:59:19 -0800 Subject: [PATCH 05/10] added comment Signed-off-by: David Mosallanezhad --- nemo/utils/export_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index c2803f55de2e..03f4fec1af87 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -376,11 +376,11 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: """ - Replaces + Replaces MatchedScaleMaskSoftmax with exportable softmax layer Args: - n: + n: module to replace Returns: - + exportable module """ mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale) From c8f9eae45cec8068301707301bcd3ce0eb77135e Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Mon, 14 Nov 2022 15:53:23 -0800 Subject: [PATCH 06/10] reverting changes on FloatCast Signed-off-by: David Mosallanezhad --- nemo/utils/export_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 03f4fec1af87..f3e0a826d977 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -68,11 +68,12 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - with torch.autocast(device_type='cuda', dtype=x.dtype): + if torch.is_autocast_enabled(): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + else: ret = self.mod.forward(x) return ret - class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): super(LinearWithBiasSkip, self).__init__() @@ -205,7 +206,7 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 return onnx_session_opt = onnxruntime.SessionOptions() - onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC + onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess = onnxruntime.InferenceSession( onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'] ) From 18417e4b216bbdace8d55060f07c425b3f3e0d93 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 14 Nov 2022 18:14:15 -0800 Subject: [PATCH 07/10] Cherry-picked changes from megatron-norm Signed-off-by: Boris Fomitchev --- .../machine_translation/megatron_nmt_model.py | 10 ++- nemo/utils/__init__.py | 7 ++ nemo/utils/cast_utils.py | 75 +++++++++++++++++++ nemo/utils/export_utils.py | 55 +++----------- scripts/export.py | 25 +++---- 5 files changed, 115 insertions(+), 57 deletions(-) create mode 100644 nemo/utils/cast_utils.py diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 6eb3a62de195..a44b560fbb2d 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -42,6 +42,7 @@ from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes import Exportable from nemo.utils import AppState, logging, timers try: @@ -56,7 +57,7 @@ __all__ = ["MegatronNMTModel"] -class MegatronNMTModel(MegatronLMEncoderDecoderModel): +class MegatronNMTModel(MegatronLMEncoderDecoderModel, Exportable): """ Megatron NMT training """ @@ -752,3 +753,10 @@ def decoder(self): @property def log_softmax(self): return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device) + + @property + def input_module(self): + return self.encoder + + def list_export_subnets(self): + return ['encoder', 'log_softmax', 'decoder'] diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index fd3cbd2699d1..0b2748e580af 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -14,6 +14,13 @@ from nemo.utils.app_state import AppState +from nemo.utils.cast_utils import ( + CastToFloat, + avoid_bfloat16_autocast_context, + avoid_float16_autocast_context, + cast_all, + cast_tensor, +) from nemo.utils.nemo_logging import Logger as _Logger from nemo.utils.nemo_logging import LogMode as logging_mode diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py new file mode 100644 index 000000000000..9eb064936ea5 --- /dev/null +++ b/nemo/utils/cast_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch + + +def avoid_bfloat16_autocast_context(): + """ + If the current autocast context is bfloat16, + cast it to float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def avoid_float16_autocast_context(): + """ + If the current autocast context is float16, cast it to bfloat16 + if available (unless we're in jit) or float32 + """ + + if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.cuda.amp.autocast(dtype=torch.float32) + + if torch.cuda.is_bf16_supported(): + return torch.cuda.amp.autocast(dtype=torch.bfloat16) + else: + return torch.cuda.amp.autocast(dtype=torch.float32) + else: + return nullcontext() + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + def __init__(self, mod): + super(CastToFloat, self).__init__() + self.mod = mod + + def forward(self, x): + with avoid_float16_autocast_context(): + ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) + return ret diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index f3e0a826d977..b4992cf88896 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from contextlib import nullcontext from enum import Enum from typing import Callable, Dict, Optional, Type @@ -21,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F -from nemo.utils import logging +from nemo.utils import CastToFloat, logging try: import onnxruntime @@ -45,35 +46,6 @@ class ExportFormat(Enum): } -def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): - return x.to(dtype=to_dtype) if x.dtype == from_dtype else x - - -def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): - if isinstance(x, torch.Tensor): - return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) - else: - if isinstance(x, dict): - new_dict = {} - for k in x.keys(): - new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) - return new_dict - elif isinstance(x, tuple): - return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) - - -class CastToFloat(nn.Module): - def __init__(self, mod): - super(CastToFloat, self).__init__() - self.mod = mod - - def forward(self, x): - if torch.is_autocast_enabled(): - ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) - else: - ret = self.mod.forward(x) - return ret - class LinearWithBiasSkip(nn.Module): def __init__(self, weight, bias, skip_bias_add): super(LinearWithBiasSkip, self).__init__() @@ -126,7 +98,7 @@ def forward(self, input, mask): def get_export_format(filename: str): _, ext = os.path.splitext(filename) try: - return _EXT_DICT[ext] + return _EXT_DICT[ext.lower()] except KeyError: raise ValueError(f"Export file {filename} extension does not correspond to any export format!") @@ -204,7 +176,7 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n") onnx.checker.check_model(onnx_model, full_check=True) return - + del onnx_model onnx_session_opt = onnxruntime.SessionOptions() onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess = onnxruntime.InferenceSession( @@ -262,7 +234,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.transformer.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: + def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. Args: @@ -270,19 +242,16 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: Returns: Equivalent LayerNorm module """ - if ( - not isinstance(n, FusedLayerNorm) - and not isinstance(n, FastLayerNorm) - and not isinstance(n, MixedFusedLayerNorm) - ): - return None - dev = next(n.parameters()).device + p = next(n.parameters()) if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): - mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev) + shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine elif isinstance(n, FastLayerNorm): - mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True,).to(dev) + shape, eps, affine = n.weight.shape, n.epsilon, True + else: + return None + mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) n_state = n.state_dict() mod.load_state_dict(n_state) return mod @@ -299,7 +268,7 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: raise ValueError("This function can only change the RowParallelLinear module.") dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) n_state = n.state_dict() mod.load_state_dict(n_state) diff --git a/scripts/export.py b/scripts/export.py index ca9c4fed64aa..b3d6317e936c 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -133,22 +133,21 @@ def nemo_export(argv): logging.info("Caching support is enabled.") autocast = nullcontext - model.to(device=args.device).freeze() - model.eval() - with torch.inference_mode(): - input_example = model.input_module.input_example(**in_args) - if check_trace: - check_trace = [input_example] - if max_dim: - in_args["max_dim"] = (max_dim + 1) // 2 - in_args["max_batch"] = (max_batch + 1) // 2 - input_example2 = model.input_module.input_example(**in_args) - check_trace.append(input_example2) - if args.autocast: autocast = torch.cuda.amp.autocast try: - with autocast(), torch.inference_mode(): + with autocast(), torch.no_grad(), torch.inference_mode(): + model.to(device=args.device).freeze() + model.eval() + input_example = None + if check_trace and len(in_args) > 0: + input_example = model.input_module.input_example(**in_args) + check_trace = [input_example] + for key, arg in in_args: + in_args[key] = (arg + 1) // 2 + input_example2 = model.input_module.input_example(**in_args) + check_trace.append(input_example2) + _, descriptions = model.export( out, input_example=input_example, From 6116f10c1e697e61dbdafc963f3179b9a59de441 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 15 Nov 2022 09:18:48 -0800 Subject: [PATCH 08/10] updated asr_model import to cast_utils Signed-off-by: David Mosallanezhad --- nemo/collections/asr/models/asr_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 23f5e60d9489..7c8a7ab8aa78 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -21,7 +21,7 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.mixins import AccessMixin from nemo.utils import logging, model_utils -from nemo.utils.export_utils import cast_all +from nemo.utils.cast_utils import cast_all __all__ = ['ASRModel'] From dbbdab2e0bea608ad65e7d49f9b3ea87070b3391 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 15 Nov 2022 09:25:19 -0800 Subject: [PATCH 09/10] updated del onnx_model place Signed-off-by: David Mosallanezhad --- nemo/utils/export_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index b4992cf88896..9eeeece556db 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -176,12 +176,12 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n") onnx.checker.check_model(onnx_model, full_check=True) return - del onnx_model onnx_session_opt = onnxruntime.SessionOptions() onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess = onnxruntime.InferenceSession( onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'] ) + del onnx_model all_good = True for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) From e334703a0561400336bac2f9a11309c6b25d5897 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 15 Nov 2022 09:28:26 -0800 Subject: [PATCH 10/10] changed ort optimization to basic -> temp fix Signed-off-by: David Mosallanezhad --- nemo/utils/export_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 9eeeece556db..a5c4e5b3d24f 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -177,7 +177,7 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 onnx.checker.check_model(onnx_model, full_check=True) return onnx_session_opt = onnxruntime.SessionOptions() - onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC sess = onnxruntime.InferenceSession( onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'] )