From a4c01726ee1ba2ab4b5971cb655fcc82aa2fa4d5 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Fri, 18 Jul 2025 19:52:31 +0000 Subject: [PATCH] correct mask flops --- MaxText/maxtext_utils.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 7aa1a9df93..4634b31699 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -14,7 +14,7 @@ limitations under the License. """ -# pylint: disable=bare-except, consider-using-generator +# pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText. """ from typing import Optional @@ -268,7 +268,7 @@ def calculate_tflops_training_per_device(config, log=True): # Attention flops if config.attention_type == "mla": - qkv_flops, attention_flops, projection_flops = calculate_mla_tflops_per_device(config) + qkv_flops, noncausal_attention_flops, projection_flops = calculate_mla_tflops_per_device(config) else: qkv_flops = ( 2 @@ -278,7 +278,7 @@ def calculate_tflops_training_per_device(config, log=True): * (config.num_query_heads + 2 * config.num_kv_heads) * config.head_dim ) - attention_flops = ( + noncausal_attention_flops = ( 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim ) projection_flops = ( @@ -290,6 +290,12 @@ def calculate_tflops_training_per_device(config, log=True): * config.head_dim ) + # Divide attantion flops by 2 due to causal mask + # References: + # NVIDIA/Megatron-LM (2025 March): https://github.com/NVIDIA/Megatron-LM/blob/250b79415dcc4b660521273c87f15334c804eeae/megatron/training/training.py#L361-L362 + # NVIDIA/NeMo (2025 April): https://github.com/NVIDIA/NeMo/blob/ba4d6d116463de512ff0cfc14641aa6cf4577a42/nemo/utils/flops_formulas.py#L259-L272 + causal_attention_flops = noncausal_attention_flops / 2 + # Embedding flops embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size @@ -302,14 +308,13 @@ def calculate_tflops_training_per_device(config, log=True): learnable_weight_tflops = ( (total_ffn_flops + (qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 ) - attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12 + attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 else: # multiply by 3 for both feed forward and back propagation flops learnable_weight_tflops = ( ((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 ) - # megatron tflops calculation does not account for causality in attention - attention_tflops = attention_flops * config.num_decoder_layers * 3 / 10**12 + attention_tflops = causal_attention_flops * config.num_decoder_layers * 3 / 10**12 learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps attention_tflops = attention_tflops * config.gradient_accumulation_steps @@ -338,7 +343,7 @@ def calculate_tflops_training_per_device(config, log=True): def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, config, log=True): """Calculate training TFLOP""" learnable_weight_tflops = 2 * num_model_parameters * prefill_length / jax.device_count() / 1e12 - noncasual_attention_flops = ( + noncausal_attention_flops = ( 4 * config.num_query_heads * config.num_decoder_layers @@ -347,7 +352,7 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co / jax.device_count() / 1e12 ) - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention + causal_attention_tflops = noncausal_attention_flops / 2 # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: