Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,9 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {

// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
if (!out_id->isSymbolic()) {
continue;
}
auto concretized_out_id =
IterDomainBuilder(out_id).iter_type(iter_type).build();
registerConcretization(out_id, concretized_out_id);
Expand Down
31 changes: 31 additions & 0 deletions test/test_dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,4 +1001,35 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) {
reductionDynamicPadAddFusion(invocations);
}

// Test that a Symbolic root/Broadcast rfactor is not concretized to
// Iteration/Iteration
TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(fusion_ptr.get());
auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);
// tv0[:2] introduces symbolic IterDomain
auto tv1 = slice(
tv0, {{fusion.zeroVal(), IrBuilder::create<Int>(2), fusion.oneVal()}});
// tv1 has Broadcast rfactor, Iteration root
auto tv2 = slice(tv1, {{fusion.zeroVal(), fusion.oneVal(), fusion.oneVal()}});
// tv2 has a Symbolic root related to a Broadcast rfactor through a Resize op
fusion.addOutput(tv2);

// At concretization, tv1's rfactor will be set to Iteration, which will
// propagate to tv2s root. This test will test that when tv2 root is
// concretized to Iteration, it does not wind up overwriting the Broadcast
// rfactor.

FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr));
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor at0 = at::randn({5}, options);
std::vector<c10::IValue> aten_inputs = {at0};
auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs);
auto at1 = at::slice(at0, 0, 0, 2);
auto at2 = at::slice(at1, 0, 0, 1);
testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__);
}

} // namespace nvfuser