From 75bf9809750a0bb22ad4e7255f0ddc53fcabb9a2 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Tue, 22 Nov 2022 15:03:53 -0800 Subject: [PATCH 1/2] updated export_utils Signed-off-by: David Mosallanezhad --- nemo/utils/export_utils.py | 47 +++++++------------------------------- 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 197d3b478167..e1db0e80a525 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -58,43 +58,6 @@ def forward(self, x): return F.linear(x, self.weight), self.bias return F.linear(x, self.weight, self.bias), None - -class ExportableMatchedScaleMaskSoftmax(nn.Module): - def __init__(self, mod): - super(ExportableMatchedScaleMaskSoftmax, self).__init__() - self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale) - - def init_module( - self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale, - ): - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - self.softmax_in_fp32 = softmax_in_fp32 - self.mask_func = mask_func - self.scale = scale - - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - - def forward(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - all_k_masked = mask.all(axis=-1) - zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] - probs = probs * zero_attention_mask - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - return probs - - def get_export_format(filename: str): _, ext = os.path.splitext(filename) try: @@ -367,7 +330,13 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: exportable module """ - mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale) + # including the import here to avoid circular imports + from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax + + # disabling fusion for the MatchedScaleMaskSoftmax + mod = MatchedScaleMaskSoftmax( + n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale + ) return mod @@ -440,7 +409,7 @@ def script_module(m: nn.Module): "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), - "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax), + "MatchedScaleMaskSoftmax": wrap_module(None, replace_MatchedScaleMaskSoftmax), } script_replacements = { From bad20d4befc13a357e6e00d53453ab09da5f07f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Nov 2022 23:05:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/utils/export_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index e1db0e80a525..e4fda73c181d 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -58,6 +58,7 @@ def forward(self, x): return F.linear(x, self.weight), self.bias return F.linear(x, self.weight, self.bias), None + def get_export_format(filename: str): _, ext = os.path.splitext(filename) try: