Skip to content

Cannot define an MmaOp where batch dimension is Broadcast #2273

@jacobhinkle

Description

@jacobhinkle

The following test fails when trying to create an MmaOp

// Single batch dimension which is broadcast
TEST_F(GPUTTensorCoreTest, FusionAmpereBroadcastBatchMatmul_CUDA) {
  auto layout = MmaLayout::TN;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

  auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
  auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  auto tv2 = fusedMultiplySum(
      broadcast(tv0, {true, false, false, false}),
      broadcast(tv1, {true, false, false, false}),
      {-1});
/*
C++ exception with description "details.bcasts.empty() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir/utils.cpp":1
268, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. MmaOp output: has broadcast domains.                                                                                                                 
Exception raised from operator() at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1268 (most recent call first):
*/

  fusion.addOutput(tv2);
}

This caused the failure of

TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail2) {
which is why that test currently checks that we cannot translate that case. However, I think that case should be covered and we should instead fix the MmaOp ctor to not balk at such cases.

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions