diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index ca2fcf27b86..62753c1f6c8 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -218,9 +218,16 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper( if (std::find(reference_ids.begin(), reference_ids.end(), id) != reference_ids.end()) { reordered_rfactor.push_back(id); - // Initiailze the extent for the mapped iter domain + // Initialize the extent for the mapped iter domain ProjectedExtent pe; - pe.multiplyNumeratorValue(commonOrConstExtent(ca_map_, id)); + auto ext = commonOrConstExtent(ca_map_, id); + if (ext->isConstInt() && ext->evaluateInt() == 0) { + // A size-zero extent ID will be predicated out always, so it should + // not affect the projected extent calculation + continue; + } else { + pe.multiplyNumeratorValue(ext); + } addProjectedExtent(id, pe); } else if (!id->isBroadcast()) { // Ignore broadcasts in the reference. Otherwise, remove non-contiguous @@ -239,9 +246,16 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper( if (std::find(reference_ids.begin(), reference_ids.end(), id) != reference_ids.end()) { reordered_root.push_back(id); - // Initiailze the extent for the mapped iter domain + // Initialize the extent for the mapped iter domain ProjectedExtent pe; - pe.multiplyNumeratorValue(commonOrConstExtent(ca_map_, id)); + auto ext = commonOrConstExtent(ca_map_, id); + if (ext->isConstInt() && ext->evaluateInt() == 0) { + // A size-zero extent ID will be predicated out always, so it should + // not affect the projected extent calculation + continue; + } else { + pe.multiplyNumeratorValue(ext); + } addProjectedExtent(id, pe); } else if (!id->isBroadcast()) { // Ignore broadcasts in the reference. Otherwise, remove non-contiguous diff --git a/csrc/scheduler/vectorize_helper.h b/csrc/scheduler/vectorize_helper.h index 0c8b6536195..374d5307fe3 100644 --- a/csrc/scheduler/vectorize_helper.h +++ b/csrc/scheduler/vectorize_helper.h @@ -71,11 +71,12 @@ class TORCH_CUDA_CU_API ProjectedExtent { // Multiply numerator by provided value, or if currently zero set numerator to // provided value. void multiplyNumeratorValue(Val* new_numerator_val) { - TORCH_INTERNAL_ASSERT( - !new_numerator_val->isZeroInt() && - (!new_numerator_val->isConstInt() || - new_numerator_val->evaluateInt() > 0), - "Adding numerator value of zero not supported in ProjectedExtent."); + if (new_numerator_val->isZeroInt() || + (new_numerator_val->isConstInt() && + new_numerator_val->evaluateInt() == 0)) { + // If we know the value is zero, we want to skip setting zero_ = true + return; + } zero_ = false; if (new_numerator_val->isConstInt()) { diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 3fa7063a7c1..e51f62a6c9a 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -2182,6 +2182,64 @@ def fusion_func(fd: FusionDefinition, inps, matmul_fn) -> None: fc = FusionCache.get() fc.reset() + def test_multiple_empty_slices(self): + inputs = [ + torch.testing.make_tensor((4,), dtype=torch.float32, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition): + T2 = fd.define_tensor( + symbolic_sizes=[-1], + contiguity=[True], + dtype=DataType.Float, + is_cpu=False, + ) + # Perform a size-1 slice and a size-0 slice on T2. The size-1 slice + # could be size >1 with no change in the error. The order does not + # matter. Performing only one of these slices does not trigger the + # error and the output is correct in that case. If there are + # multiple size-0 slices the error is not triggered. It only seems + # to appear when there are both size-0 and size non-zero slices of + # the same tensor. + T3 = fd.ops.slice(T2, start_indices=[2], end_indices=[3], strides=[1]) + T4 = fd.ops.slice(T2, start_indices=[1], end_indices=[1], strides=[1]) + fd.add_output(T3) + fd.add_output(T4) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + self.assertEqual(inputs[0][2:3], nvf_out[0]) + self.assertEqual(inputs[0][1:1], nvf_out[1]) + + def test_th420(self): + inputs = [ + torch.testing.make_tensor((4, 6, 7), dtype=torch.float32, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition): + T2 = fd.define_tensor( + symbolic_sizes=[-1], + contiguity=[True], + dtype=DataType.Float, + is_cpu=False, + ) + # Perform a size-1 slice and a size-0 slice on T2. The size-1 slice + # could be size >1 with no change in the error. The order does not + # matter. Performing only one of these slices does not trigger the + # error and the output is correct in that case. If there are + # multiple size-0 slices the error is not triggered. It only seems + # to appear when there are both size-0 and size non-zero slices of + # the same tensor. + T3 = fd.ops.slice(T2, start_indices=[0], end_indices=[4], strides=[1]) + T4 = fd.ops.slice(T2, start_indices=[1], end_indices=[1], strides=[1]) + fd.add_output(T3) + fd.add_output(T4) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + + self.assertEqual(inputs[0][2:3], nvf_out[0]) + self.assertEqual(inputs[0][1:1], nvf_out[1]) + def test_integer_division(self): inputs = [ torch.testing.make_tensor(1024, device="cuda", dtype=torch.long), diff --git a/test/test_resize.cpp b/test/test_resize.cpp index e119738dfda..9e53ae161ca 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2122,4 +2122,54 @@ TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) { TORCH_CHECK(ref0.equal(cg_outputs[0])); } +// Test issue with multiple slices, one of which has size 0 +TEST_F(NVFuserTest, FusionResizeMultiSliceEmpty_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector shape({9}); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + fusion->addInput(tv0); + + // Perform a size-1 slice and a size-0 slice on tv0. The size-1 slice + // could be size >1 with no change in the error. The order does not + // matter. Performing only one of these slices does not trigger the + // error and the output is correct in that case. If there are + // multiple size-0 slices the error is not triggered. It only seems + // to appear when there are both size-0 and size non-zero slices of + // the same tensor. + auto tv1 = slice( + tv0, + {{IrBuilder::create(0), + IrBuilder::create(1), + IrBuilder::create(1)}}); + fusion->addOutput(tv1); + auto tv2 = slice( + tv0, + {{IrBuilder::create(0), + IrBuilder::create(0), + IrBuilder::create(1)}}); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + auto params = *getPointwiseHeuristics(fusion.get(), aten_inputs); + schedulePointwise(fusion.get(), params); + FusionExecutor fe; + fe.compileFusion(fusion.get(), aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref0 = t0.index({at::indexing::Slice(0, 1)}); + auto ref1 = t0.index({at::indexing::Slice(0, 0)}); + + TORCH_CHECK(ref0.equal(cg_outputs[0])); + TORCH_CHECK(ref1.equal(cg_outputs[1])); +} + } // namespace nvfuser