Skip to content
Closed
22 changes: 18 additions & 4 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,16 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper(
if (std::find(reference_ids.begin(), reference_ids.end(), id) !=
reference_ids.end()) {
reordered_rfactor.push_back(id);
// Initiailze the extent for the mapped iter domain
// Initialize the extent for the mapped iter domain
ProjectedExtent pe;
pe.multiplyNumeratorValue(commonOrConstExtent(ca_map_, id));
auto ext = commonOrConstExtent(ca_map_, id);
if (ext->isConstInt() && ext->evaluateInt() == 0) {
// A size-zero extent ID will be predicated out always, so it should
// not affect the projected extent calculation
continue;
} else {
pe.multiplyNumeratorValue(ext);
}
addProjectedExtent(id, pe);
} else if (!id->isBroadcast()) {
// Ignore broadcasts in the reference. Otherwise, remove non-contiguous
Expand All @@ -239,9 +246,16 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper(
if (std::find(reference_ids.begin(), reference_ids.end(), id) !=
reference_ids.end()) {
reordered_root.push_back(id);
// Initiailze the extent for the mapped iter domain
// Initialize the extent for the mapped iter domain
ProjectedExtent pe;
pe.multiplyNumeratorValue(commonOrConstExtent(ca_map_, id));
auto ext = commonOrConstExtent(ca_map_, id);
if (ext->isConstInt() && ext->evaluateInt() == 0) {
// A size-zero extent ID will be predicated out always, so it should
// not affect the projected extent calculation
continue;
Comment on lines +252 to +255
Copy link
Collaborator Author

@jacobhinkle jacobhinkle May 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zasdfgbnm I know you've been looking at this code lately also, in relation to #393 . I am trying to avoid bad behavior when an empty TensorView with size zero extent is encountered. I currently am just pretending it doesn't exist here for the purposes of computing the projected extent. But I'm a little uncertain of whether that is enough, or if it's even safe. I think I have an idea of how this class supposed to work but I'd appreciate you having a look if you have some time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should clean up this entire thing with expr simplifier, and I am thinking about doing so after #393.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the case of zero extent, as long as this does not throw an error, whatever value it returns does not matter, because we will not do any memory read/write anyway. Am I correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a tensor has zero elements, should we just stop propagation and assume that tensor is OK with arbitrary vectorization size?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am refactoring this part of code for allocation domain support. So I don't want you to invest too much time digging into the codebase on how to do it best. The workaround here makes sense if it can let the test pass.

} else {
pe.multiplyNumeratorValue(ext);
}
addProjectedExtent(id, pe);
} else if (!id->isBroadcast()) {
// Ignore broadcasts in the reference. Otherwise, remove non-contiguous
Expand Down
11 changes: 6 additions & 5 deletions csrc/scheduler/vectorize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ class TORCH_CUDA_CU_API ProjectedExtent {
// Multiply numerator by provided value, or if currently zero set numerator to
// provided value.
void multiplyNumeratorValue(Val* new_numerator_val) {
TORCH_INTERNAL_ASSERT(
!new_numerator_val->isZeroInt() &&
(!new_numerator_val->isConstInt() ||
new_numerator_val->evaluateInt() > 0),
"Adding numerator value of zero not supported in ProjectedExtent.");
if (new_numerator_val->isZeroInt() ||
(new_numerator_val->isConstInt() &&
new_numerator_val->evaluateInt() == 0)) {
// If we know the value is zero, we want to skip setting zero_ = true
return;
}

zero_ = false;
if (new_numerator_val->isConstInt()) {
Expand Down
58 changes: 58 additions & 0 deletions python_tests/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,64 @@ def fusion_func(fd: FusionDefinition, inps, matmul_fn) -> None:
fc = FusionCache.get()
fc.reset()

def test_multiple_empty_slices(self):
inputs = [
torch.testing.make_tensor((4,), dtype=torch.float32, device="cuda"),
]

def fusion_func(fd: FusionDefinition):
T2 = fd.define_tensor(
symbolic_sizes=[-1],
contiguity=[True],
dtype=DataType.Float,
is_cpu=False,
)
# Perform a size-1 slice and a size-0 slice on T2. The size-1 slice
# could be size >1 with no change in the error. The order does not
# matter. Performing only one of these slices does not trigger the
# error and the output is correct in that case. If there are
# multiple size-0 slices the error is not triggered. It only seems
# to appear when there are both size-0 and size non-zero slices of
# the same tensor.
T3 = fd.ops.slice(T2, start_indices=[2], end_indices=[3], strides=[1])
T4 = fd.ops.slice(T2, start_indices=[1], end_indices=[1], strides=[1])
fd.add_output(T3)
fd.add_output(T4)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

self.assertEqual(inputs[0][2:3], nvf_out[0])
self.assertEqual(inputs[0][1:1], nvf_out[1])

def test_th420(self):
inputs = [
torch.testing.make_tensor((4, 6, 7), dtype=torch.float32, device="cuda"),
]

def fusion_func(fd: FusionDefinition):
T2 = fd.define_tensor(
symbolic_sizes=[-1],
contiguity=[True],
dtype=DataType.Float,
is_cpu=False,
)
# Perform a size-1 slice and a size-0 slice on T2. The size-1 slice
# could be size >1 with no change in the error. The order does not
# matter. Performing only one of these slices does not trigger the
# error and the output is correct in that case. If there are
# multiple size-0 slices the error is not triggered. It only seems
# to appear when there are both size-0 and size non-zero slices of
# the same tensor.
T3 = fd.ops.slice(T2, start_indices=[0], end_indices=[4], strides=[1])
T4 = fd.ops.slice(T2, start_indices=[1], end_indices=[1], strides=[1])
fd.add_output(T3)
fd.add_output(T4)

nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)

self.assertEqual(inputs[0][2:3], nvf_out[0])
self.assertEqual(inputs[0][1:1], nvf_out[1])

def test_integer_division(self):
inputs = [
torch.testing.make_tensor(1024, device="cuda", dtype=torch.long),
Expand Down
50 changes: 50 additions & 0 deletions test/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2122,4 +2122,54 @@ TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) {
TORCH_CHECK(ref0.equal(cg_outputs[0]));
}

// Test issue with multiple slices, one of which has size 0
TEST_F(NVFuserTest, FusionResizeMultiSliceEmpty_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

std::vector<int64_t> shape({9});

// concrete shapes to avoid dynamic Fusion
auto tv0 = makeConcreteTensor(shape);
fusion->addInput(tv0);

// Perform a size-1 slice and a size-0 slice on tv0. The size-1 slice
// could be size >1 with no change in the error. The order does not
// matter. Performing only one of these slices does not trigger the
// error and the output is correct in that case. If there are
// multiple size-0 slices the error is not triggered. It only seems
// to appear when there are both size-0 and size non-zero slices of
// the same tensor.
auto tv1 = slice(
tv0,
{{IrBuilder::create<Int>(0),
IrBuilder::create<Int>(1),
IrBuilder::create<Int>(1)}});
fusion->addOutput(tv1);
auto tv2 = slice(
tv0,
{{IrBuilder::create<Int>(0),
IrBuilder::create<Int>(0),
IrBuilder::create<Int>(1)}});
fusion->addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);

auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

auto params = *getPointwiseHeuristics(fusion.get(), aten_inputs);
schedulePointwise(fusion.get(), params);
FusionExecutor fe;
fe.compileFusion(fusion.get(), aten_inputs);
auto cg_outputs = fe.runFusion(aten_inputs);

auto ref0 = t0.index({at::indexing::Slice(0, 1)});
auto ref1 = t0.index({at::indexing::Slice(0, 0)});

TORCH_CHECK(ref0.equal(cg_outputs[0]));
TORCH_CHECK(ref1.equal(cg_outputs[1]));
}

} // namespace nvfuser