diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 6993699632b..677ac93e834 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -64,6 +64,53 @@ void ConcretizedBroadcastDomains::handle(BroadcastOp* bop) { } } +void ConcretizedBroadcastDomains::handle(CatOp* op) { + auto id = + op->out()->as()->getMaybeRFactorDomain().at(op->concatenatedDim()); + if (id->isBroadcast()) { + broadcast_origin_map_.emplace(id, std::unordered_set({id})); + } +} + +void ConcretizedBroadcastDomains::handle(PadOp* op) { + for (auto i : op->getPaddedAxes()) { + // Instead of the root domain of the output, as with BroadcastOp, we set the + // origin as the RFactor domain, since PadOp inserts Resize ops between root + // and rfactor + auto id = op->out()->as()->getMaybeRFactorDomain().at(i); + if (id->isBroadcast()) { + broadcast_origin_map_.emplace(id, std::unordered_set({id})); + } + } +} + +void ConcretizedBroadcastDomains::handle(SliceOp* op) { + auto consumer_root = op->out()->as()->getMaybeRFactorDomain(); + auto producer_rfactor = TensorDomain::noReductions( + op->in()->as()->getMaybeRFactorDomain()); + TORCH_INTERNAL_ASSERT( + consumer_root.size() == producer_rfactor.size(), + "Consumer root size ", + consumer_root.size(), + " does not match producer rfactor size ", + producer_rfactor.size()); + for (auto i : c10::irange(consumer_root.size())) { + auto cid = consumer_root.at(i); + auto pid = producer_rfactor.at(i); + if (cid->isBroadcast()) { + // Map to producer ID if it was already broadcast. Otherwise to consumer + // ID + if (pid->isBroadcast()) { + broadcast_origin_map_.emplace( + pid, std::unordered_set({cid, pid})); + } else { + broadcast_origin_map_.emplace( + cid, std::unordered_set({cid})); + } + } + } +} + void ConcretizedBroadcastDomains::handle(Expr* expr) { IterVisitor::handle(expr); diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index 841b23c501f..ab94641d40a 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -47,6 +47,13 @@ class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor { void handle(BroadcastOp* bop) final; + // After concretization, ops with Resized IterDomains in their outputs may set + // the broadcast flag, even though they are not BroadcastOps themselves. In + // these cases, we set the output as the origin. + void handle(CatOp* op) final; + void handle(PadOp* op) final; + void handle(SliceOp* op) final; + void handle(Expr* expr) final; void markAsConcretized( diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index df0ad505649..f56691c2f14 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -542,7 +542,10 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { - if (!out_id->isSymbolic()) { + if (!out_id->isSymbolic() || + mutations_.find(out_id) != mutations_.end()) { + // Skip symbolic outputs and outputs that have already been registered + // for mutation continue; } auto concretized_out_id = @@ -644,6 +647,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( for (auto producer : ir_utils::filterByType(def->inputs())) { PairwiseRootDomainMap root_map(producer, consumer); + root_map.mapSymbolic(true); auto c2p = root_map.mapConsumerToProducer( consumer->domain(), producer->domain()); diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 857f8ba792c..56439c11f34 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1863,6 +1863,10 @@ class TORCH_CUDA_CU_API CatOp : public Expr { return attribute(0)->as>()->value; } + Val* out() const { + return output(0); + } + //! The index val that determines which input tensor should be used //! to fill the particular output position of this expression. Only //! valid after indexing diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 43b7ba18206..c89485009b2 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -223,9 +223,10 @@ std::vector newOutputDomain( std::vector start_offsets(out_domain.size(), 0); std::vector stop_offsets(out_domain.size(), 0); std::vector extent_vals(out_domain.size(), nullptr); + std::vector mismatched_symbolic_extents(out_domain.size(), false); std::vector expanded_extent_vals(out_domain.size(), nullptr); - std::vector> iter_types( - out_domain.size(), c10::nullopt); + std::vector> iter_types( + out_domain.size(), std::nullopt); for (auto tv : tvs) { auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); @@ -235,7 +236,62 @@ std::vector newOutputDomain( dom.size(), " dimensions but expected ", out_domain.size()); + // If there is any Iteration domain, we should use the first one's + // extent. + // + // If all inputs are Symbolic or Broadcast, then we can use the + // symbolic extent if all the symbolic extents agree. + // + // Otherwise, we don't know the output extent and iter_type should be + // Symbolic if there are any Symbolic inputs else Broadcast. for (const auto i : c10::irange(dom.size())) { + auto iter_type = dom[i]->getIterType(); + auto prev_iter_type = iter_types[i]; + if (prev_iter_type.has_value()) { + // Clang-tidy complains about unchecked access to optional value here + if (iter_type == IterType::Iteration && + prev_iter_type.value() == IterType::Symbolic) { + // Prefer the Iteration extent, since Symbolic could be broadcast + extent_vals[i] = nullptr; + } else if (iter_type == IterType::Symbolic) { + switch (prev_iter_type.value()) { + case IterType::Iteration: + // Previously found Iteration domain, so ignore all Symbolic + // domains + continue; + case IterType::Symbolic: + if (extent_vals[i]->sameAs(dom[i]->extent())) { + // matching symbolic extent + continue; + } else { + // Mismatched symbolic input extents. Any one of the symbolic + // inputs could be a Broadcast or Iteration domain. Until + // concretization, we will not know which one holds the true + // extent (or whether they all are Broadcast, so that the output + // is also Broadcast). We record that these symbolic extents + // mismatched so that we can introduce a new symbolic extent + // later. + mismatched_symbolic_extents[i] = true; + } + break; + case IterType::Broadcast: + // Previously found only broadcast, so this will either also + // broadcast or resolve those broadcasts. If the expanded + // extent of any of the broadcasts is not 1, then it will need to + // match that of the dom[i]. In either case, prefer dom[i]'s + // extent, so clear iter_types[i] and extent_vals[i] so that the + // rest of this iteration will mark output as Symbolic. + iter_types[i] = std::nullopt; + extent_vals[i] = nullptr; + break; + default: + TORCH_CHECK( + false, + "Encountered unexpected IterType when creating new output domain: ", + prev_iter_type.value()); + } + } + } if (dom[i]->isBroadcast()) { if (dom[i]->hasExpandedExtent()) { expanded_extent_vals[i] = @@ -244,9 +300,9 @@ std::vector newOutputDomain( continue; } extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent()); - if (iter_types[i].has_value()) { + if (prev_iter_type.has_value()) { iter_types[i] = - promoteIterType(iter_types[i].value(), dom[i]->getIterType()); + promoteIterType(prev_iter_type.value(), dom[i]->getIterType()); } else { iter_types[i] = dom[i]->getIterType(); } @@ -268,15 +324,21 @@ std::vector newOutputDomain( } } for (const auto dim_i : c10::irange(out_domain.size())) { + auto iter_type = iter_types[dim_i]; + if (iter_type == IterType::Symbolic && mismatched_symbolic_extents[dim_i]) { + // if we have a symbolic output but the input symbolic extents did not + // match, create a new extent + extent_vals[dim_i] = nullptr; + } if (extent_vals[dim_i] != nullptr) { TORCH_INTERNAL_ASSERT( - iter_types[dim_i].has_value(), + iter_type.has_value(), "Could not deduce iter type for new tensor view."); out_domain[dim_i] = IterDomainBuilder( IrBuilder::create(start_offsets[dim_i]), extent_vals[dim_i]) .stop_offset(IrBuilder::create(stop_offsets[dim_i])) - .iter_type(iter_types[dim_i].value()) + .iter_type(iter_type.value()) .build(); } else { out_domain[dim_i] = IterDomainBuilder( diff --git a/csrc/root_domain_map.cpp b/csrc/root_domain_map.cpp index 19a40803e23..aad690614be 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/root_domain_map.cpp @@ -131,6 +131,7 @@ std::unordered_map PairwiseRootDomainMap::map( // domains of torch_gather) // 3. Squeeze and unsqueeze // 4. Broadcast and non broadcast + // 5. Symbolic IDs // Condition 1: when the producer ID is the dim of a select-like op if (producer_id == indexed_producer_id) { @@ -181,6 +182,15 @@ std::unordered_map PairwiseRootDomainMap::map( continue; } + // Condition 5 + if (!map_symbolic_ && + (producer_id->getIterType() == IterType::Symbolic || + consumer_id->getIterType() == IterType::Symbolic)) { + itc++; + itp++; + continue; + } + IterDomain* map_key_id = producer_id; IterDomain* map_value_id = consumer_id; if (!producer_to_consumer) { @@ -861,7 +871,9 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( } if (consumer_id->isBroadcast()) { - TORCH_INTERNAL_ASSERT(producer_id->isBroadcast()); + // Note that consumer may be broadcast even though producer is not if it is + // the output of a Resize op. + // Get bcast_map_ entry for consumer_id const auto consumer_bcast_domains = root_map_.getConcretizedKeys(consumer_td, consumer_id); diff --git a/csrc/root_domain_map.h b/csrc/root_domain_map.h index a3e7a30e977..ff0728b7e33 100644 --- a/csrc/root_domain_map.h +++ b/csrc/root_domain_map.h @@ -100,6 +100,11 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { return *this; } + PairwiseRootDomainMap& mapSymbolic(bool b) { + map_symbolic_ = b; + return *this; + } + PairwiseRootDomainMap& mapDifferentExtents(bool b) { map_different_extents_ = b; return *this; @@ -136,6 +141,8 @@ class TORCH_CUDA_CU_API PairwiseRootDomainMap : public RootDomainMap { //! Map broadcast and non-broadcast domains. Note that this is on by //! default bool map_broadcast_ = true; + //! Map symbolic domains with one another. + bool map_symbolic_ = false; //! Map domains that may have different extents, e.g., torch_gather bool map_different_extents_ = false; //! Map domains that are indirectly accessed, e.g., index_select diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 78e23e41040..ad8e1272dea 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1001,6 +1001,67 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { reductionDynamicPadAddFusion(invocations); } +// Test dynamic pad followed by broadcast resolution +TEST_F(NVFuserTest, DynamicPadBroadcast_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + // 2d axis order here is YX + auto ypad = IrBuilder::create(); + fusion.addInput(ypad); + auto xpad = IrBuilder::create(); + fusion.addInput(xpad); + + // two-way resizes to cut square tv down to broadcastable size in each axis + auto tv0_pad = pad(tv0, {fusion.zeroVal(), xpad, fusion.zeroVal(), ypad}); + + // This will potentially resolve the y or x broadcast + auto p = mul(tv0_pad, tv1); + fusion.addOutput(p); + fusion.addOutput(tv0_pad); + + fusion.printMath(); + + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn({5, 5}, options); + at::Tensor at_y = at::randn({5, 5}, options); + + // trivial resize + std::vector aten_inputs({at_x, at_y, 0, 0}); + std::vector outputs; + + /* + aten_inputs[2] = 0; + aten_inputs[3] = 0; + outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + testValidate(fusion_executor_cache.fusion(), outputs, aten_inputs, {at_x * + at_y}, __LINE__, __FILE__); + */ + + // shrink first axis + aten_inputs[2] = -4; + aten_inputs[3] = 0; + outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + std::cout << outputs << std::endl; + std::cout << at_x.slice(0, 0, 1) * at_y << std::endl; + std::cout << at_x.slice(0, 0, 1) << std::endl; + testValidate( + fusion_executor_cache.fusion(), + outputs, + aten_inputs, + {at_x.slice(0, 0, 1) * at_y, at_x.slice(0, 0, 1)}, + __LINE__, + __FILE__); +} + // Test that a Symbolic root/Broadcast rfactor is not concretized to // Iteration/Iteration TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) {