Skip to content

pad propagation and replay issues. #2374

@jjsjann123

Description

@jjsjann123

Mostly just an issue for myself to record progress while working on #1597 :

On the preseg transformation branch, I'm applying a transformation that pushes pad out and replaces CatOp with a binary add.

Repro branch here #2373

repro script:

from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    S2 = fd.define_scalar(0.297302, dtype=DataType.Double)
    #T3 = fd.ops.mul(T1, S2)
    T3 = fd.ops.relu(T1)
    T4 = fd.ops.cat([T0, T3], -1)
    fd.add_output(T4)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((100,), dtype=torch.float32, device='cuda:0').as_strided((2, 5, 10), (50, 10, 1)),
    torch.randn((30,), dtype=torch.float32, device='cuda:0').as_strided((2, 5, 3), (15, 3, 1)),
]
fd.execute(inputs)

Here's the fusion IR before the transformation

Inputs:
  T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
  T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
  T5_g[ iS17{i0}, iS18{i1}, iS22{( i2 + i6 )} ], float

%kernel_math {
i15 = i2 + i6;
i17 = -i2;
i19 = i15 + i17;
T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
   = pad( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], {0, 0, 0, 0, 0, i19} )
T2_l[ iS6{i4}, iS7{i5}, iS8{i6} ]
   = relu(T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ]);
i21 = 0 + i2;
T4_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]
   = pad( T2_l[ iS6{i4}, iS7{i5}, iS8{i6} ], {0, 0, 0, 0, i21, 0} )
T5_g[ iS17{i0}, iS18{i1}, iS22{( i2 + i6 )} ]
   = cat( T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ], T4_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ], 2 )
}

Here's the fusion IR after the transformation. I have:

  1. moved the pad at the beginning of the fusion.
  2. replayed relu with the padded input.
  3. replace the cat with a add, as well as the one of the input from the pad with the new relu output.
Inputs:
  T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
  T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
  T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ], float

%kernel_math {
i15 = i2 + i6;
i17 = -i2;
i19 = i15 + i17;
T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
   = pad( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], {0, 0, 0, 0, 0, i19} )
i21 = 0 + i2;
T6_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]
   = pad( T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], {0, 0, 0, 0, i21, 0} )
T7_l[ iS23{i4}, iS24{i5}, iS25{( i6 + ( 0 + i2 ) )} ]
   = relu(T6_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]);
T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
   = T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
   + T7_l[ iS23{i4}, iS24{i5}, iS25{( i6 + ( 0 + i2 ) )} ];
}

I'm hitting issues with compute_at_map here:

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 145, in execute
    result = self._execute(
RuntimeError: logical_id_uses.find(logical_inp_id) == logical_id_uses.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp":621, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Was expecting iter domains to only have one active transformation but found id iS11{i2}rf used in
Resize: iS11{i2}rf by 0 and ( ( i2 + i6 ) + ( -i2 ) ) -> iS20{( i2 + i6 )}rf

and
Resize: iS11{i2}rf by 0 and ( ( i2 + i6 ) + ( -i2 ) ) -> iS20{( i2 + i6 )}rf

Looks like we are seeing redundant expressions. Maybe there's something in the replay I was using.

More Context

I tried to verify that the transformation is legit. I have a cpp test that generates a similar fusion IR and everything is working fine with the cpp test below.

TEST_F(NVFuserTest, FusionPadDynamicShape) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* tv0 = makeContigTensor(3);
  fusion->addInput(tv0);
  TensorView* tv1 = makeContigTensor(3);
  fusion->addInput(tv1);

  Val* i15 = add(tv0->axis(2)->extent(), tv1->axis(2)->extent());
  Val* i17 = neg(tv0->axis(2)->extent());
  Val* i19 = add(i15, i17);
  Val* zero = IrBuilder::create<Val>(0);

  TensorView* tv3 = pad(tv0, {zero, i19, zero, zero, zero, zero});

  Val* i21 = add(zero, tv0->axis(2)->extent());
  TensorView* tv6 = pad(tv1, {i21, zero, zero, zero, zero, zero});
  TensorView* tv7 = relu(tv6);
  TensorView* tv8 = add(tv3, tv7);
  fusion->addOutput(tv8);

  fusion->printMath();

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({2, 5, 10}, options);
  at::Tensor t1 = at::randn({2, 5, 3}, options);
  std::vector<c10::IValue> aten_inputs({t0, t1});

  FusionExecutorCache executor_cache(std::move(fusion));
  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

  testValidate(
      executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__);
}

Running with NVFUSER_DUMP=segmented_fusion

Segmented_Fusion Dump: -- Re-written complete fusion:{
Inputs:
  T0_g[ iS0{i0}, iS1{i2}, iS2{i3} ], float
  T1_g[ iS24{i0}, iS25{i2}, iS5{i6} ], float
Outputs:
  T5_g[ iS17{i0}, iS18{i2}, iS23{( i3 + i6 )} ], float

%kernel_math {
i7 = i3 + i6;
i9 = -i3;
i11 = i7 + i9;
T2_l[ iS6{i0}, iS7{i2}, iS20{( i3 + i6 )}rf ]
   = pad( T0_g[ iS0{i0}, iS1{i2}, iS2{i3} ], {0, 0, 0, 0, 0, i11} )
i20 = 0 + i3;
T3_l[ iS26{i0}, iS27{i2}, iS21{( i6 + ( 0 + i3 ) )}rf ]
   = pad( T1_g[ iS24{i0}, iS25{i2}, iS5{i6} ], {0, 0, 0, 0, i20, 0} )
T4_l[ iS28{i0}, iS29{i2}, iS22{( i6 + ( 0 + i3 ) )} ]
   = relu(T3_l[ iS26{i0}, iS27{i2}, iS21{( i6 + ( 0 + i3 ) )}rf ]);
T5_g[ iS17{i0}, iS18{i2}, iS23{( i3 + i6 )} ]
   = T2_l[ iS6{i0}, iS7{i2}, iS20{( i3 + i6 )}rf ]
   + T4_l[ iS28{i0}, iS29{i2}, iS22{( i6 + ( 0 + i3 ) )} ];
}

} // {Re-written complete fusion}

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions