Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions csrc/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2250,9 +2250,6 @@ std::vector<IterDomain*> IterDomain::clone(
// domains have valid start and stop, it's not possible to contiguous
// predication.
IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
TORCH_CHECK(
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
"Merging IterDomains with ending values that are 0 is not supported at this time.");
TORCH_CHECK(
outer->isReduction() == inner->isReduction(),
"Merging IterDomains requires that their iteration types match. ",
Expand Down Expand Up @@ -2330,10 +2327,6 @@ std::pair<IterDomain*, IterDomain*> IterDomain::split(
bool inner_split,
Val* start_offset,
Val* stop_offset) {
TORCH_CHECK(
!in->extent()->isZeroInt(),
"Splitting IterDomains with ending values that are 0 is not supported at this time.");

TORCH_CHECK(
factor->isIntegralScalar(), "Cannot split by non-integer value ", factor);

Expand Down
107 changes: 107 additions & 0 deletions test/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2053,4 +2053,111 @@ TEST_F(NVFuserTest, ResizePermuteAndSlice_CUDA) {
__FILE__);
}

// When scheduling this test, the pointwise scheduler attempt to replay a Split
// transform on a size-0 dimension, which is not allowed.
TEST_F(NVFuserTest, FusionSizeZeroSliceSplitSchedule_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

std::vector<int64_t> shape({8});

// concrete shapes to avoid dynamic Fusion
auto tv0 = makeContigConcreteTensor(shape);
fusion->addInput(tv0);

auto tv1 = slice(
tv0,
{{IrBuilder::create<Int>(0),
IrBuilder::create<Int>(2),
IrBuilder::create<Int>(1)}});
auto tv2 = slice(
tv0,
{{IrBuilder::create<Int>(2),
IrBuilder::create<Int>(4),
IrBuilder::create<Int>(1)}});
auto tv3 = slice(
tv0,
{{IrBuilder::create<Int>(4),
IrBuilder::create<Int>(6),
IrBuilder::create<Int>(1)}});
auto tv4 = slice(
tv0,
{{IrBuilder::create<Int>(6),
IrBuilder::create<Int>(6),
IrBuilder::create<Int>(1)}});
auto tv5 = slice(
tv0,
{{IrBuilder::create<Int>(6),
IrBuilder::create<Int>(6),
IrBuilder::create<Int>(1)}});
auto tv6 = slice(
tv0,
{{IrBuilder::create<Int>(6),
IrBuilder::create<Int>(8),
IrBuilder::create<Int>(1)}});
fusion->addOutput(tv1);
fusion->addOutput(tv2);
fusion->addOutput(tv3);
fusion->addOutput(tv4);
fusion->addOutput(tv5);
fusion->addOutput(tv6);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);

auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
FusionExecutor fe;

auto ref0 = t0.index({at::indexing::Slice(0, 2)});
auto ref1 = t0.index({at::indexing::Slice(0, 4)});

TORCH_CHECK(ref0.equal(cg_outputs[0]));
TORCH_CHECK(ref1.equal(cg_outputs[1]));
}

// In this test, we split and merge with size-zero dimensions directly.
TEST_F(NVFuserTest, FusionSizeZeroSliceSplit_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

std::vector<int64_t> shape({4, 5});

// concrete shapes to avoid dynamic Fusion
auto tv0 = makeContigConcreteTensor(shape);
fusion->addInput(tv0);

auto tv1 = slice(
tv0,
{{IrBuilder::create<Int>(2),
IrBuilder::create<Int>(2),
IrBuilder::create<Int>(1)},
{IrBuilder::create<Int>(0),
IrBuilder::create<Int>(5),
IrBuilder::create<Int>(1)}});
// tv1 is of shape {0, 5}
fusion->addOutput(tv1);

tv1->merge(0, 1); // size 0*5 = 0
tv1->split(0, 4); // sizes (0, 4)

FusionExecutor fe;
fe.compileFusion(fusion.get());

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);

auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});

auto cg_outputs = fe.runFusion(aten_inputs);

auto ref0 = t0.index({at::indexing::Slice(2, 2), at::indexing::Slice(0, 5)});

TORCH_CHECK(ref0.equal(cg_outputs[0]));
}

} // namespace nvfuser