Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 8 additions & 38 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,42 +59,6 @@ def forward(self, x):
return F.linear(x, self.weight, self.bias), None


class ExportableMatchedScaleMaskSoftmax(nn.Module):
def __init__(self, mod):
super(ExportableMatchedScaleMaskSoftmax, self).__init__()
self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale)

def init_module(
self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale,
):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
self.softmax_in_fp32 = softmax_in_fp32
self.mask_func = mask_func
self.scale = scale

self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16

def forward(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()

if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
probs = probs * zero_attention_mask

if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs


def get_export_format(filename: str):
_, ext = os.path.splitext(filename)
try:
Expand Down Expand Up @@ -367,7 +331,13 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
exportable module
"""

mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale)
# including the import here to avoid circular imports
from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax

# disabling fusion for the MatchedScaleMaskSoftmax
mod = MatchedScaleMaskSoftmax(
n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
)

return mod

Expand Down Expand Up @@ -440,7 +410,7 @@ def script_module(m: nn.Module):
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
"MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax),
"MatchedScaleMaskSoftmax": wrap_module(None, replace_MatchedScaleMaskSoftmax),
}

script_replacements = {
Expand Down