diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index f838e07a1aa..df0ad505649 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -542,6 +542,9 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { + if (!out_id->isSymbolic()) { + continue; + } auto concretized_out_id = IterDomainBuilder(out_id).iter_type(iter_type).build(); registerConcretization(out_id, concretized_out_id); diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index dc0201d0833..78e23e41040 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1001,4 +1001,35 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { reductionDynamicPadAddFusion(invocations); } +// Test that a Symbolic root/Broadcast rfactor is not concretized to +// Iteration/Iteration +TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + // tv0[:2] introduces symbolic IterDomain + auto tv1 = slice( + tv0, {{fusion.zeroVal(), IrBuilder::create(2), fusion.oneVal()}}); + // tv1 has Broadcast rfactor, Iteration root + auto tv2 = slice(tv1, {{fusion.zeroVal(), fusion.oneVal(), fusion.oneVal()}}); + // tv2 has a Symbolic root related to a Broadcast rfactor through a Resize op + fusion.addOutput(tv2); + + // At concretization, tv1's rfactor will be set to Iteration, which will + // propagate to tv2s root. This test will test that when tv2 root is + // concretized to Iteration, it does not wind up overwriting the Broadcast + // rfactor. + + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({5}, options); + std::vector aten_inputs = {at0}; + auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + auto at1 = at::slice(at0, 0, 0, 2); + auto at2 = at::slice(at1, 0, 0, 1); + testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); +} + } // namespace nvfuser