From 4fb591eb612c9585207f8eb531b350874d4f63bb Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 28 Aug 2023 16:52:54 -0400 Subject: [PATCH 1/4] Add C++ repro --- test/test_resize.cpp | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 9baaa8d95f1..b3c58077746 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -483,6 +483,46 @@ TEST_F(ResizeTest, FusionResizePadScheduler4) { __FILE__); } +// Pad a broadcast +// See https://github.com/NVIDIA/Fuser/issues/798 +TEST_F(ResizeTest, FusionResizePadBroadcastInput) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // IterTypes are {Broadcast, Iteration} + auto tv0 = makeConcreteTensor({1, -1}); + fusion->addInput(tv0); + + // trivial pad of broadcast dimension + auto tv1 = + pad(tv0, + {fusion->oneVal(), + fusion->zeroVal(), + fusion->zeroVal(), + fusion->zeroVal()}); + fusion->addOutput(tv1); + + std::vector shape({1, 2}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = at::pad(t0, {1, 0, 0, 0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {t1}, + __LINE__, + __FILE__); +} + // Trivial cat TEST_F(ResizeTest, FusionResizeCat1) { Fusion fusion; From 3712607bfe37294b7049ace8dcfd1ead4280cf6e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 28 Aug 2023 16:53:08 -0400 Subject: [PATCH 2/4] Handle resize of broadcast in concretization. Avoids trivial resize --- csrc/dynamic_transform.cpp | 23 +++++++++++++++++------ csrc/ir/nodes.cpp | 10 +++++++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index c82520e8a11..80a37f42f47 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -629,18 +629,29 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { IterType iter_type = IterType::Symbolic; for (auto inp_id : ir_utils::filterByType(expr->inputs())) { auto updated_id = maybeMutated(inp_id)->as(); - iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); + // ops::promoteIterType will favor Symbolic if it encounters it + // alongside Broadcast. This is preferable at fusion definition, but + // here we are propagating, and if we only see Broadcast in some + // dimension, then we should not retain Symbolic. To work around this, + // we always overwrite Symbolic with the first concrete IterType we + // encounter. + iter_type = iter_type == IterType::Symbolic + ? updated_id->getIterType() + : ops::promoteIterType(iter_type, updated_id->getIterType()); } - TORCH_INTERNAL_ASSERT( - iter_type != IterType::Symbolic, - "Failed to concretize an output IterType for expression: ", - expr->toString()); - // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { if (!out_id->isSymbolic()) { continue; } + + // If out_id is Symbolic, we need to concretize it here. If we did not + // yet determine its IterType, then we've missed our chance. + TORCH_INTERNAL_ASSERT( + iter_type != IterType::Symbolic, + "Failed to concretize an output IterType for expression: ", + expr->toString()); + auto concretized_out_id = IterDomainBuilder(maybeMutated(out_id)->as()) .iter_type(iter_type) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index fa96cff28ba..1c72649b676 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2823,12 +2823,16 @@ IterDomain* IterDomain::resize( if (iter_type_opt.has_value()) { iter_type = iter_type_opt.value(); } else if (left_expansion->isConstInt() && right_expansion->isConstInt()) { - if (resized_id_size->isConstInt()) { + auto left = left_expansion->evaluateInt(); + auto right = right_expansion->evaluateInt(); + if (left == 0 && right == 0) { + // Trivial Resize + return in; + } else if (resized_id_size->isConstInt()) { // Means input extent is also known auto out_extent = resized_id_size->evaluateInt(); iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration; - } else if ( - left_expansion->evaluateInt() + right_expansion->evaluateInt() > 1) { + } else if (left + right > 1) { // Input extent is non-negative, so we know out_extent > 1 iter_type = IterType::Iteration; } From 328e0611d64b8d14689bc337b5f5db3f5de168c9 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 12:44:30 -0400 Subject: [PATCH 3/4] Move trivial resize check to beginning of resize() --- csrc/ir/nodes.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 1c72649b676..db38a2a7737 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2783,6 +2783,19 @@ IterDomain* IterDomain::resize( "Expansion factor must be an integer scalar: ", right_expansion->toString()); + if (left_expansion->isConstInt() && right_expansion->isConstInt()) { + auto left = left_expansion->evaluateInt(); + auto right = right_expansion->evaluateInt(); + if (left == 0 && right == 0) { + // This is a trivial resize. Check that we are not changing the IterType, + // then return the input. + TORCH_CHECK( + !iter_type_opt.has_value() || + iter_type_opt.value() == in->getIterType(), + "If IterType is specified in pad with zero expansion then it must match input"); + return in; + } + } TORCH_CHECK( in->getIterType() == IterType::Iteration || in->getIterType() == IterType::Broadcast || @@ -2825,10 +2838,7 @@ IterDomain* IterDomain::resize( } else if (left_expansion->isConstInt() && right_expansion->isConstInt()) { auto left = left_expansion->evaluateInt(); auto right = right_expansion->evaluateInt(); - if (left == 0 && right == 0) { - // Trivial Resize - return in; - } else if (resized_id_size->isConstInt()) { + if (resized_id_size->isConstInt()) { // Means input extent is also known auto out_extent = resized_id_size->evaluateInt(); iter_type = out_extent == 1 ? IterType::Broadcast : IterType::Iteration; From 1425ea81bbe2eaf3f3b1d5c730a694c0ba7a8612 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 29 Aug 2023 13:13:40 -0400 Subject: [PATCH 4/4] Check that updated (root, rfactor] domains are nonsymbolic --- csrc/dynamic_transform.cpp | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 80a37f42f47..404506ac3f8 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -627,17 +627,26 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Determine the output IterType IterType iter_type = IterType::Symbolic; - for (auto inp_id : ir_utils::filterByType(expr->inputs())) { + const auto input_ids = + ir_utils::filterByType(expr->inputs()).vector(); + for (auto i : c10::irange(input_ids.size())) { + auto inp_id = input_ids.at(i); auto updated_id = maybeMutated(inp_id)->as(); - // ops::promoteIterType will favor Symbolic if it encounters it - // alongside Broadcast. This is preferable at fusion definition, but - // here we are propagating, and if we only see Broadcast in some - // dimension, then we should not retain Symbolic. To work around this, - // we always overwrite Symbolic with the first concrete IterType we - // encounter. - iter_type = iter_type == IterType::Symbolic - ? updated_id->getIterType() - : ops::promoteIterType(iter_type, updated_id->getIterType()); + TORCH_CHECK( + updated_id == inp_id || !updated_id->isSymbolic(), + "Mutated IterDomains between root and rfactor should not be symbolic"); + if (i == 0) { + // ops::promoteIterType will favor Symbolic if it encounters it + // alongside Broadcast. This is preferable at fusion definition, but + // here we are propagating, and if we only see Broadcast in some + // dimension, then we should not retain Symbolic. To work around this, + // we always overwrite Symbolic with the first concrete IterType we + // encounter. + iter_type = updated_id->getIterType(); + } else { + iter_type = + ops::promoteIterType(iter_type, updated_id->getIterType()); + } } // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) {