diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 2e3a7a3d49d..74e46a8d018 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -202,10 +202,17 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { if (!id->definition() || id->getIterType() != IterType::Symbolic) { continue; } - if (id->definition()->isA()) { + if (auto rop = dynamic_cast(id->definition())) { info_.dynamic_resized_ids_.push_back(id); - // extent of output determines its IterType - loop_dynamic_vals_.push_back(id->extent()); + if (isOptionEnabled(EnableOption::ConcretizeResizeExtents)) { + // The input extent and both expand vals are all concretized + loop_dynamic_vals_.push_back(rop->in()->extent()); + loop_dynamic_vals_.push_back(rop->leftExpand()); + loop_dynamic_vals_.push_back(rop->rightExpand()); + } else { + // extent of output determines its IterType + loop_dynamic_vals_.push_back(id->extent()); + } } } } @@ -399,27 +406,40 @@ void DynamicTransformConcretizationInfo::analyzeResizes( "Found non-dynamic Resize in initial concretization info: ", op->toString()); - auto extent_val = expr_eval->evaluate(out_id->getMaybeExpandedExtent()); - NVF_ERROR( - extent_val.hasValue(), - "Cannot evaluate the extent of a resized domain: ", - out_id->toString()); + PolymorphicValue input_extent = + expr_eval->evaluate(op->in()->getMaybeExpandedExtent()); NVF_ERROR( - extent_val.is(), - "Invalid evaluated value of resized domain extent: ", - out_id->toString()); - auto extent_int = extent_val.as(); + input_extent.hasValue(), + "Could not compute input extent to dynamic resize ", + op->toString()); + + PolymorphicValue left_expand = expr_eval->evaluate(op->leftExpand()); NVF_ERROR( - extent_int >= 0, - "Invalid resized domain extent ", - extent_int, - " for domain ", - out_id->toString()); + left_expand.hasValue(), + "Could not compute left expand of dynamic resize ", + op->toString()); - auto iter_type = - extent_int == 1 ? IterType::Broadcast : IterType::Iteration; + PolymorphicValue right_expand = expr_eval->evaluate(op->rightExpand()); + NVF_ERROR( + right_expand.hasValue(), + "Could not compute right expand of dynamic resize ", + op->toString()); - resize_itertypes_.emplace_back(id_index, iter_type); + if (isOptionEnabled(EnableOption::ConcretizeResizeExtents)) { + resize_extents_.emplace_back( + id_index, ConcreteResize{input_extent, left_expand, right_expand}); + } else { + // If this option is disabled, it means we will only concretize IterType, + // not the extents of Resize ops. In such case we standardize the + // concretization info so that the resize_extents_ entries are equivalent + // to saving IterType. This is so that we avoid unnecessary cache misses + // in cases where the IterType does not change but the specific extents or + // expand values do change. + resize_extents_.emplace_back( + id_index, + ConcreteResize{ + 2, input_extent + left_expand + right_expand == 1 ? -1 : 1, 0}); + } } } @@ -507,51 +527,11 @@ bool DynamicTransformConcretizationInfo::operator==( return true; } - if (reshape_transforms_.size() != other.reshape_transforms_.size() || - resize_itertypes_.size() != other.resize_itertypes_.size() || - empty_extents_.size() != other.empty_extents_.size() || - factory_output_itertypes_.size() != - other.factory_output_itertypes_.size()) { - return false; - } - - for (const auto i : c10::irange((int64_t)reshape_transforms_.size())) { - const auto& analysis = reshape_transforms_.at(i); - const auto& other_analysis = other.reshape_transforms_.at(i); - if (analysis != other_analysis) { - return false; - } - } - - for (const auto i : c10::irange((int64_t)resize_itertypes_.size())) { - const auto& itertype = resize_itertypes_.at(i); - const auto& other_itertype = other.resize_itertypes_.at(i); - if (itertype != other_itertype) { - return false; - } - } - - if (factory_output_itertypes_ != other.factory_output_itertypes_) { - return false; - } - - for (const auto i : c10::irange((int64_t)expand_axes_.size())) { - const auto& expand_axes = expand_axes_.at(i); - const auto& other_expand_axes = other.expand_axes_.at(i); - if (expand_axes != other_expand_axes) { - return false; - } - } - - for (const auto i : c10::irange((int64_t)empty_extents_.size())) { - const auto& ee = empty_extents_.at(i); - const auto& other_ee = other.empty_extents_.at(i); - if (ee != other_ee) { - return false; - } - } - - return true; + return reshape_transforms_ != other.reshape_transforms_ && + resize_extents_ != other.resize_extents_ && + factory_output_itertypes_ != other.factory_output_itertypes_ && + expand_axes_ != other.expand_axes_ && + empty_extents_ != other.empty_extents_; } std::string DynamicTransformConcretizationInfo::toString() const { @@ -580,12 +560,15 @@ std::string DynamicTransformConcretizationInfo::toString() const { } indent(ss, 1) << "Resize:\n"; NVF_ERROR( - resize_itertypes_.size() == + resize_extents_.size() == initial_info_->getDynamicResizedIterDomains().size()); - for (const auto& [id_index, iter_type] : resize_itertypes_) { + for (const auto& [id_index, concrete_resize] : resize_extents_) { + const auto& [input_extent, left_expand, right_expand] = concrete_resize; auto id = initial_info_->getDynamicResizedIterDomains().at(id_index); - indent(ss, 2) << id->toString() << " (index=" << id_index << "), " - << iter_type << "\n"; + indent(ss, 2) << id->toString() << " (index=" << id_index << ")," + << " input_extent=" << input_extent + << " left_expand=" << left_expand + << " right_expand=" << right_expand << "\n"; } indent(ss, 1) << "Expand:\n"; NVF_ERROR( @@ -668,6 +651,9 @@ class DynamicTransformConcretizer : public OptOutMutator { //! Use this instead of calling registerMutation directly, since it will also //! check that the concretized value is a valid input to all of its uses. void registerConcretization(Val* old_val, Val* new_val) { + if (new_val == old_val) { + return; + } symbolic_to_concretized_map_.emplace(old_val, new_val); checkConcretizedUses(old_val, new_val); NVF_ERROR( @@ -933,21 +919,134 @@ void DynamicTransformConcretizer::concretizeReshape() { } void DynamicTransformConcretizer::concretizeResize() { - // Concretize each resize op. - for (const auto& [id_index, iter_type] : info_->getResizeIterTypes()) { - auto id = info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); + if (!isOptionEnabled(EnableOption::ConcretizeResizeExtents)) { + for (const auto& [id_index, concrete_resize] : info_->getResizeExtents()) { + const auto& [input_extent, left_expand, right_expand] = concrete_resize; + + IterDomain* id = + info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); + + IterType iter_type = input_extent + left_expand + right_expand == 1 + ? IterType::Broadcast + : IterType::Iteration; + + auto def = id->definition()->as(); + NVF_CHECK( + def != nullptr && def->isA(), + "Resized IterDomain must have a Resize definition"); + + auto new_id = IterDomain::resize( + def->in(), + def->leftExpand(), + def->rightExpand(), + id->isRFactorProduct(), + iter_type); + + registerConcretization(id, new_id); + } + return; + } + + // First mutate IterDomains so that each Resize producer has constant extent + for (const auto& [id_index, concrete_resize] : info_->getResizeExtents()) { + const auto& [input_extent, left_expand, right_expand] = concrete_resize; + + IterDomain* id = + info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); + + NVF_CHECK( + id->definition() && id->definition()->isA(), + "Resized IterDomain must have a Resize definition"); + IterDomain* orig_in_id = id->definition()->as()->in(); + + IterDomain* in_id = maybeMutated(orig_in_id)->as(); + + bool has_const_extent = in_id->getMaybeExpandedExtent()->isConst(); + + if (has_const_extent && !in_id->isSymbolic()) { + continue; + } + + // Create a new concretized IterDomain to replace the input, if it is not + // already replaced (for example one IterDomain could theoretically be + // input to multiple Resize ops). + IterDomainBuilder builder(in_id); + + if (!has_const_extent) { + Val* new_extent = IrBuilder::create( + input_extent, in_id->getMaybeExpandedExtent()->dtype()); + if (in_id->hasExpandedExtent()) { + builder.expanded_extent(new_extent); + } else { + builder.extent(new_extent); + } + } + + if (in_id->isSymbolic()) { + builder.iter_type( + input_extent + left_expand + right_expand == 1 ? IterType::Broadcast + : IterType::Iteration); + } + + IterDomain* new_in_id = builder.build(); + + ir_utils::replaceValInExprInputs(id->definition(), orig_in_id, new_in_id); + + registerConcretization(orig_in_id, new_in_id); + if (!orig_in_id->getMaybeExpandedExtent()->sameAs( + new_in_id->getMaybeExpandedExtent())) { + registerConcretization( + orig_in_id->getMaybeExpandedExtent(), + new_in_id->getMaybeExpandedExtent()); + } + } + + // Concretize each resize op using constant expand values + for (const auto& [id_index, concrete_resize] : info_->getResizeExtents()) { + const auto& [input_extent, left_expand, right_expand] = concrete_resize; + + IterType iter_type = input_extent + left_expand + right_expand == 1 + ? IterType::Broadcast + : IterType::Iteration; + + auto orig_id = + info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); + + // id might have been updated in the previous loop + IterDomain* id = maybeMutated(orig_id)->as(); + NVF_CHECK( id->definition() && id->definition()->isA(), "Resized IterDomain must have a Resize definition"); auto def = id->definition()->as(); + + Val* left_expand_val = def->leftExpand(); + if (!def->leftExpand()->isConst()) { + left_expand_val = + IrBuilder::create(left_expand, def->leftExpand()->dtype()); + registerConcretization(def->leftExpand(), left_expand_val); + } + + Val* right_expand_val = def->rightExpand(); + if (!def->rightExpand()->isConst()) { + right_expand_val = + IrBuilder::create(right_expand, def->rightExpand()->dtype()); + registerConcretization(def->rightExpand(), right_expand_val); + } + auto new_id = IterDomain::resize( - def->in(), - def->leftExpand(), - def->rightExpand(), + maybeMutated(def->in())->as(), + left_expand_val, + right_expand_val, id->isRFactorProduct(), iter_type); - registerConcretization(id, new_id); + registerConcretization(orig_id, new_id); + + // Concretize the output shape which is constant + if (!orig_id->extent()->sameAs(new_id->extent())) { + registerConcretization(orig_id->extent(), new_id->extent()); + } } } @@ -1118,6 +1217,7 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Update the IterType of each output for (auto out_id : ir_utils::filterByType(expr->outputs())) { auto mut_id = maybeMutated(out_id)->as(); + if (!mut_id->isSymbolic()) { // We are only concretizing IterType here, so if we have already // concretized the iter_type for this ID, we can skip this. @@ -1156,6 +1256,7 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // Set expr as the definition for concretized outputs expr = mutateExprOutputsOnly(expr); + // Replace inputs and attributes that were concretized mutate(expr); } @@ -1189,13 +1290,10 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { return updated_ids; }; - std::vector root_dom = - td->hasRoot() ? updateIdVec(td->root()) : std::vector(); + std::vector root_dom = updateIdVec(td->root()); std::vector logical_dom = updateIdVec(td->logical()); std::vector loop_domain = updateIdVec(td->loop()); - std::vector alloc_dom = td->hasAllocation() - ? updateIdVec(td->allocation()) - : std::vector(); + std::vector alloc_dom = updateIdVec(td->allocation()); if (!mutated) { return; @@ -1350,7 +1448,13 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( bool is_concretized = false; for (const auto i : c10::irange((int64_t)root_domain.size())) { - auto root_id = root_domain.at(i); + IterDomain* orig_root_id = root_domain.at(i); + + // This root ID might have already been marked for concretization. For + // example, if it is used in a Resize op then it will be concretized + // earlier in concretizeResize. + auto root_id = maybeMutated(orig_root_id)->as(); + if (root_id->getIterType() != IterType::Symbolic) { continue; } @@ -1367,7 +1471,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( bool found = false; for (const auto& c2p : c2p_maps) { - auto p_it = c2p.find(root_id); + auto p_it = c2p.find(orig_root_id); // In some cases, we can exact map to one producer, but not to another. // This is the case for index_select, for example, whose first input is // the tensor to look up values in and whose second input gives the @@ -1427,20 +1531,16 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( // Propagate expanded IterDomains by swapping the extent into the expanded // extent concretized_id = - IterDomainBuilder(maybeMutated(root_id)->as()) + IterDomainBuilder(root_id) .iter_type(*id_type) .extent(FusionGuard::getCurFusion()->oneVal(DataType::Index)) - .expanded_extent( - maybeMutated(root_id)->as()->extent()) + .expanded_extent(root_id->extent()) .build(); } else { - concretized_id = - IterDomainBuilder(maybeMutated(root_id)->as()) - .iter_type(*id_type) - .build(); + concretized_id = IterDomainBuilder(root_id).iter_type(*id_type).build(); } - registerConcretization(root_id, concretized_id); + registerConcretization(orig_root_id, concretized_id); is_concretized = true; } @@ -1493,8 +1593,11 @@ size_t DynamicTransformConcretizationInfo::hash() const { for (const auto& extent_idx : getEmptyExtents()) { hashCombine(hash, (size_t)extent_idx); } - for (const auto& [id, iter_type] : getResizeIterTypes()) { - hashCombine(hash, (size_t)iter_type); + for (const auto& [id, concrete_resize] : getResizeExtents()) { + const auto& [input_extent, left_expand, right_expand] = concrete_resize; + hashCombine(hash, (size_t)input_extent); + hashCombine(hash, (size_t)left_expand); + hashCombine(hash, (size_t)right_expand); } for (const auto& pair_vec : getFactoryOutputIterTypes()) { for (const auto& [pos, iter_type] : pair_vec) { @@ -1502,9 +1605,6 @@ size_t DynamicTransformConcretizationInfo::hash() const { hashCombine(hash, (size_t)iter_type); } } - for (const auto& [id, iter_type] : getResizeIterTypes()) { - hashCombine(hash, (size_t)iter_type); - } for (const auto& [id, expand_axes] : getExpandAxes()) { hashCombine(hash, (size_t)id); for (bool e : expand_axes) { diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index e5b25f0dab7..cd82eb75656 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -171,11 +171,15 @@ class DynamicTransformConcretizationInfo { return reshape_transforms_; } + //! Holds input extent, left expansion, right expansion + using ConcreteResize = std::tuple; + //! Return a vector of pairs holding the index of each resized IterDomain in //! the vector returned by initialInfo()->getDynamicResizedIterDomains(), //! along with the IterType it should be concretized to. - const std::vector>& getResizeIterTypes() const { - return resize_itertypes_; + const std::vector>& getResizeExtents() + const { + return resize_extents_; } //! Return a vector of pairs holding the index of each expanded TensorView in @@ -261,8 +265,8 @@ class DynamicTransformConcretizationInfo { //! Holds the index of the resized IterDomain (output of the Resize op) in the //! vector returned by initial_info_->getDynamicResizedIterDomains() along - //! with its concretized IterType - std::vector> resize_itertypes_; + //! with its concretized input extent, left and right expansions + std::vector> resize_extents_; //! Holds the index of the expanded TensorView in the vector returned by //! initial_info_->getDynamicExpandedTensorViews(), and a corresponding vector diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index 2e7090373a8..716f9eeca9f 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -84,7 +84,8 @@ void OptOutMutator::mutate(Val* s) {} void OptOutMutator::mutate(NamedScalar* ns) {} -void OptOutMutator::mutate(IterDomain* id) { +void OptOutMutator::mutate(IterDomain* orig_id) { + IterDomain* id = maybeMutated(orig_id)->as(); Val* start = maybeMutated(id->start()); Val* extent = maybeMutated(id->extent()); Val* expanded_extent = nullptr; @@ -106,7 +107,10 @@ void OptOutMutator::mutate(IterDomain* id) { .build(); // This guarantees we replace id in all downstream expressions - registerMutation(id, new_id); + registerMutation(orig_id, new_id); + if (id != orig_id) { + registerMutation(id, new_id); + } // Preserve definition if it exists in id. This is important since otherwise // we might disconnect the root to logical transform path. For example if id @@ -183,22 +187,56 @@ Expr* OptOutMutator::mutateExpr( std::vector mutated_outputs; mutated_outputs.reserve(op->outputs().size()); for (auto output : op->outputs()) { - mutated_outputs.emplace_back( - replace_outputs ? maybeMutated(output) : output); + if (replace_outputs) { + Val* mut_out = maybeMutated(output); + if (mut_out != output && + std::find_if( + op->inputs().begin(), op->inputs().end(), [&](Val* const inp) { + return (replace_inputs ? maybeMutated(inp) : inp) == mut_out; + }) == op->inputs().end()) { + // skip using mutated output if is one of the inputs. + output = mut_out; + } + } + mutated_outputs.emplace_back(output); } std::vector mutated_inputs; mutated_inputs.reserve(op->inputs().size()); for (auto input : op->inputs()) { - mutated_inputs.emplace_back(replace_inputs ? maybeMutated(input) : input); + if (replace_inputs) { + Val* mut_inp = maybeMutated(input); + if (mut_inp != input && + std::find_if( + op->outputs().begin(), op->outputs().end(), [&](Val* const outp) { + return (replace_outputs ? maybeMutated(outp) : outp) == mut_inp; + }) == op->outputs().end()) { + // skip using mutated input if is one of the outputs. + input = mut_inp; + } + } + mutated_inputs.emplace_back(input); } std::vector mutated_attrs; mutated_attrs.reserve(op->attributes().size()); for (auto attr : op->attributes()) { if (auto attr_val = dynamic_cast(attr)) { - mutated_attrs.emplace_back( - replace_inputs ? maybeMutated(attr_val) : attr_val); + if (replace_inputs) { + Val* mut_attr = maybeMutated(attr_val); + if (mut_attr != attr_val && + std::find_if( + op->outputs().begin(), + op->outputs().end(), + [&](Val* const outp) { + return (replace_outputs ? maybeMutated(outp) : outp) == + mut_attr; + }) == op->outputs().end()) { + // skip using mutated attr if is one of the outputs. + attr_val = mut_attr; + } + } + mutated_attrs.emplace_back(attr_val); } else { mutated_attrs.emplace_back(attr); } diff --git a/csrc/options.cpp b/csrc/options.cpp index 2641216bd23..0409dff5b45 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -152,6 +152,7 @@ template <> std::unordered_map> Options< EnableOption>::getOptionsFromEnv() { const std::unordered_map available_options = { + {"concretize_resize_extents", EnableOption::ConcretizeResizeExtents}, {"fuse_matmul", EnableOption::FuseMatmul}, {"fuse_multiple_matmuls", EnableOption::FuseMultipleMatmuls}, {"id_model", EnableOption::IdModel}, diff --git a/csrc/options.h b/csrc/options.h index b1b706e8f15..55081e49f90 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -91,6 +91,7 @@ enum class DebugDumpOption { //! These can be set through the `NVFUSER_ENABLE` environment variable //! enum class EnableOption { + ConcretizeResizeExtents, //! Concretize input extents and expand of Resize FuseMatmul, //! Enable automatic fusion of matmul and linear ops FuseMultipleMatmuls, //! Allow fusing more than one matmul in a single kernel IdModel, //! Enable IdModel diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 17ecac56cf4..7bad424717c 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -3636,7 +3636,7 @@ TEST_F(ResizeTest, Chunk_NegativeSize) { auto in_tensor = at::randn({13}).cuda(); fec.runFusionWithInputs({in_tensor}); }, - ThrowsMessage(HasSubstr("Invalid resized domain extent"))); + ThrowsMessage(HasSubstr("Unexpected size of axis: -2"))); } TEST_F(ResizeTest, Chunk_SizeZero) {