Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.utils import logging, model_utils
from nemo.utils.export_utils import cast_all
from nemo.utils.cast_utils import cast_all

__all__ = ['ASRModel']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb
from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.classes import Exportable
from nemo.utils import AppState, logging, timers

try:
Expand All @@ -56,7 +57,7 @@
__all__ = ["MegatronNMTModel"]


class MegatronNMTModel(MegatronLMEncoderDecoderModel):
class MegatronNMTModel(MegatronLMEncoderDecoderModel, Exportable):
"""
Megatron NMT training
"""
Expand Down Expand Up @@ -750,5 +751,12 @@ def decoder(self):
return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device)

@property
def classifier(self):
def log_softmax(self):
return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device)

@property
def input_module(self):
return self.encoder

def list_export_subnets(self):
return ['encoder', 'log_softmax', 'decoder']
50 changes: 33 additions & 17 deletions nemo/collections/nlp/modules/common/megatron/megatron_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,23 @@ def forward(self, dec_output):
if isinstance(dec_output, list):
dec_output = dec_output[0]

dec_output = torch.permute(dec_output, (1, 0, 2))

if self.tokens_head_bias is not None:
return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight, self.tokens_head_bias)
return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight)

def input_example(self, max_batch=1, max_dim=1024, seq_len=6):
def input_example(self, max_batch=1, max_dim=768, seq_len=6):
return [
torch.randint(low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32)
torch.randint(low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32)
]

def freeze(self):
for param in self.parameters():
param.requires_grad = False

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"hidden_states": NeuralType(('T', 'B', 'D'), ChannelType()),
"hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()),
}

@property
Expand Down Expand Up @@ -107,18 +109,28 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec
# dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask
_ = dec_mems

return self.decoder(dec_input, decoder_mask, encoder_embeddings, encoder_mask).float()
return (
self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask)
.float()
.permute(1, 0, 2)
)

def input_example(self, max_batch=1, max_dim=1024, seq_len=6):
def freeze(self):
for param in self.parameters():
param.requires_grad = False

def input_example(self, max_batch=1, max_dim=768, seq_len=6):
enc_output = torch.randint(
low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32
low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32
)
enc_attn_mask = torch.tensor([[1 for _ in range(seq_len)]]).to(self.device)

dec_len = random.randint(10, 128)
dec_input = torch.randint(low=0, high=1000, size=(max_batch, dec_len), device=self.device)
dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device)
decoder_mems = torch.zeros([8, 6, 1024], dtype=torch.float32).to(self.device)

# constant decoder_mems as placeholder for now
decoder_mems = torch.zeros([8, 6, max_dim], dtype=torch.float32).to(self.device)

# input_ids, decoder_mask, encoder_mask, encoder_embeddings
return (dec_input, dec_attn_mask, enc_attn_mask, enc_output, decoder_mems)
Expand All @@ -128,14 +140,14 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"input_ids": NeuralType(('B', 'T', 'D'), ChannelType()),
"decoder_mask": NeuralType(('B', 'T'), MaskType()),
"encoder_mask": NeuralType(('T', 'B', 'D'), ChannelType()),
"encoder_mask": NeuralType(('B', 'T', 'D'), ChannelType()),
"encoder_embeddings": NeuralType(('B', 'T'), MaskType()),
"decoder_mems": NeuralType(('T', 'B', 'D'), ChannelType()),
"decoder_mems": NeuralType(('B', 'T', 'D'), ChannelType()),
}

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}

@property
def input_names(self) -> List[str]:
Expand Down Expand Up @@ -172,15 +184,19 @@ def forward(self, input_ids, encoder_mask):
enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None)

# pass input through the encoder
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32)
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2)

def input_example(self):
def input_example(self, max_batch=1, max_dim=30000, seq_len=6):
seq_len = random.randint(0, 128)
return (
torch.randint(0, 30000, (1, seq_len)).to(self.device),
torch.ones((1, seq_len), dtype=int).to(self.device),
torch.randint(0, max_dim, (max_batch, seq_len)).to(self.device),
torch.ones((max_batch, seq_len), dtype=int).to(self.device),
)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
Expand All @@ -190,7 +206,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}

@property
def input_names(self) -> List[str]:
Expand Down
7 changes: 7 additions & 0 deletions nemo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@


from nemo.utils.app_state import AppState
from nemo.utils.cast_utils import (
CastToFloat,
avoid_bfloat16_autocast_context,
avoid_float16_autocast_context,
cast_all,
cast_tensor,
)
from nemo.utils.nemo_logging import Logger as _Logger
from nemo.utils.nemo_logging import LogMode as logging_mode

Expand Down
75 changes: 75 additions & 0 deletions nemo/utils/cast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import torch


def avoid_bfloat16_autocast_context():
"""
If the current autocast context is bfloat16,
cast it to float32
"""

if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()


def avoid_float16_autocast_context():
"""
If the current autocast context is float16, cast it to bfloat16
if available (unless we're in jit) or float32
"""

if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return torch.cuda.amp.autocast(dtype=torch.float32)

if torch.cuda.is_bf16_supported():
return torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return torch.cuda.amp.autocast(dtype=torch.float32)
else:
return nullcontext()


def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
return x.to(dtype=to_dtype) if x.dtype == from_dtype else x


def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
if isinstance(x, torch.Tensor):
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
else:
if isinstance(x, dict):
new_dict = {}
for k in x.keys():
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
return new_dict
elif isinstance(x, tuple):
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)


class CastToFloat(torch.nn.Module):
def __init__(self, mod):
super(CastToFloat, self).__init__()
self.mod = mod

def forward(self, x):
with avoid_float16_autocast_context():
ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
return ret
Loading