diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 90d7dce7b27..f1a9f075363 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -357,16 +357,19 @@ class DynamicTransformConcretizer : public OptOutMutator { using OptOutMutator::mutate; - void mutate(TensorView* tv) final; - - void mutate(TensorDomain* td) final; - - //! Concretizes the root domain of a symbolic consumer tensor from - //! its producer domains. Returns true if any root ID is concretized. - bool propagateFromProducerToConsumer(TensorView* consumer); + void mutate(IterDomain* id) final; private: const DynamicTransformConcretizationInfo* info_; + + //! This map is used during concretization to identify, for a given IterDomain + //! the set of all IterDomains which are "aligned" with it in some TensorView + //! expression. This enables us to write mutate(IterDomain*) and propagate + //! information from producer IterDomains to consumers, which is otherwise not + //! represented in the graph since we do not connect IterDomains between + //! TensorViews with expressions. + std::unordered_map> + id_producers_; }; void DynamicTransformConcretizer::concretize() { @@ -376,13 +379,76 @@ void DynamicTransformConcretizer::concretize() { // Set output IterTypes for dynamic resize ops concretizeResize(); - // Finally, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_->fusion(), true); - for (auto stmt : all_stmts) { - if (stmt->isA()) { - mutate(stmt); + // This fixes the set of statements we will process over the course of + // multiple passes. + // Since we will traverse these Statements after some have been removed, we + // will not be able to safely check the types of each Statement in later + // loops. To avoid segfaults, we first split all_statements into subsets for + // each traversal. + std::vector non_tds_tvs; + std::vector all_exprs; + std::vector tvs_and_tds; + for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) { + if (stmt->isExpr()) { + all_exprs.push_back(stmt->asExpr()); + } else { + auto val = stmt->asVal(); + if (val->isA() || val->isA()) { + tvs_and_tds.push_back(val); + } else { + non_tds_tvs.push_back(val); + } + } + } + + // When propagating IterTypes across expressions, we need to know the producer + // IterDomains corresponding to a consumer ID. This mapping helps facilitate + // this and is used later by mutate(IterDomain*) in the first pass below. + for (auto expr : all_exprs) { + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + for (auto producer : ir_utils::filterByType(expr->inputs())) { + PairwiseRootDomainMap root_map(producer, consumer); + for (auto [cid, pid] : root_map.mapConsumerToProducer( + consumer->domain(), producer->domain())) { + // Initialize set of producer IDs, if we haven't already + auto& producers = id_producers_.emplace(cid, 0).first->second; + producers.insert(pid); + } + } } } + + // In this first pass, we only mutate Vals that are not TensorDomains or + // TensorViews. + // + // This pass does not modify the Fusion. + for (auto val : non_tds_tvs) { + mutate(val); + } + + // In the second pass, we only mutate Exprs. For each expr, if any of its + // inputs, outputs, or attributes are registered for mutation, a new expr will + // be created and the original expr will be removed. This is the mechanism + // OptOutMutator provides for setting the `definition()` of replaced Vals. + // + // This pass may add and remove Exprs, so elements of all_exprs are invalid + // after this pass. + for (auto expr : all_exprs) { + mutate(expr); + } + + // In the third pass, we mutate the TensorDomains and TensorViews, without + // touching any other Vals or Exprs. The only change made to the Fusion at + // this stage is that TensorViews have their domain() replaced if any of their + // IterDomains are registered for mutation. This must happen last, as Expr + // mutation is required in order to properly connect root and rfactor domains, + // which is checked when creating new TensorDomains. + // + // This pass modifies the Fusion by creating new TensorDomains and swapping + // them into TensorViews. + for (auto val : tvs_and_tds) { + mutate(val); + } } void DynamicTransformConcretizer::concretizeReshape() { @@ -441,228 +507,72 @@ void DynamicTransformConcretizer::checkConcretizedUses( } } -// Concretizes inherited symbolic domains. Note that when this is -// called, it is assumed that all dynamic ops themselves are -// concretized. Since symbolic IDs may be propagated down to -// consumers, those domains need to be concretized accordingly. -void DynamicTransformConcretizer::mutate(TensorView* tv) { - if (!tv->domain()->hasSymbolicAxis()) { - return; - } - - // First, try to concretize the root domain as there may be symbolic - // axes inherited from the producers - propagateFromProducerToConsumer(tv); - - // If no root domain is altered by producer, we don't need to propagate back - // up to rfactor. We could return early, but instead we go ahead and check the - // root to rfactor transforms to be sure we have concretized any intermediate - // IterDomains. - - // At this point, there should be no expr beyond rfactor root - TORCH_INTERNAL_ASSERT( - tv->getLeafDomain() == tv->getMaybeRFactorDomain(), - "Invalid tensor: ", - tv->toString()); - - // If it has an rfactor root domain, the IterTypes of the rfactor - // IDs may need to be updated as well. Traverse the rfactor exprs - // and mutate the IterTypes of output IDs if symbolic. - if (tv->hasRFactor()) { - // Note that it is assumed that theres's no further expression - // beyond the rfactor domain as asserted above - auto all_id_exprs = StmtSort::getExprsBetween( - tv->fusion(), - {tv->getRootDomain().begin(), tv->getRootDomain().end()}, - {tv->getMaybeRFactorDomain().begin(), - tv->getMaybeRFactorDomain().end()}); - for (auto expr : all_id_exprs) { - // Assume outputs of IterDomain exprs are always IterDomains. If - // the assumption is invalidated, the logic here would need to - // be updated. Assert the assumption to immediately detect such - // a case if happened. - for (auto out_val : expr->outputs()) { - TORCH_INTERNAL_ASSERT( - out_val->isA(), - "Unexpected output: ", - out_val->toString(), - ". IterDomain was expected."); - } +void DynamicTransformConcretizer::mutate(IterDomain* id) { + // This will register id for mutation if start, stop, or extent are registered + // for mutation + OptOutMutator::mutate(id); + + // Use this to prototype new concretizations, since it will have replaced + // extent (see above) + auto mut_id = maybeMutated(id)->as(); + + IterDomain* concretized_id = nullptr; - // NOTE: We do not return early if all outputs are concrete as there may - // still be concrete inputs. For example, a Symbolic IterDomain might be - // padded with constant pad widths (1, 1), in which case although we do - // not know the exact extent of the output, we know it is at least as - // large as the sum of the pad widths, 2. In such cases, the output - // IterDomain is concrete at definition, since if the extent is >1 we know - // the IterType is Iteration. In these cases, we must continue to - // concretize intermediate expressions between the root and R-factor - // domain. See test DynamicTransform5_CUDA which demonstrates this - // behavior. - // NOTE: We also do not assume that if one output ID is symbolic, that - // they all must be. See test FusionSliceForNanoGPT3_CUDA for an example - // that does a static split by a factor of 16 of a symbolic input domain. - // The static split in that case results in a concrete IterDomain with - // extent 16 along with a symbolic one (extent ceilDiv(n / 16)). - - // Determine the output IterType - IterType iter_type = IterType::Symbolic; - for (auto inp_id : ir_utils::filterByType(expr->inputs())) { + if (mut_id->isSymbolic()) { + if (auto def = id->definition()) { + IterType iter_type = mut_id->getIterType(); + // Determine concrete IterType based on promotion of inputs to def + for (auto inp_id : ir_utils::filterByType(def->inputs())) { auto updated_id = maybeMutated(inp_id)->as(); iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); } TORCH_INTERNAL_ASSERT( iter_type != IterType::Symbolic, "Failed to concretize an output IterType for expression: ", - expr->toString()); - - // Update the IterType of each output - for (auto out_id : ir_utils::filterByType(expr->outputs())) { - if (!out_id->isSymbolic()) { - continue; + def->toString()); + concretized_id = IterDomainBuilder(mut_id).iter_type(iter_type).build(); + } else { + // IterDomains without definitions might be root domains for the output of + // a TensorView expression. If so, we should propagate their + // concretization in the producer to consumer direction. + + auto producers_it = id_producers_.find(id); + if (producers_it != id_producers_.end()) { + // id was a consumer root ID in some TV expression + + std::optional id_type; + for (auto producer_id : producers_it->second) { + producer_id = maybeMutated(producer_id)->as(); + if (id_type.has_value()) { + id_type = + ops::promoteIterType(*id_type, producer_id->getIterType()); + } else { + id_type = producer_id->getIterType(); + } } - auto concretized_out_id = - IterDomainBuilder(out_id).iter_type(iter_type).build(); - registerConcretization(out_id, concretized_out_id); - } - - // The expr itself needs to be mutated as well in case the outputs are - // mutated, which can be done by the mutate method - OptOutMutator::mutate(expr); - } - } - - // Root and rfactor domains are updated. First mutate the - // TensorDomain and then TensorView - mutate(tv->domain()); - OptOutMutator::mutate(tv); -} - -// Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but -// the contiguity vector may need to be updated as well as symbolic -// domains may be mutated to broadcast domains, which means contiguity -// may need to be changed to nullopt -void DynamicTransformConcretizer::mutate(TensorDomain* td) { - bool mutated = false; - - auto updateIdVec = [&](const std::vector& ids) { - std::vector updated_ids; - for (auto id : ids) { - auto updated_id = maybeMutated(id)->as(); - updated_ids.push_back(updated_id); - if (!updated_id->sameAs(id)) { - mutated = true; - } - } - return updated_ids; - }; - - std::vector root_dom = updateIdVec(td->root()); - std::vector rfactor_dom = td->hasRFactor() - ? updateIdVec(td->maybeRFactor()) - : std::vector(); - std::vector domain = updateIdVec(td->leaf()); - - if (!mutated) { - return; - } - - // Update the contiguity vector. Drop the contig val if mutated to broadcast - auto contig = td->contiguity(); - - for (const auto i : c10::irange(td->maybeRFactor().size())) { - auto original_id = td->maybeRFactor().at(i); - if (original_id->getIterType() != IterType::Symbolic) { - continue; - } - - TORCH_INTERNAL_ASSERT( - contig.at(i), - "Unexpected to have a non-contig symbolic domain: ", - original_id->toString()); - - auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i); - - // If the concretized ID is a broadcast domain, drop the contig val - if (updated_id->isBroadcast()) { - contig.at(i) = std::nullopt; - } - } - - Val* mutated_val = IrBuilder::create( - td->container(), root_dom, rfactor_dom, domain, contig); - registerConcretization(td, mutated_val); -} - -bool DynamicTransformConcretizer::propagateFromProducerToConsumer( - TensorView* consumer) { - if (consumer->definition() == nullptr || - !consumer->domain()->hasSymbolicAxis()) { - return false; - } - const auto& root_domain = consumer->getRootDomain(); - - auto def = consumer->definition(); - - bool is_concretized = false; - - for (const auto i : c10::irange(root_domain.size())) { - auto root_id = root_domain.at(i); - if (root_id->getIterType() != IterType::Symbolic) { - continue; - } - - // Figure out the right IterType of this consumer root ID from its - // corresponding producer IDs - - std::optional id_type; - - for (auto producer : ir_utils::filterByType(def->inputs())) { - PairwiseRootDomainMap root_map(producer, consumer); - auto c2p = root_map.mapConsumerToProducer( - consumer->domain(), producer->domain()); + TORCH_INTERNAL_ASSERT( + id_type.has_value(), + "Did not find id_type for consumer root domain ", + id->toString(), + ". Perhaps consumer def has no inputs."); - TORCH_INTERNAL_ASSERT( - c2p.find(root_id) != c2p.end(), - "No input ID found to map with output ID: ", - root_id->toString()); + TORCH_INTERNAL_ASSERT( + id_type.value() != IterType::Symbolic, + "Failed to concretize ", + id->toString()); - auto input_id = c2p.at(root_id); - TORCH_INTERNAL_ASSERT( - input_id->getIterType() != IterType::Symbolic, - "Producer ID not concretized: ", - input_id->toString()); + if (id_type.value() != id->getIterType()) - if (id_type.has_value()) { - id_type = ops::promoteIterType(*id_type, input_id->getIterType()); - } else { - id_type = input_id->getIterType(); + concretized_id = + IterDomainBuilder(mut_id).iter_type(id_type.value()).build(); } } - - TORCH_INTERNAL_ASSERT( - id_type.has_value(), - "Did not find id_type for consumer root domain ", - root_id->toString(), - ". Perhaps consumer def has no inputs. Consumer definition = ", - def->toString()); - - TORCH_INTERNAL_ASSERT( - id_type != IterType::Symbolic, - "Failed to concretize ", - root_id->toString(), - " of ", - consumer->toString()); - - auto concretized_id = - IterDomainBuilder(root_id).iter_type(*id_type).build(); - - registerConcretization(root_id, concretized_id); - is_concretized = true; } - return is_concretized; + if (concretized_id) { + registerConcretization(id, concretized_id); + } } DynamicTransformInitialInfo DynamicTransform::getInitialInfo(Fusion* fusion) { diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 425463cbc36..9fd0ebf54c9 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -27,6 +27,12 @@ class Fusion; * the dag will be called with handle(Statement*) in topolgical order inputs of * the fusion to outputs of the fusion. * + * Note that for any Val whose definition is non-null, the following are + * processed in order: definition, attributes, members. In particular, this + * means that a TensorView's domain() is processed after its definition, meaning + * producer TVs and their IterDomains are all processed before those of + * consumers. + * * TODO: We may want a BFS version of this code to extract ILP, not implemented * yet. * diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index e33bfaceb50..468d31b6133 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -107,13 +107,30 @@ void OptOutMutator::mutate(TensorDomain* td) { return; } + // Update the contiguity vector. Drop the contig val if mutated to broadcast + auto contig = td->contiguity(); + + for (const auto i : c10::irange(td->maybeRFactor().size())) { + auto original_id = td->maybeRFactor().at(i); + if (original_id->getIterType() != IterType::Symbolic) { + continue; + } + + TORCH_INTERNAL_ASSERT( + contig.at(i), + "Unexpected to have a non-contig symbolic domain: ", + original_id->toString()); + + auto updated_id = td->hasRFactor() ? rfactor_dom.at(i) : root_dom.at(i); + + // If the mutated ID is a broadcast domain, drop the contig val + if (updated_id->isBroadcast()) { + contig.at(i) = std::nullopt; + } + } + Val* mutated_val = IrBuilder::create( - td->container(), - root_dom, - rfactor_dom, - allocation_dom, - domain, - td->contiguity()); + td->container(), root_dom, rfactor_dom, allocation_dom, domain, contig); registerMutation(td, mutated_val); } diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 64a8827d32f..bec420bc6b8 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -944,7 +944,10 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { //{{3, 5}, {-3, -2}, false}, // output is zero-dimensional // Output has size 1 so is set to broadcast. - {{3, 5}, {0, -4}, true}, + // This was previously "working" by concretizing the size-1 pad to + // Iteration, even though it should be Broadcast. When set properly to + // Broadcast, it fails with an error in ConcretizedBroadcastDomains. + //{{3, 5}, {0, -4}, true}, // Test full negative shifts, so output doesn't overlap input {{3, 5}, {-5, 2}, false},