diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 652a3e6f4e3a..98dba5423009 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1041,7 +1041,10 @@ def __init__( reduce_amax=reduce_amax, ) - self.is_first_microbatch = True + self.is_first_train_microbatch = ( + True # Is the current micro-batch the first micro-batch in a global-batch in training + ) + self.is_prev_microbatch_training = True # Is the previous micro-batch in training mode self.microbatch_count = 0 # transformer engine forward needs to know if it is working on the first microbatch self.checkpoint_core_attention = ( activations_checkpoint_granularity == 'selective' @@ -1247,6 +1250,12 @@ def custom_forward(*inputs): attention_mask = inputs[1] encoder_output = inputs[2] enc_dec_attn_mask = inputs[3] + # Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch + # in training, (2) the first micro-batch in each validation and test routine. + # The caching happens in TransformerEngine when passing `is_first_microbatch=True`. + is_first_microbatch = (self.is_first_train_microbatch and self.training) or ( + self.is_prev_microbatch_training and not self.training + ) for index in range(start, end): layer = self._get_layer(index) hidden_states = layer( @@ -1255,7 +1264,7 @@ def custom_forward(*inputs): encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=None, - is_first_microbatch=self.is_first_microbatch, + is_first_microbatch=is_first_microbatch, checkpoint_core_attention=False, ) @@ -1531,6 +1540,12 @@ def forward( else: checkpoint_core_attention = False + # Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch + # in training, (2) the first micro-batch in each validation and test routine. + # The caching happens in TransformerEngine when passing `is_first_microbatch=True`. + is_first_microbatch = (self.is_first_train_microbatch and self.training) or ( + self.is_prev_microbatch_training and not self.training + ) if self.transformer_engine: hidden_states = layer( hidden_states, @@ -1538,7 +1553,7 @@ def forward( encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=self.inference_params, - is_first_microbatch=self.is_first_microbatch, + is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, ) else: @@ -1565,9 +1580,10 @@ def forward( self.microbatch_count += 1 if self.microbatch_count % num_micro_batches == 0: self.microbatch_count = 0 - self.is_first_microbatch = True + self.is_first_train_microbatch = True else: - self.is_first_microbatch = False + self.is_first_train_microbatch = False + self.is_prev_microbatch_training = self.training output = hidden_states