diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index c82520e8a11..404506ac3f8 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -627,20 +627,40 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Determine the output IterType IterType iter_type = IterType::Symbolic; - for (auto inp_id : ir_utils::filterByType(expr->inputs())) { + const auto input_ids = + ir_utils::filterByType(expr->inputs()).vector(); + for (auto i : c10::irange(input_ids.size())) { + auto inp_id = input_ids.at(i); auto updated_id = maybeMutated(inp_id)->as(); - iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); + TORCH_CHECK( + updated_id == inp_id || !updated_id->isSymbolic(), + "Mutated IterDomains between root and rfactor should not be symbolic"); + if (i == 0) { + // ops::promoteIterType will favor Symbolic if it encounters it + // alongside Broadcast. This is preferable at fusion definition, but + // here we are propagating, and if we only see Broadcast in some + // dimension, then we should not retain Symbolic. To work around this, + // we always overwrite Symbolic with the first concrete IterType we + // encounter. + iter_type = updated_id->getIterType(); + } else { + iter_type = + ops::promoteIterType(iter_type, updated_id->getIterType()); + } } - TORCH_INTERNAL_ASSERT( - iter_type != IterType::Symbolic, - "Failed to concretize an output IterType for expression: ", - expr->toString()); - // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { if (!out_id->isSymbolic()) { continue; } + + // If out_id is Symbolic, we need to concretize it here. If we did not + // yet determine its IterType, then we've missed our chance. + TORCH_INTERNAL_ASSERT( + iter_type != IterType::Symbolic, + "Failed to concretize an output IterType for expression: ", + expr->toString()); + auto concretized_out_id = IterDomainBuilder(maybeMutated(out_id)->as()) .iter_type(iter_type) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index fa96cff28ba..db38a2a7737 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2783,6 +2783,19 @@ IterDomain* IterDomain::resize( "Expansion factor must be an integer scalar: ", right_expansion->toString()); + if (left_expansion->isConstInt() && right_expansion->isConstInt()) { + auto left = left_expansion->evaluateInt(); + auto right = right_expansion->evaluateInt(); + if (left == 0 && right == 0) { + // This is a trivial resize. Check that we are not changing the IterType, + // then return the input. + TORCH_CHECK( + !iter_type_opt.has_value() || + iter_type_opt.value() == in->getIterType(), + "If IterType is specified in pad with zero expansion then it must match input"); + return in; + } + } TORCH_CHECK( in->getIterType() == IterType::Iteration || in->getIterType() == IterType::Broadcast || @@ -2823,12 +2836,13 @@ IterDomain* IterDomain::resize( if (iter_type_opt.has_value()) { iter_type = iter_type_opt.value(); } else if (left_expansion->isConstInt() && right_expansion->isConstInt()) { + auto left = left_expansion->evaluateInt(); + auto right = right_expansion->evaluateInt(); if (resized_id_size->isConstInt()) { // Means input extent is also known auto out_extent = resized_id_size->evaluateInt(); iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration; - } else if ( - left_expansion->evaluateInt() + right_expansion->evaluateInt() > 1) { + } else if (left + right > 1) { // Input extent is non-negative, so we know out_extent > 1 iter_type = IterType::Iteration; } diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 9baaa8d95f1..b3c58077746 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -483,6 +483,46 @@ TEST_F(ResizeTest, FusionResizePadScheduler4) { __FILE__); } +// Pad a broadcast +// See https://github.com/NVIDIA/Fuser/issues/798 +TEST_F(ResizeTest, FusionResizePadBroadcastInput) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // IterTypes are {Broadcast, Iteration} + auto tv0 = makeConcreteTensor({1, -1}); + fusion->addInput(tv0); + + // trivial pad of broadcast dimension + auto tv1 = + pad(tv0, + {fusion->oneVal(), + fusion->zeroVal(), + fusion->zeroVal(), + fusion->zeroVal()}); + fusion->addOutput(tv1); + + std::vector shape({1, 2}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = at::pad(t0, {1, 0, 0, 0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t1}, + __LINE__, + __FILE__); +} + // Trivial cat TEST_F(ResizeTest, FusionResizeCat1) { Fusion fusion;