Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,10 @@ def __init__(
reduce_amax=reduce_amax,
)

self.is_first_microbatch = True
self.is_first_train_microbatch = (
True # Is the current micro-batch the first micro-batch in a global-batch in training
)
self.is_prev_microbatch_training = True # Is the previous micro-batch in training mode
self.microbatch_count = 0 # transformer engine forward needs to know if it is working on the first microbatch
self.checkpoint_core_attention = (
activations_checkpoint_granularity == 'selective'
Expand Down Expand Up @@ -1247,6 +1250,12 @@ def custom_forward(*inputs):
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
# Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch
# in training, (2) the first micro-batch in each validation and test routine.
# The caching happens in TransformerEngine when passing `is_first_microbatch=True`.
is_first_microbatch = (self.is_first_train_microbatch and self.training) or (
self.is_prev_microbatch_training and not self.training
)
for index in range(start, end):
layer = self._get_layer(index)
hidden_states = layer(
Expand All @@ -1255,7 +1264,7 @@ def custom_forward(*inputs):
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=None,
is_first_microbatch=self.is_first_microbatch,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=False,
)

Expand Down Expand Up @@ -1531,14 +1540,20 @@ def forward(
else:
checkpoint_core_attention = False

# Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch
# in training, (2) the first micro-batch in each validation and test routine.
# The caching happens in TransformerEngine when passing `is_first_microbatch=True`.
is_first_microbatch = (self.is_first_train_microbatch and self.training) or (
self.is_prev_microbatch_training and not self.training
)
if self.transformer_engine:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=self.inference_params,
is_first_microbatch=self.is_first_microbatch,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
)
else:
Expand All @@ -1565,9 +1580,10 @@ def forward(
self.microbatch_count += 1
if self.microbatch_count % num_micro_batches == 0:
self.microbatch_count = 0
self.is_first_microbatch = True
self.is_first_train_microbatch = True
else:
self.is_first_microbatch = False
self.is_first_train_microbatch = False
self.is_prev_microbatch_training = self.training

output = hidden_states

Expand Down