Skip to content

[BUG] Inaccurate FLOPs Calculation for Causal and Specialized Attention #14376

@NuojCheng

Description

@NuojCheng

Description

This issue highlights inaccuracies in the FLOPs calculation for decoder-based models within nemo/utils/flops_formulas.py. Correcting these formulas is crucial for accurate model comparison and resource planning.

This alignment has been a recent focus in other major frameworks. For instance, Megatron-LM has updated its FLOPs calculations, and Google's MaxText has also refined its formulas to improve accuracy (see PRs #1988 and #2030).

The problems in NeMo are twofold:

  1. Standard causal attention isn't consistently accounted for, leading to a 2x overestimation of attention FLOPs for models like Llama and Mixtral.
  2. Models with specialized attention mechanisms require unique formulas, which are not currently implemented.

1. Causal Mask Inconsistency

FLOPs formulas should be consistent for all decoder models using standard causal attention.

Correct Implementation: The formulas for the base Transformer and DeepSeekV2 properly account for their respective attention mechanisms. The base Transformer formula correctly divides the FLOPs by two for the causal mask.

Missing Correction: This ÷2 adjustment for causal attention is absent in the formulas for other prominent models, notably Llama 2, Llama 3, and Mixtral. This causes their calculated attention FLOPs to be double what they should be.


2. Inaccuracy for Specialized Architectures

For new architectures, inheriting a generic formula can lead to errors. For example, the flops_callback.py dispatches formulas by model type.

If a new model like "Llama4" uses chunked attention, simply applying the Llama 3 formula with a causal correction would still be incorrect. Chunked attention has a different computational cost that needs its own specific formula.

Proposed Solution

  1. Standardize Causal Attention: Apply the ÷2 adjustment to the attention FLOPs calculation for all relevant decoder models (Llama, Mixtral, etc.) to align with standard practice and ensure consistency.
  2. Implement Architectural Specificity: For models that do not use full causal attention masks -- such as Llama4, Gemma2, Gemma3, which employ chunked or local attention -- define a custom FLOPs formula that accurately reflects their reduced attention computation.

Thank you for your consideration.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions