From 063d5263961d5f981f3ffc9c8bce15f78982272b Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Tue, 19 Sep 2023 22:21:11 -0700 Subject: [PATCH 1/2] Cache weight and transpose only in the first batch in all training, val, and test runs Signed-off-by: Sangkug Lym --- .../modules/common/megatron/transformer.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 652a3e6f4e3a..1cb0013278a1 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1041,7 +1041,8 @@ def __init__( reduce_amax=reduce_amax, ) - self.is_first_microbatch = True + self.is_first_train_microbatch = True # Is the current micro-batch the first micro-batch in a global-batch in training + self.is_prev_microbatch_training = True # Is the previous micro-batch in training mode self.microbatch_count = 0 # transformer engine forward needs to know if it is working on the first microbatch self.checkpoint_core_attention = ( activations_checkpoint_granularity == 'selective' @@ -1247,6 +1248,11 @@ def custom_forward(*inputs): attention_mask = inputs[1] encoder_output = inputs[2] enc_dec_attn_mask = inputs[3] + # Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch + # in training, (2) the first micro-batch in each validation and test routine. + # The caching happens in TransformerEngine when passing `is_first_microbatch=True`. + is_first_microbatch = (self.is_first_train_microbatch and self.training) or ( + self.is_prev_microbatch_training and not self.training) for index in range(start, end): layer = self._get_layer(index) hidden_states = layer( @@ -1255,7 +1261,7 @@ def custom_forward(*inputs): encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=None, - is_first_microbatch=self.is_first_microbatch, + is_first_microbatch=is_first_microbatch, checkpoint_core_attention=False, ) @@ -1531,6 +1537,11 @@ def forward( else: checkpoint_core_attention = False + # Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch + # in training, (2) the first micro-batch in each validation and test routine. + # The caching happens in TransformerEngine when passing `is_first_microbatch=True`. + is_first_microbatch = (self.is_first_train_microbatch and self.training) or ( + self.is_prev_microbatch_training and not self.training) if self.transformer_engine: hidden_states = layer( hidden_states, @@ -1538,7 +1549,7 @@ def forward( encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=self.inference_params, - is_first_microbatch=self.is_first_microbatch, + is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, ) else: @@ -1565,9 +1576,10 @@ def forward( self.microbatch_count += 1 if self.microbatch_count % num_micro_batches == 0: self.microbatch_count = 0 - self.is_first_microbatch = True + self.is_first_train_microbatch = True else: - self.is_first_microbatch = False + self.is_first_train_microbatch = False + self.is_prev_microbatch_training = self.training output = hidden_states From 9d5620d36a11322f596c72e91923c1b75537040f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Sep 2023 23:24:28 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../nlp/modules/common/megatron/transformer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 1cb0013278a1..98dba5423009 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1041,8 +1041,10 @@ def __init__( reduce_amax=reduce_amax, ) - self.is_first_train_microbatch = True # Is the current micro-batch the first micro-batch in a global-batch in training - self.is_prev_microbatch_training = True # Is the previous micro-batch in training mode + self.is_first_train_microbatch = ( + True # Is the current micro-batch the first micro-batch in a global-batch in training + ) + self.is_prev_microbatch_training = True # Is the previous micro-batch in training mode self.microbatch_count = 0 # transformer engine forward needs to know if it is working on the first microbatch self.checkpoint_core_attention = ( activations_checkpoint_granularity == 'selective' @@ -1252,7 +1254,8 @@ def custom_forward(*inputs): # in training, (2) the first micro-batch in each validation and test routine. # The caching happens in TransformerEngine when passing `is_first_microbatch=True`. is_first_microbatch = (self.is_first_train_microbatch and self.training) or ( - self.is_prev_microbatch_training and not self.training) + self.is_prev_microbatch_training and not self.training + ) for index in range(start, end): layer = self._get_layer(index) hidden_states = layer( @@ -1541,7 +1544,8 @@ def forward( # in training, (2) the first micro-batch in each validation and test routine. # The caching happens in TransformerEngine when passing `is_first_microbatch=True`. is_first_microbatch = (self.is_first_train_microbatch and self.training) or ( - self.is_prev_microbatch_training and not self.training) + self.is_prev_microbatch_training and not self.training + ) if self.transformer_engine: hidden_states = layer( hidden_states,