Skip to content

[JAX] grouped_gemm() uses variadic arguments#1658

Merged
phu0ngng merged 14 commits intoNVIDIA:mainfrom
huanghua1994:JAX-GroupedGEMM-VariadicArgs
Apr 14, 2025
Merged

[JAX] grouped_gemm() uses variadic arguments#1658
phu0ngng merged 14 commits intoNVIDIA:mainfrom
huanghua1994:JAX-GroupedGEMM-VariadicArgs

Conversation

@huanghua1994
Copy link
Collaborator

@huanghua1994 huanghua1994 commented Apr 8, 2025

Description

This PR optimizes the grouped_gemm() implementation in JAX. The original implementation manually flattens all input matrices before lowering to C++ function and manually split the output into a list of tensors. Using variadic arguments allows the code to avoid extra copying of inputs and outputs.

This PR is marked as a draft since PR #1545 breaks both the original implementation and this implementation for MXFP8 on Blackwell. Fixed by PR #1652.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Rewrite grouped_gemm() and GroupedGemmPrimitive using variadic arguments.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@huanghua1994 huanghua1994 requested a review from phu0ngng April 8, 2025 20:08
@huanghua1994 huanghua1994 self-assigned this Apr 8, 2025
@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from d9b9aef to 600029a Compare April 8, 2025 21:28
@huanghua1994 huanghua1994 changed the title [Draft][JAX] grouped_gemm() uses variadic arguments [JAX] grouped_gemm() uses variadic arguments Apr 8, 2025
@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from 600029a to 3af8804 Compare April 10, 2025 16:53
@phu0ngng
Copy link
Collaborator

/te-ci jax L0

@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from 5d581ee to 292f7a9 Compare April 10, 2025 22:41
huanghua1994 and others added 13 commits April 11, 2025 09:27
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
.
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huah@nvidia.com>
@huanghua1994 huanghua1994 force-pushed the JAX-GroupedGEMM-VariadicArgs branch from 292f7a9 to c850c00 Compare April 11, 2025 16:27
@phu0ngng
Copy link
Collaborator

/te-ci jax L0

@phu0ngng
Copy link
Collaborator

Pipeline #26823920 passed.
Ready to merge.

@phu0ngng phu0ngng merged commit 98b4c0d into NVIDIA:main Apr 14, 2025
12 checks passed
@huanghua1994 huanghua1994 deleted the JAX-GroupedGEMM-VariadicArgs branch April 24, 2025 16:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants