Skip to content

dsa indexer TP #36

@zzh24zzh

Description

@zzh24zzh

Checklist / 检查清单

  • I have searched existing issues, and this is a new question or discussion topic. / 我已经搜索过现有的 issues,确认这是一个新的问题与讨论。

Question Description / 问题描述

hello 我跑glm5.1 的时候发现第一个node 只处理前两层也会 oom,好像是indexer 没有支持TP 在long context 会直接oom

(decoder): TransformerBlock(
    (layers): ModuleList(
      (0): TransformerLayer(
        (input_layernorm): RMSNorm()
        (self_attention): MLASelfAttention(
          (core_attention): DSAttention(
            (indexer): NewDSAIndexer(
              (rotary_pos_emb): RotaryEmbedding()
              (linear_wq_b): TELinear(in_features=2048, out_features=4096, bias=False, TP=1)
              (linear_wk): TELinear(in_features=6144, out_features=128, bias=False, TP=1)
              (k_norm): LayerNorm()
              (linear_weights_proj): TELinear(in_features=6144, out_features=32, bias=False, TP=1)
            )
          )
          (linear_proj): TERowParallelLinear(in_features=2048, out_features=6144, bias=False, TP=8)
          (linear_q_down_proj): TELinear(in_features=6144, out_features=2048, bias=False, TP=1)
          (linear_q_up_proj): TEColumnParallelLinear(in_features=2048, out_features=2048, bias=False, TP=8)
          (linear_kv_down_proj): TELinear(in_features=6144, out_features=576, bias=False, TP=1)
          (linear_kv_up_proj): TEColumnParallelLinear(in_features=512, out_features=3584, bias=False, TP=8)
          (q_layernorm): RMSNorm()
          (kv_layernorm): RMSNorm()
        )
        (pre_cross_attn_layernorm): IdentityOp()
        (cross_attention): IdentityOp()
        (cross_attn_bda): IdentityFuncOp()
        (pre_mlp_layernorm): IdentityOp()
        (mlp): MLP(
          (linear_fc1): TELayerNormColumnParallelLinear(in_features=6144, out_features=3072, bias=False, TP=8)
          (linear_fc2): TERowParallelLinear(in_features=1536, out_features=6144, bias=False, TP=8)
        )
      )
    )
  )
  (rotary_pos_emb): RotaryEmbedding()
)
NFO:swift] The training of Epoch 0 starts...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/cli/_megatron/sft.py", line 7, in <module>
[rank0]:     megatron_sft_main()
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/megatron/pipelines/train/sft.py", line 93, in megatron_sft_main
[rank0]:     return MegatronSft(args).main()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/pipelines/base.py", line 52, in main
[rank0]:     result = self.run()
[rank0]:              ^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/megatron/pipelines/train/sft.py", line 68, in run
[rank0]:     trainer.train(train_dataset, val_dataset)
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/megatron/trainers/base.py", line 636, in train
[rank0]:     metrics, grad_norm, update_successful = self.train_step(train_data_iterator)
[rank0]:                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/megatron/trainers/base.py", line 857, in train_step
[rank0]:     metrics = forward_backward_func(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/pipeline_parallel/schedules.py", line 2209, in forward_backward_pipelining_without_interleaving
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:                                 ^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/pipeline_parallel/schedules.py", line 428, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/swift/megatron/trainers/trainer.py", line 124, in forward_step
[rank0]:     output_tensor = model(**data)
[rank0]:                     ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/distributed/data_parallel_base.py", line 22, in forward
[rank0]:     return self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/module.py", line 493, in forward
[rank0]:     outputs = self.module(*inputs, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/mcore_bridge/model/gpt_model.py", line 331, in forward
[rank0]:     hidden_states = self.decoder(
[rank0]:                     ^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/transformer_block.py", line 643, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/module.py", line 356, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/transformer_block.py", line 784, in forward
[rank0]:     checkpointed_result = self._checkpointed_forward(
[rank0]:                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/transformer_block.py", line 584, in _checkpointed_forward
[rank0]:     hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/transformer_block.py", line 535, in checkpoint_handler
[rank0]:     return tensor_parallel.checkpoint(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/tensor_parallel/random.py", line 642, in checkpoint
[rank0]:     return CheckpointFunction.apply(function, distribute_saved_activations, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/tensor_parallel/random.py", line 581, in forward
[rank0]:     outputs = run_function(*args)
[rank0]:               ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/transformer_block.py", line 503, in custom_forward
[rank0]:     hidden_states, context = layer(
[rank0]:                              ^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/module.py", line 356, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/mcore_bridge/patcher.py", line 577, in forward
[rank0]:     hidden_states, context = self._forward_attention(*_args, **kwargs)
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/transformer_layer.py", line 615, in _forward_attention
[rank0]:     attention_output_with_bias = self.self_attention(
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/mcore_bridge/patcher.py", line 139, in forward
[rank0]:     core_attn_out = self.core_attention(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/experimental_attention_variant/dsa.py", line 1110, in forward
[rank0]:     _, topk_indices = self.indexer.forward_with_scores(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/mcore_bridge/patcher.py", line 833, in forward_with_scores
[rank0]:     index_scores, topk_indices = fused_qk_topk_naive(q, k, weights, self.index_topk, mask)
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/experimental_attention_variant/dsa.py", line 311, in fused_qk_topk_naive
[rank0]:     index_scores = _compute_index_scores(q, weights, k)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tiger/.local/lib/python3.11/site-packages/megatron/core/transformer/experimental_attention_variant/dsa.py", line 278, in _compute_index_scores
[rank0]:     index_scores = torch.einsum('sbhd,tbd->sbht', q.float(), k.float())
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/functional.py", line 422, in einsum
[rank0]:     return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 238.17 GiB. GPU 0 has a total capacity of 79.11 GiB of which 66.51 GiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 10.46 GiB is allocated by PyTorch, and 136.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions