Skip to content

feat: add hybridep#1333

Merged
hemildesai merged 14 commits intomainfrom
hemil/hybrid-ep-2
Mar 31, 2026
Merged

feat: add hybridep#1333
hemildesai merged 14 commits intomainfrom
hemil/hybrid-ep-2

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

@hemildesai hemildesai commented Feb 19, 2026

Wandb - https://wandb.ai/Nemo-automodel/automodel-moe-dispatcher

This PR adds HybridEP support for MoE token dispatch and updates dependency/docker wiring needed to run it reliably.

Changelog

  • Add HybridEP backend support to MoE flex token dispatch:
    • Introduce _HybridEPManager in nemo_automodel/components/moe/megatron/token_dispatcher.py.
    • Add moe_flex_dispatcher_backend ("deepep" or "hybridep"), plus backend-specific SM settings.
    • Add HybridEP preprocessing path that converts top-k indices to multihot routing metadata.
  • Extend fused all-to-all utilities in nemo_automodel/components/moe/megatron/fused_a2a.py:
    • Add set_deepep_num_sms.
    • Add HybridEP dispatch/combine autograd wrappers and buffer initialization/reset helpers.
  • Extend backend config and MoE wiring:
    • BackendConfig.dispatcher now accepts "hybridep".
    • Add dispatcher_num_sms.
    • Treat "hybridep" as valid with te/gmm experts.
    • Pass dispatcher backend + SM settings through MoE, GroupedExpertsDeepEP, and GroupedExpertsTE.
  • Add unit tests for HybridEP paths:
    • tests/unit_tests/moe/test_backend_config.py
    • tests/unit_tests/moe/test_experts.py
    • tests/unit_tests/moe/test_layers.py
  • Update dependencies:
    • Bump deep_ep to 7febc6e25660af0f54d95dd781ecdcd62265ecca (v1.2.1+7febc6e metadata).
    • Add Linux-only override: nvidia-cudnn-cu12==9.19.0.56; sys_platform == 'linux'.
    • Update uv.lock accordingly.
  • Fix Docker update behavior:
    • docker/common/update_pyproject_pytorch.sh now removes existing override-dependencies and reinserts docker/common/uv-pytorch.toml under
      [tool.uv] to avoid duplicate/incorrect override blocks.

Additional Information

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 19, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 8d0be12

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 6233747

Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 0a8d5ee

@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread docker/Dockerfile Outdated
Comment on lines +94 to +97
TORCH_CUDA_ARCH_LIST="9.0 10.0 12.0" MAX_JOBS=8 pip install --no-build-isolation . && \
apt-get purge -y libnvidia-ml-dev && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/*
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code cleaned up the DeepEP/ source dir and /opt/deepep.patch after install. The new code doesn't, and also leaves rdma-core/ behind. This can add non-trivial bloat to the Docker image.

Suggested change
TORCH_CUDA_ARCH_LIST="9.0 10.0 12.0" MAX_JOBS=8 pip install --no-build-isolation . && \
apt-get purge -y libnvidia-ml-dev && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/*
TORCH_CUDA_ARCH_LIST="9.0 10.0 12.0" MAX_JOBS=8 pip install --no-build-isolation . && \
apt-get purge -y libnvidia-ml-dev && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm -rf /opt/deepep.patch && \
cd / && rm -rf DeepEP rdma-core

Comment on lines +399 to +419
) -> torch.Tensor:
# Reset num_permuted_tokens to None to avoid reusing cached state from a prior dispatch.
# This can happen in non-reentrant activation checkpointing mode.
self.num_permuted_tokens = None
if self.token_probs.dtype != torch.float32:
self.token_probs = self.token_probs.float()
dispatched_hidden, self.dispatched_probs, _, tokens_per_expert, self.handle = hybrid_ep_dispatch(
x=hidden_states,
routing_map=self.routing_map,
probs=self.token_probs,
group=self.group,
num_local_experts=self.num_local_experts,
num_sms_dispatch_api=self.moe_hybridep_num_sms,
num_sms_combine_api=self.moe_hybridep_num_sms,
num_permuted_tokens=self.num_permuted_tokens,
pad_multiple=self.pad_multiple,
)

self.tokens_per_expert = tokens_per_expert
self.num_permuted_tokens = self.tokens_per_expert.sum()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_indices_to_multihot is a pure tensor operation (no multi-GPU or DeepEP dependency needed) and is on the non-permute_fusion code path. Consider adding a unit test for it — edge cases like all--1 indices, duplicate expert assignments, or topk=1 would be good to cover.

… tests

- Clean up DeepEP/, rdma-core/, and deepep.patch after Docker build to reduce image bloat
- Add unit tests for _HybridEPManager._indices_to_multihot covering basic,
  topk=1, all-invalid, partial-invalid, and single-token edge cases

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 061d6f4

Comment thread docker/Dockerfile
# Install HybridEP
## Dependency: RDMA Core
RUN git clone https://github.com/linux-rdma/rdma-core.git && \
cd rdma-core && git checkout tags/v60.0 && sh build.sh
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The INSTALL_DEEPEP build arg was removed here (DeepEP is now always installed), but it's still referenced in:

  • CONTRIBUTING.md:85 (export INSTALL_DEEPEP=True)
  • CONTRIBUTING.md:91 (--build-arg INSTALL_DEEPEP=$INSTALL_DEEPEP)
  • .github/actions/build-container/action.yml:146 (INSTALL_DEEPEP=True)

The CI action will silently pass a build arg that no longer exists (harmless but confusing), and the CONTRIBUTING docs are now stale/misleading. These should be cleaned up.

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall — clean HybridEP integration with solid test coverage for the new paths.

One issue: the INSTALL_DEEPEP build arg was removed from the Dockerfile (DeepEP now builds unconditionally), but stale references remain in CONTRIBUTING.md and .github/actions/build-container/action.yml. See inline comment.

DeepEP is now always installed (no longer behind a build arg), so remove
the leftover INSTALL_DEEPEP references from CONTRIBUTING.md and the
build-container GitHub Action.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 07ffc01

self.num_permuted_tokens = None
return hidden_states

def get_dispached_metadata(self) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: typo — get_dispached_metadata should be get_dispatched_metadata. (Pre-existing in the base class and _DeepepManager too, but this is a good opportunity to fix it across all three since you're adding a new implementation.)

Comment on lines +393 to +397

def dispatch(
self,
hidden_states: torch.Tensor,
async_finish: bool = True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug/silent no-op: async_finish and allocate_on_comm_stream are accepted here but silently ignored. The caller (MoEFlexTokenDispatcher.token_dispatch at L653 and token_combine at L732) passes these through for both backends, but only _DeepepManager honors them.

If HybridEP intentionally doesn't support these, consider either:

  • Removing them from the signature and documenting why, or
  • Wiring them through to hybrid_ep_dispatch / hybrid_ep_combine (e.g. non_blocking in dispatch_with_permute is already derived from num_permuted_tokens, but allocate_on_comm_stream has no equivalent)

As-is, a caller switching from deepephybridep will silently lose async behavior.

- Rename get_dispached_metadata -> get_dispatched_metadata across all
  three dispatch manager classes
- Document that async_finish and allocate_on_comm_stream are not
  supported by the HybridEP backend (kept in signature for interface
  compatibility with callers)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test b117377

claude[bot]
claude Bot previously approved these changes Mar 30, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Regenerated .secrets.baseline to keep line numbers in sync with the
updated files. All entries are false positives (config keys, test
fixtures, high-entropy example strings).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test cf65bff

@hemildesai
Copy link
Copy Markdown
Contributor Author

Thanks @hemildesai when you get the chance, please share the tulu3 results from the robustness test. I think it's ok to use h100s if resource constrained. 🙇

https://wandb.ai/Nemo-automodel/tulu3-convergence/runs/z2jpozm9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support HybridEP

3 participants