diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 7fcbb0fd42d..37729c63739 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -52,6 +52,30 @@ std::unordered_set ConcretizedBroadcastDomains:: return {}; } +// In some cases an op like pad or slice will introduce a broadcast domain by +// truncating a longer dimension or expanding an empty dimension to size 1. In +// these cases tv will have RFactor Broadcast IterDomains that are not present +// in the root domain. Contrast this with BroadcastOp, whose output does not +// have RFactor domains and instead places new broadcast domains in the output +// root domain. +void ConcretizedBroadcastDomains::handle(TensorView* tv) { + if (!tv->hasRFactor()) { + return; + } + for (auto id : tv->getMaybeRFactorDomain()) { + // Register broadcast rfactor domains that are not root domains as new + // broadcast origins. + if (id->isBroadcast() && + std::find(tv->getRootDomain().begin(), tv->getRootDomain().end(), id) == + tv->getRootDomain().end()) { + broadcast_origin_map_.emplace(id, std::unordered_set({id})); + } + } +} + +// Most broadcasts are handled with this method, since Broadcast domains are +// usually introduced through a BroadcastOp. Others are handled by the +// handle(TensorView*) method. void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) { // Create a new entry for each of new broadcast domains auto out = bop->out()->as(); diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index a9a3598ed72..a952df7aaff 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -46,6 +46,8 @@ class ConcretizedBroadcastDomains : private IterVisitor { private: using IterVisitor::handle; + void handle(TensorView* tv) final; + void handle(BroadcastOp* bop) final; void dispatch(Expr* expr) final; diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index b4f205ca487..3016ebcc682 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -678,7 +678,10 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { } // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { - if (!out_id->isSymbolic()) { + auto mut_id = maybeMutated(out_id)->as(); + if (!mut_id->isSymbolic()) { + // We are only concretizing IterType here, so if we have already + // concretized the iter_type for this ID, we can skip this. continue; } @@ -690,9 +693,7 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { expr->toString()); auto concretized_out_id = - IterDomainBuilder(maybeMutated(out_id)->as()) - .iter_type(iter_type) - .build(); + IterDomainBuilder(mut_id).iter_type(iter_type).build(); registerConcretization(out_id, concretized_out_id); } @@ -794,6 +795,23 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( auto def = consumer->definition(); + // We will loop over IterDomains in the consumer root. For each, we need to + // inspect the consumer to producer map to all producers. Instead of + // recomputing these for each root IterDomain, we precompute them for each + // producer here then re-use them in the following loop. + std::vector> c2p_maps; + for (auto producer : ir_utils::filterByType(def->inputs())) { + PairwiseRootDomainMap root_map(producer, consumer); + // We map symbolic domains here regardless of whether their extents match. + // This is safe because we are propagating from a producer which should have + // already been concretized. The consumer might have a different extent + // which will be equivalent to (but not necessarily sameAs) the producer's, + // and we just want to use its IterType to concretize the consumer ID. + root_map.mapSymbolic(true); + c2p_maps.push_back( + root_map.mapConsumerToProducer(consumer->domain(), producer->domain())); + } + bool is_concretized = false; for (const auto i : c10::irange(root_domain.size())) { @@ -807,17 +825,13 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( std::optional id_type; - for (auto producer : ir_utils::filterByType(def->inputs())) { - PairwiseRootDomainMap root_map(producer, consumer); - auto c2p = root_map.mapConsumerToProducer( - consumer->domain(), producer->domain()); - + for (const auto& c2p : c2p_maps) { + auto p_it = c2p.find(root_id); NVF_ERROR( - c2p.find(root_id) != c2p.end(), + p_it != c2p.end(), "No input ID found to map with output ID: ", root_id->toString()); - - auto input_id = c2p.at(root_id); + auto input_id = p_it->second; NVF_ERROR( input_id == maybeMutated(input_id), "Consumer IterDomain ", diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index b5c8ed71b12..e6dce264fc1 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -259,6 +259,9 @@ void ExpressionEvaluator::print() const { } void ExpressionEvaluator::propagateBoundValuesThroughExactMaps(Fusion* fusion) { + // We map Symbolic IterDomains here only if their extents match. This avoids + // mapping between symbolic domains that might concretize to an (Iteration, + // Broadcast) pair from a resolved broadcast. const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); for (const auto& set : mapped_sets.disjointSets()) { diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 62a60fd631b..a0bab3cf35e 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -132,6 +132,7 @@ std::unordered_map PairwiseRootDomainMap::map( // domains of torch_gather) // 3. Squeeze and unsqueeze // 4. Broadcast and non broadcast + // 5. Symbolic ID with different extent from other ID // Condition 1: when the producer ID is the dim of a select-like op if (producer_id == indexed_producer_id) { @@ -182,6 +183,27 @@ std::unordered_map PairwiseRootDomainMap::map( continue; } + // Condition 5 + // At least one ID is symbolic. + // + // If map_symbolic_ is true: + // Map these IDs regardless of other considerations. + // + // If map_symbolic_ is false (default): + // Map these only if their extents are identical. IterType::Symbolic + // reflects that the extent might evaluate to 1 for some inputs, in which + // case it may be valid to use those domains in a broadcast op. If the + // extents are exactly the same between two aligned IterDomains, the + // Symbolic one will be concretized to the same IterType as the other, so + // they should be mapped with one another. + if (!map_symbolic_ && + (producer_id->isSymbolic() || consumer_id->isSymbolic()) && + (!producer_id->extent()->sameAs(consumer_id->extent()))) { + itc++; + itp++; + continue; + } + IterDomain* map_key_id = producer_id; IterDomain* map_value_id = consumer_id; if (!producer_to_consumer) { @@ -1185,7 +1207,14 @@ void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) { if (root_set.find(id) == root_set.end() || rf_id == id) { continue; } - setMaybeMapped(td, id, td, rf_id); + // Usually, the itertypes between IterDomain expression inputs and + // outputs will match. However, it is possible for a Resize operation to + // take an Iteration input and reduce it to size 1, after which it + // becomes Broadcast. This check avoids mapping an Iteration and + // Broadcast domain in such a case. + if (id->getIterType() == rf_id->getIterType()) { + setMaybeMapped(td, id, td, rf_id); + } } } // Once mappings for rfactor axes are propagated to root axes, diff --git a/csrc/root_domain_map.h b/csrc/root_domain_map.h index 0fbe983f7ab..40b5f9f434c 100644 --- a/csrc/root_domain_map.h +++ b/csrc/root_domain_map.h @@ -101,6 +101,14 @@ class PairwiseRootDomainMap : public RootDomainMap { return *this; } + //! If b is true: map symbolic domains with other IterDomains even if their + //! extents don't match. If b is false (default): map symbolic domains with + //! other IterDomains only if their extents match. + PairwiseRootDomainMap& mapSymbolic(bool b) { + map_symbolic_ = b; + return *this; + } + PairwiseRootDomainMap& mapDifferentExtents(bool b) { map_different_extents_ = b; return *this; @@ -137,6 +145,10 @@ class PairwiseRootDomainMap : public RootDomainMap { //! Map broadcast and non-broadcast domains. Note that this is on by //! default bool map_broadcast_ = true; + //! Map symbolic domains with other IterDomains, even if their extents don't + //! match. Note that this is off by default, in which case they are mapped + //! only if their extents match. + bool map_symbolic_ = false; //! Map domains that may have different extents, e.g., torch_gather bool map_different_extents_ = false; //! Map domains that are indirectly accessed, e.g., index_select diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 92e7efdaae4..eab238a3d4c 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -63,6 +63,10 @@ TEST_F(NVFuserTest, DynamicTransform1_CUDA) { expr_eval.bind(tv0->axis(1)->extent(), 3L); expr_eval.bind(reshape_shape0, 3L); expr_eval.bind(reshape_shape1, 4L); + // We cannot infer the shape of tv1 from the above bound values, since + // either axis of tv2 might be broadcast against one from tv1. + expr_eval.bind(tv1->axis(0)->extent(), 3L); + expr_eval.bind(tv1->axis(1)->extent(), 4L); auto initial_info = DynamicTransform::getInitialInfo(&fusion); auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); @@ -187,6 +191,11 @@ TEST_F(NVFuserTest, DynamicTransform3_CUDA) { expr_eval.bind(tv0->axis(1)->extent(), shape_before.at(1)); expr_eval.bind(tv1->axis(0)->extent(), shape_after.at(0)); expr_eval.bind(tv1->axis(1)->extent(), shape_after.at(1)); + // We cannot infer reshape_shape0 and reshape_shape1 from tv0's and tv1's + // extents alone, since either of these reshaped extents could either match + // that of tv1 or be 1, resulting in a broadcast. + expr_eval.bind(reshape_shape0, shape_after.at(0)); + expr_eval.bind(reshape_shape1, shape_after.at(1)); auto initial_info = DynamicTransform::getInitialInfo(&fusion); auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); @@ -251,6 +260,13 @@ TEST_F(NVFuserTest, DynamicTransform4_CUDA) { for (const auto i : c10::irange(after_shape.size())) { expr_eval.bind(tv2->axis((int)i)->extent(), after_shape.at(i)); + // We must bind tv1's extents, since they cannot be inferred until after + // concretization. Because tv2 is a dynamic reshape both its IterDomains + // are Symbolic, which means both of tv3's IterDomains are also Symbolic. + // tv1 has both IterDomains of type Iteration, but it since we add tv3 to + // it to get tv4, we do not know whether this will resolve broadcasts from + // tv3 or not until concretization. + expr_eval.bind(tv1->axis((int)i)->extent(), after_shape.at(i)); } auto initial_info = DynamicTransform::getInitialInfo(&fusion); diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 1fec7dbcb5b..cf052605f60 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2281,6 +2281,200 @@ TEST_F(ResizeTest, SliceVectorization) { testValidate(&fusion, cg_outputs, inputs, {ref}, __LINE__, __FILE__); } +// Concretize a symbolic pad that results in a broadcast (static pads) +// In this test, the sizes and pad widths are static, so there should be nothing +// to concretize. +TEST_F(NVFuserTest, ResizePadToBroadcastStatic_CUDA) { + std::vector t0_size = {2, 3, 2, 5, 6}; + std::vector t1_size = {2, 4, 4, 3, 5}; + // Note there are only 8 input scalars for 5D input. Implicit no-pad of dim 0 + std::vector pad_widths = { + 0, + -1, // dim=4 trim last element + 0, + -4, // dim=3 pad to broadcast of first element + 1, + 1, // dim=2 pad with zeros on either side + -1, + -1, // dim=1 pad to broadcast of second element + // dim=0 is implicit 0, 0 + }; + std::vector expected_itertypes = { + IterType::Iteration, + IterType::Broadcast, + IterType::Iteration, + IterType::Broadcast, + IterType::Iteration, + }; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor(t0_size); + fusion->addInput(tv0); + auto tv1 = makeConcreteTensor(t1_size); + fusion->addInput(tv1); + + std::vector pad_width_vals; + pad_width_vals.reserve(pad_widths.size()); + for (auto w : pad_widths) { + pad_width_vals.push_back(IrBuilder::create(w)); + } + + auto tv2 = pad(tv0, pad_width_vals); + auto tv3 = mul(tv1, tv2); + fusion->addOutput(tv3); + + EXPECT_FALSE(fusion->hasDynamicTransform()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn(t0_size, options); + auto t1 = at::randn(t1_size, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + auto concretized_fusion = runtime->fusionSegments()->completeFusion(); + + auto conc_t2 = concretized_fusion->outputs()[0] + ->definition() + ->inputs()[1] + ->as(); + for (auto i : c10::irange(expected_itertypes.size())) { + EXPECT_EQ(conc_t2->axis(i)->getIterType(), expected_itertypes.at(i)); + } + + auto t2_padded = at::pad(t0, pad_widths); + auto ref_t2 = t1 * t2_padded; + + testValidate( + concretized_fusion, + cg_outputs, + aten_inputs, + {ref_t2}, + __LINE__, + __FILE__); +} + +// Concretize a symbolic pad that results in a broadcast (dynamic pads) +TEST_F(NVFuserTest, ResizePadToBroadcastDynamic_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(5); + fusion->addInput(tv0); + auto tv1 = makeSymbolicTensor(5); + fusion->addInput(tv1); + + // Note there are only 8 input scalars for 5D input. Implicit no-pad of dim 0 + std::vector pad_widths = { + 0, + -1, // dim=4 trim last element + 0, + -4, // dim=3 pad to broadcast of first element + 1, + 1, // dim=2 pad with zeros on either side + -1, + -1, // dim=1 pad to broadcast of second element + // dim=0 is implicit 0, 0 + }; + std::vector pad_width_vals; + pad_width_vals.reserve(pad_widths.size()); + for ([[maybe_unused]] auto _ : pad_widths) { + auto w_val = IrBuilder::create(DataType::Int); + fusion->addInput(w_val); + pad_width_vals.push_back(w_val); + } + + auto tv2 = pad(tv0, pad_width_vals); + auto tv3 = mul(tv1, tv2); + fusion->addOutput(tv3); + + EXPECT_TRUE(fusion->hasDynamicTransform()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn({2, 3, 2, 5, 6}, options); + auto t1 = at::randn({2, 4, 4, 3, 5}, options); + // Keep dimension 0, pad to broadcast in dimension 1 and 3. Pad with zero in + // dimension 2. Trim by one element in dimension 4. + std::vector aten_inputs({ + t0, + t1, + }); + aten_inputs.insert(aten_inputs.end(), pad_widths.begin(), pad_widths.end()); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto runtime = executor_cache.getMostRecentKernelRuntime(); + auto concretized_fusion = runtime->fusionSegments()->completeFusion(); + + auto conc_t2 = concretized_fusion->outputs()[0] + ->definition() + ->inputs()[1] + ->as(); + EXPECT_EQ(conc_t2->axis(0)->getIterType(), IterType::Iteration); + EXPECT_EQ(conc_t2->axis(1)->getIterType(), IterType::Broadcast); + EXPECT_EQ(conc_t2->axis(2)->getIterType(), IterType::Iteration); + EXPECT_EQ(conc_t2->axis(3)->getIterType(), IterType::Broadcast); + EXPECT_EQ(conc_t2->axis(4)->getIterType(), IterType::Iteration); + + auto t2_padded = at::pad(t0, pad_widths); + auto ref_t2 = t1 * t2_padded; + + testValidate( + concretized_fusion, + cg_outputs, + aten_inputs, + {ref_t2}, + __LINE__, + __FILE__); +} + +// See https://github.com/NVIDIA/Fuser/issues/596 +TEST_F(NVFuserTest, ResizePadToBroadcastIssue596_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({2}); + auto tv1 = makeConcreteTensor({3}); + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = pad(tv0, {fusion->zeroVal(), IrBuilder::create(-1)}); + auto tv3 = mul(tv1, tv2); + fusion->addOutput(tv3); + + // Fusion is not dynamic + EXPECT_FALSE(fusion->hasDynamicTransform()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto t0 = at::randn({2}, options); + auto t1 = at::randn({3}, options); + std::vector aten_inputs({t0, t1}); + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion), args); + runtime.compileFusionParallel(args); + auto cg_outputs = runtime.runWithInputs(args); + + auto t2_padded = at::pad(t0, {0, -1}); + auto ref_t2 = t1 * t2_padded; + + testValidate( + runtime.fusionSegments()->completeFusion(), + cg_outputs, + aten_inputs, + {ref_t2}, + __LINE__, + __FILE__); +} + // An input is sliced and then reshaped TEST_F(ResizeTest, SliceAndReshape1) { auto fusion_ptr = std::make_unique();