From 51b64dafc7584a9833b5ce705681f91e4e638fc4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 06:58:16 -0400 Subject: [PATCH 01/14] First draft of slice concretization --- csrc/dynamic_transform.cpp | 175 +++++++++++++++++++++++++++++++++++++ csrc/dynamic_transform.h | 68 +++++++++++++- csrc/ops/alias.cpp | 63 +++++++++++-- csrc/ops/alias.h | 7 +- 4 files changed, 302 insertions(+), 11 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 42581cdeb5d..7281d64fcea 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -37,6 +39,12 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op)); } } + cloned_info.dynamic_sliced_tvs_.reserve(dynamic_sliced_tvs_.size()); + for (const auto v : dynamic_sliced_tvs_) { + if (v) { + cloned_info.dynamic_sliced_tvs_.push_back(ir_cloner.clone(v)); + } + } cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); for (const auto v : root_dynamic_vals_) { if (v) { @@ -54,6 +62,10 @@ std::string DynamicTransformInitialInfo::toString() const { for (const auto& op : dynamic_reshaped_tvs_) { ss << indent << indent << op->toString() << "\n"; } + ss << indent << "Dynamic sliced TensorViews:\n"; + for (const auto& op : dynamic_sliced_tvs_) { + ss << indent << indent << op->toString() << "\n"; + } ss << indent << "Dynamic resized IterDomains:\n"; for (const auto& op : dynamic_resized_ids_) { ss << indent << indent << op->toString() << "\n"; @@ -109,6 +121,12 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { //! Detect dynamic IterDomain transforms when handling TensorViews void handle(TensorView* tv) override { + if (tv->definition() && tv->definition()->isA()) { + if (tv->domain()->hasSymbolicAxis()) { + info_.dynamic_sliced_tvs_.push_back(tv); + } + return; + } const auto& rfd = tv->getMaybeRFactorDomain(); for (auto id : rfd) { if (!id->definition() || id->getIterType() != IterType::Symbolic) { @@ -238,6 +256,72 @@ void DynamicTransformConcretizationInfo::analyzeReshapes( } } +void DynamicTransformConcretizationInfo::analyzeSlices( + ExpressionEvaluator* expr_eval) { + const auto& sliced_tvs = initial_info_->getDynamicSlicedTensorViews(); + for (auto tv_index : c10::irange(sliced_tvs.size())) { + auto out_tv = sliced_tvs.at(tv_index); + auto op = out_tv->definition()->as(); + auto inp_tv = op->in()->as(); + const auto inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + const auto ranges = op->getRanges(); + std::vector slice_descs(inp_dom.size()); + for (auto i : c10::irange(inp_dom.size())) { + const auto& range = ranges.at(i); + + auto start_opt = expr_eval->evaluate(range.start); + TORCH_INTERNAL_ASSERT( + start_opt.has_value(), + "Could not evaluate start of slice range ", + range.start); + auto start = start_opt->as(); + + auto stop_opt = expr_eval->evaluate(range.stop); + TORCH_INTERNAL_ASSERT( + stop_opt.has_value(), + "Could not evaluate stop of slice range ", + range.stop); + auto stop = stop_opt->as(); + + auto step_opt = expr_eval->evaluate(range.step); + TORCH_INTERNAL_ASSERT( + step_opt.has_value(), + "Could not evaluate step of slice range ", + range.step); + auto step = step_opt->as(); + + TORCH_INTERNAL_ASSERT(step != 0, "Slice step must not be zero"); + TORCH_INTERNAL_ASSERT( + step == 1, "Slicing with step != 1 is not currently supported"); + + auto extent_opt = + expr_eval->evaluate(inp_dom.at(i)->getMaybeExpandedExtent()); + TORCH_INTERNAL_ASSERT( + extent_opt.has_value(), + "Could not evaluate slice input extent ", + inp_dom.at(i)->getMaybeExpandedExtent()); + auto extent = extent_opt->as(); + + auto getBranch = [&extent](int64_t a) -> SliceIndexBranch { + if (a <= extent) { + return SliceIndexBranch::AlwaysZero; + } else if (a < 0) { + return SliceIndexBranch::Negative; + } else if (a < extent) { + return SliceIndexBranch::Positive; + } else { + return SliceIndexBranch::AlwaysExtent; + } + }; + slice_descs[i].start_branch = getBranch(start); + slice_descs[i].stop_branch = getBranch(stop); + slice_descs[i].is_empty = (stop - start) * step <= 0; + } + slice_descriptors_.emplace_back(tv_index, slice_descs); + } +} + void DynamicTransformConcretizationInfo::analyzeResizes( ExpressionEvaluator* expr_eval) { const auto& resize_ids = initial_info_->getDynamicResizedIterDomains(); @@ -341,6 +425,8 @@ class DynamicTransformConcretizer : public OptOutMutator { void concretizeReshape(); + void concretizeSlice(); + void concretizeResize(); //! Use this instead of calling registerMutation directly, since it will also @@ -373,6 +459,9 @@ void DynamicTransformConcretizer::concretize() { // First, concretize all dynamic reshape ops concretizeReshape(); + // Concretize dynamic slices + concretizeSlice(); + // Set output IterTypes for dynamic resize ops concretizeResize(); @@ -414,6 +503,92 @@ void DynamicTransformConcretizer::concretizeReshape() { } } +void DynamicTransformConcretizer::concretizeSlice() { + auto fusion = FusionGuard::getCurFusion(); + for (const auto& [tv_index, slice_descs] : info_->getSliceDescriptors()) { + auto incomplete_out_tv = + info_->initialInfo()->getDynamicSlicedTensorViews().at(tv_index); + auto slice_op = incomplete_out_tv->definition()->as(); + auto inp_tv = slice_op->input(0)->as(); + + const auto& root_dom = incomplete_out_tv->getRootDomain(); + // Create new rfactor domain with potentially newly-resized root IDs + std::vector new_rfactor(root_dom.size()); + + bool is_empty = false; + bool is_sliced = false; + const auto ranges = slice_op->getRanges(); + auto map_index = [&fusion]( + SliceIndexBranch branch, Val* a, Val* extent) -> Val* { + if (branch == SliceIndexBranch::AlwaysExtent) { + return extent; + } else if (branch == SliceIndexBranch::Negative) { + return SimplifyingIrBuilder::negExpr(a); + } else if (branch == SliceIndexBranch::Positive) { + return a; + } else { + return fusion->zeroVal(); + } + }; + std::vector new_ranges; + new_ranges.reserve(ranges.size()); + for (auto i : c10::irange(root_dom.size())) { + auto desc = slice_descs.at(i); + if (desc.is_empty) { + is_empty = true; + // Use 0:0:1 as the canonical empty slice. + new_ranges.push_back( + {fusion->zeroVal(), fusion->zeroVal(), fusion->oneVal()}); + } else { + auto range = ranges.at(i); + auto inp_extent = root_dom.at(i)->getMaybeExpandedExtent(); + auto new_start = map_index(desc.start_branch, range.start, inp_extent); + auto new_stop = map_index(desc.stop_branch, range.stop, inp_extent); + new_ranges.push_back({new_start, new_stop, range.step}); + if (desc.start_branch != SliceIndexBranch::AlwaysZero || + desc.stop_branch != SliceIndexBranch::AlwaysExtent) { + is_sliced = true; + } + } + } + + TensorView* new_tv = nullptr; + + if (is_empty) { + std::vector new_shape(ranges.size()); + for (auto i : c10::irange(ranges.size())) { + auto new_range = new_ranges.at(i); + // TODO: this assumes new_range.step == 1 + new_shape[i] = + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start); + } + // TODO: process as empty tensor if is_empty + auto dtype = incomplete_out_tv->getDataType().value(); + new_tv = full(new_shape, fusion->zeroVal(dtype), dtype); + } else if (!is_sliced) { + // Replace the slice with set() + new_tv = set(inp_tv); + } else { + new_tv = slice(inp_tv, new_ranges, /*skip_symbolic*/ true); + } + + // We do the replacement directly here, but we must still check that the + // replacement is valid + checkConcretizedUses(incomplete_out_tv, new_tv); + + // Replace the old tensor with the new concretized tensor + for (auto use_of_old_tv : incomplete_out_tv->uses()) { + ir_utils::replaceValInExpr(use_of_old_tv, incomplete_out_tv, new_tv); + } + + if (incomplete_out_tv->isFusionOutput()) { + incomplete_out_tv->fusion()->replaceOutput(incomplete_out_tv, new_tv); + } + + info_->fusion()->removeVal(incomplete_out_tv); + } +} + void DynamicTransformConcretizer::concretizeResize() { // Concretize each resize op. for (const auto& [id_index, iter_type] : info_->getResizeIterTypes()) { diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 34609a002f1..35e66ab6538 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -40,7 +40,8 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { //! Return whether any dynamic transforms exist in the Fusion bool hasDynamicTransforms() const { - return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty(); + return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty() || + !dynamic_sliced_tvs_.empty(); } //! Return a set of scalars that are inputs or extents of input TensorViews @@ -62,6 +63,11 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { return dynamic_resized_ids_; } + //! Return a vector of outputs of Slice expressions + const std::vector& getDynamicSlicedTensorViews() const { + return dynamic_sliced_tvs_; + } + std::string toString() const; DynamicTransformInitialInfo clone(IrCloner& ir_cloner) const; @@ -93,12 +99,51 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { std::vector dynamic_resized_ids_; + // Slice operations can have complicated output extents. The inputs to slice + // are a start, stop, and step for each sliced dimension. Each of these is an + // integer, and any combination of three finite integers with step != 0 is + // acceptable and should run without error. Normalization of the start and + // stop values must be done, followed by computation of the output extent: + // + // normed_start = min(max(where(start < 0, extent + start, start), 0), + // extent); normed_stop = max(min(max(where(stop < 0, extent + stop, stop), + // 0), extent), normed_start); extent = max((normed_stop - normed_start + 1) + // / step, 0); + // + // These expressions are unwieldy and cannot be significantly simplified + // unless we know certain relations about the start, stop, and step scalars. + // Here we keep track of non-static slices or slices with non-static input + // extents. That way we can restrict to a single branch in each of these + // expressions during concretization. + std::vector dynamic_sliced_tvs_; + // Root Vals that determine concretization std::unordered_set root_dynamic_vals_; friend class DynamicTransformInitialInfoBuilder; }; +//! This enum describes cases that can occur for the start or stop arguments to +//! slice(). Each of these leads to a different branch in the normalized form's +//! general expression. +enum class SliceIndexBranch { + AlwaysZero, // a <= -extent + Negative, // -ext < a < 0 + Positive, // 0 <= a < extent + AlwaysExtent // extent <= a +}; + +//! Describes a 1D slice in terms of the start, stop, and extent values +struct Concrete1DSliceDescriptor { + //! These enums determine the form of the simplified expressions + SliceIndexBranch start_branch = SliceIndexBranch::Positive; + SliceIndexBranch stop_branch = SliceIndexBranch::Positive; + + //! True if normalized values satisfy (stop - start) * step <= 0 in which case + //! we would return an empty tensor. + bool is_empty = false; +}; + //! A set of transformations for a symbolic fusion with concrete sizes //! of the fusion inputs class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { @@ -117,6 +162,8 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { analyzeReshapes(expr_eval); + analyzeSlices(expr_eval); + analyzeResizes(expr_eval); } @@ -136,6 +183,15 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { return resize_itertypes_; } + //! Return a vector of pairs holding the index of each sliced TensorView in + //! the vector returned by initialInfo()->getDynamicSlicedTensorViews(), + //! along with a vector of descriptors indicating how each axis should be + //! concretized. + const std::vector>>& + getSliceDescriptors() const { + return slice_descriptors_; + } + //! Comparison operator for the purposes of determining cache hits. This does //! not guarantee equality of all members. Instead, it returns equal if the //! resulting concretizations would be structurally equivalent. Note that @@ -152,6 +208,10 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { //! during concretization. void analyzeReshapes(ExpressionEvaluator* expr_eval); + //! Given an ExpressionEvaluator which already has input scalars bound to it, + //! determine the branches of expressions in dynamic slice ops. + void analyzeSlices(ExpressionEvaluator* expr_eval); + //! Given an ExpressionEvaluator which already has input scalars bound to it, //! determine the concrete IterType of each resized IterDomain. void analyzeResizes(ExpressionEvaluator* expr_eval); @@ -189,6 +249,12 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { //! vector returned by initial_info_->getDynamicResizedIterDomains() along //! with its concretized IterType std::vector> resize_itertypes_; + + //! Holds the index of the sliced TensorView (output of the SliceOp) in the + //! vector returned by initial_info_->getDynamicSlicedTensorViews() along + //! with a descriptor of how it should be concretized. + std::vector>> + slice_descriptors_; }; class TORCH_CUDA_CU_API DynamicTransform { diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 49ed3b9912d..d1af14004c9 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -639,10 +639,14 @@ TensorView* cat(const std::vector& inputs, int64_t cat_dim) { return out; } -// Currently there's no error check about the actual values of the -// Slice parameters. For example, the start parameter of a range of a -// domain is assumed to be >= 0 and < the extent of the domain. -TensorView* slice(TensorView* inp, const std::vector& ranges) { +// If skip_symbolic is true, then the start and stop parameters of a range of a +// domain is assumed to be >= 0 and < the extent of the domain. Otherwise, +// non-constant inputs will lead to Symbolic IterDomains in the output, which +// must be later concretized. +TensorView* slice( + TensorView* inp, + const std::vector& ranges, + bool skip_symbolic) { const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); const int ndims = static_cast(inp_dom.size()); @@ -666,6 +670,19 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { return range; }; + // Adjust an integer value relative to a given extent. This is + // min(max(where(a < 0, extent + a, a), 0), extent) + auto adjust_start_stop = [](int64_t& a, int64_t extent) { + if (a < 0) { + a += extent; + } + if (a < 0) { + a = 0; + } else if (a > extent) { + a = extent; + } + }; + for (auto& range : ranges) { // Step not supported yet TORCH_CHECK( @@ -693,11 +710,39 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { } else { out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); - out_rf_id = IterDomain::resize( - out_root_id, - SimplifyingIrBuilder::negExpr(range.start), - sub(range.stop, inp_root_id->extent()), - true); + // The start, stop, and extent of the output will all require complicated + // expressions which will be simplified at concretization. Here we set + // the output to Symbolic unless all required scalars are constant. + if (range.start->isConstInt() && range.stop->isConstInt() && + inp_root_id->isConstInt()) { + auto start = range.start->evaluateInt(); + auto stop = range.stop->evaluateInt(); + auto step = range.step->evaluateInt(); + TORCH_INTERNAL_ASSERT(step != 0, "Slice step must be non-zero"); + TORCH_INTERNAL_ASSERT( + step == 1, "Slicing with step != 1 is not currently supported"); + auto inp_extent = inp_root_id->extent()->evaluateInt(); + adjust_start_stop(start, inp_extent); + adjust_start_stop(stop, inp_extent); + out_rf_id = IterDomain::resize( + out_root_id, + SimplifyingIrBuilder::negExpr(IrBuilder::create(start)), + sub(IrBuilder::create(stop), inp_root_id->extent()), + true); + } else if (skip_symbolic) { + out_rf_id = IterDomain::resize( + out_root_id, + SimplifyingIrBuilder::negExpr(range.start), + sub(range.stop, inp_root_id->extent()), + true); + } else { + out_rf_id = IterDomainBuilder( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create()) + .is_rfactor_domain(true) + .iter_type(IterType::Symbolic) + .build(); + } needs_real_slicing = true; } root_ids.at(idx) = out_root_id; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 1d3299fd6e5..c2625de6355 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -106,8 +106,13 @@ TORCH_CUDA_CU_API TensorView* cat( //! Return a tensor where each dimension is sliced as specified by the //! ranges parameter. Stepping must be one at this moment. +//! If skip_symbolic is true, we assume start < stop and both start and stop are +//! between 0 and extent (inclusive). Otherwise, unless all input scalars are +//! constant, the output will have Symbolic IterDomains that must be concretized +//! later. TORCH_CUDA_CU_API TensorView* slice( TensorView* inp, - const std::vector& ranges); + const std::vector& ranges, + bool skip_symbolic = false); } // namespace nvfuser From 1e5ce513c05351be801790342dfffa16d1bb4f04 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 08:30:39 -0400 Subject: [PATCH 02/14] Bind sliced extents to properly analyze downstream slices --- csrc/dynamic_transform.cpp | 59 +++++++++++++++++++++++++++++++++----- csrc/dynamic_transform.h | 15 ++++++++++ 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 7281d64fcea..a976f08108e 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -265,6 +265,7 @@ void DynamicTransformConcretizationInfo::analyzeSlices( auto inp_tv = op->in()->as(); const auto inp_dom = TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + const auto out_dom = out_tv->getMaybeRFactorDomain(); const auto ranges = op->getRanges(); std::vector slice_descs(inp_dom.size()); for (auto i : c10::irange(inp_dom.size())) { @@ -295,20 +296,20 @@ void DynamicTransformConcretizationInfo::analyzeSlices( TORCH_INTERNAL_ASSERT( step == 1, "Slicing with step != 1 is not currently supported"); - auto extent_opt = + auto inp_extent_opt = expr_eval->evaluate(inp_dom.at(i)->getMaybeExpandedExtent()); TORCH_INTERNAL_ASSERT( - extent_opt.has_value(), + inp_extent_opt.has_value(), "Could not evaluate slice input extent ", inp_dom.at(i)->getMaybeExpandedExtent()); - auto extent = extent_opt->as(); + auto inp_extent = inp_extent_opt->as(); - auto getBranch = [&extent](int64_t a) -> SliceIndexBranch { - if (a <= extent) { + auto getBranch = [&inp_extent](int64_t a) -> SliceIndexBranch { + if (a <= -inp_extent) { return SliceIndexBranch::AlwaysZero; } else if (a < 0) { return SliceIndexBranch::Negative; - } else if (a < extent) { + } else if (a < inp_extent) { return SliceIndexBranch::Positive; } else { return SliceIndexBranch::AlwaysExtent; @@ -317,6 +318,38 @@ void DynamicTransformConcretizationInfo::analyzeSlices( slice_descs[i].start_branch = getBranch(start); slice_descs[i].stop_branch = getBranch(stop); slice_descs[i].is_empty = (stop - start) * step <= 0; + + // The dynamic slice output has a purely symbolic extent. Here we evaluate + // the proper extent to determine the output IterType. + auto map_int_index = [&inp_extent](const int64_t& a) -> int64_t { + if (a <= -inp_extent) { + return 0; + } else if (a < 0) { + return -a; + } else if (a < inp_extent) { + return a; + } else { + return inp_extent; + } + }; + // actual size of sliced dimension is ceilDiv(stop - start, step) when + // step > 0. When step < 0, that expression is off by one and instead the + // extent in that case is ceilDiv(start - stop, -step). + auto concrete_sliced_extent = step > 0 + ? (map_int_index(stop) - map_int_index(start) + step - 1) / step + : (map_int_index(stop) - map_int_index(start) + step + 1) / step; + + if (concrete_sliced_extent == 1) { + slice_descs[i].iter_type = IterType::Broadcast; + } + + // Even though we will eventually replace this TV, there will still be + // references to its extents in downstream uses. We will need to evaluate + // these properly both in this analysis, and at concretization. Here we + // bind the output extent so that downstream extents can be properly + // computed during this analysis. After concretization, this will happen + // via ExpressionEvaluator::propagateBoundValuesThroughExactMaps. + expr_eval->bind(out_dom[i]->extent(), concrete_sliced_extent); } slice_descriptors_.emplace_back(tv_index, slice_descs); } @@ -365,7 +398,8 @@ bool DynamicTransformConcretizationInfo::operator==( } if (reshape_transforms_.size() != other.reshape_transforms_.size() || - resize_itertypes_.size() != other.resize_itertypes_.size()) { + resize_itertypes_.size() != other.resize_itertypes_.size() || + slice_descriptors_.size() != other.slice_descriptors_.size()) { return false; } @@ -377,6 +411,14 @@ bool DynamicTransformConcretizationInfo::operator==( } } + for (const auto i : c10::irange(slice_descriptors_.size())) { + const auto& desc = slice_descriptors_.at(i); + const auto& other_desc = other.slice_descriptors_.at(i); + if (desc != other_desc) { + return false; + } + } + for (const auto i : c10::irange(resize_itertypes_.size())) { const auto& itertype = resize_itertypes_.at(i); const auto& other_itertype = other.resize_itertypes_.at(i); @@ -572,6 +614,9 @@ void DynamicTransformConcretizer::concretizeSlice() { new_tv = slice(inp_tv, new_ranges, /*skip_symbolic*/ true); } + // TODO: We need to update the maybeRFactorDomains of new_tv if there are + // any Broadcast sliced dimensions. + // We do the replacement directly here, but we must still check that the // replacement is valid checkConcretizedUses(incomplete_out_tv, new_tv); diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 35e66ab6538..f44ee2fb1ea 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -142,6 +142,21 @@ struct Concrete1DSliceDescriptor { //! True if normalized values satisfy (stop - start) * step <= 0 in which case //! we would return an empty tensor. bool is_empty = false; + + //! This can be either Iteration or Broadcast (if sliced extent is 1) + IterType iter_type = IterType::Iteration; + + // NOTE: In the future, we can hold branches for "step" here as well, in order + // to specialize when step == 1 + + bool operator==(const Concrete1DSliceDescriptor& other) const { + return start_branch == other.start_branch && + stop_branch == other.stop_branch && is_empty == other.is_empty && + iter_type == other.iter_type; + } + bool operator!=(const Concrete1DSliceDescriptor& other) const { + return !operator==(other); + } }; //! A set of transformations for a symbolic fusion with concrete sizes From fcd9f17c21f153d77fb5235e19248a95d35b2546 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 09:45:27 -0400 Subject: [PATCH 03/14] Add dynamic vals for slice --- csrc/dynamic_transform.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index a976f08108e..67631e796ef 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -124,6 +124,24 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { if (tv->definition() && tv->definition()->isA()) { if (tv->domain()->hasSymbolicAxis()) { info_.dynamic_sliced_tvs_.push_back(tv); + auto root_dom = tv->getRootDomain(); + const auto ranges = tv->definition()->as()->getRanges(); + TORCH_INTERNAL_ASSERT( + ranges.size() == root_dom.size(), + "Mismatch between number of slice ranges ", + ranges.size(), + " and size of root domain ", + root_dom.size()); + for (auto i : c10::irange(root_dom.size())) { + // input extent and start/stop/step values determine slice + // concretization + auto root_ext = root_dom.at(i)->getMaybeExpandedExtent(); + leaf_dynamic_vals_.push_back(root_ext); + auto range = ranges.at(i); + leaf_dynamic_vals_.push_back(range.start); + leaf_dynamic_vals_.push_back(range.stop); + leaf_dynamic_vals_.push_back(range.step); + } } return; } From abb07431632736c3590b2d2bbbf164a212e04bb8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 09:46:13 -0400 Subject: [PATCH 04/14] Add slice step branch to conc info --- csrc/dynamic_transform.cpp | 24 ++++++++++++++++++++---- csrc/dynamic_transform.h | 6 ++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 67631e796ef..6b744098a25 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -605,8 +605,10 @@ void DynamicTransformConcretizer::concretizeSlice() { auto new_start = map_index(desc.start_branch, range.start, inp_extent); auto new_stop = map_index(desc.stop_branch, range.stop, inp_extent); new_ranges.push_back({new_start, new_stop, range.step}); + // Trivial slices correspond to 0:extent:1 if (desc.start_branch != SliceIndexBranch::AlwaysZero || - desc.stop_branch != SliceIndexBranch::AlwaysExtent) { + desc.stop_branch != SliceIndexBranch::AlwaysExtent || + desc.step_branch != SliceStepBranch::One) { is_sliced = true; } } @@ -618,9 +620,23 @@ void DynamicTransformConcretizer::concretizeSlice() { std::vector new_shape(ranges.size()); for (auto i : c10::irange(ranges.size())) { auto new_range = new_ranges.at(i); - // TODO: this assumes new_range.step == 1 - new_shape[i] = - SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start); + auto desc = slice_descs.at(i); + // Depending on the step branch, we can use different output extent + // expressions + switch (desc.step_branch) { + case SliceStepBranch::One: + new_shape[i] = + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start); + break; + case SliceStepBranch::GreaterThanOne: + new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start), + new_range.step); + case SliceStepBranch::Negative: + new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( + SimplifyingIrBuilder::subExpr(new_range.start, new_range.stop), + SimplifyingIrBuilder::negExpr(new_range.step)); + } } // TODO: process as empty tensor if is_empty auto dtype = incomplete_out_tv->getDataType().value(); diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index f44ee2fb1ea..2be8575b8aa 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -133,11 +133,17 @@ enum class SliceIndexBranch { AlwaysExtent // extent <= a }; +//! This enum describes the "step" argument to slice, which can be a positive or +//! negative integer (but not zero). We handle the special case of step == 1 +//! separately from step > 1 since this simplifies some expressions. +enum class SliceStepBranch { Negative, One, GreaterThanOne }; + //! Describes a 1D slice in terms of the start, stop, and extent values struct Concrete1DSliceDescriptor { //! These enums determine the form of the simplified expressions SliceIndexBranch start_branch = SliceIndexBranch::Positive; SliceIndexBranch stop_branch = SliceIndexBranch::Positive; + SliceStepBranch step_branch = SliceStepBranch::One; //! True if normalized values satisfy (stop - start) * step <= 0 in which case //! we would return an empty tensor. From cc03c543b66d7e874aaa472cd4c43ba1383b3354 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 09:46:58 -0400 Subject: [PATCH 05/14] Update hash and operator== --- csrc/dynamic_transform.cpp | 5 +++++ csrc/dynamic_transform.h | 16 +++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 6b744098a25..2ded6ccc214 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -938,6 +938,11 @@ size_t DynamicTransformConcretizationInfo::hash() const { for (const auto& [id, iter_type] : getResizeIterTypes()) { hashCombine(hash, (size_t)iter_type); } + for (const auto& [tv, slice_descs] : getSliceDescriptors()) { + for (const auto& desc : slice_descs) { + hashCombine(hash, desc.hash()); + } + } return hash; } diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 2be8575b8aa..491605942b7 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -152,17 +152,23 @@ struct Concrete1DSliceDescriptor { //! This can be either Iteration or Broadcast (if sliced extent is 1) IterType iter_type = IterType::Iteration; - // NOTE: In the future, we can hold branches for "step" here as well, in order - // to specialize when step == 1 - bool operator==(const Concrete1DSliceDescriptor& other) const { return start_branch == other.start_branch && - stop_branch == other.stop_branch && is_empty == other.is_empty && - iter_type == other.iter_type; + stop_branch == other.stop_branch && step_branch == other.step_branch && + is_empty == other.is_empty && iter_type == other.iter_type; } bool operator!=(const Concrete1DSliceDescriptor& other) const { return !operator==(other); } + + size_t hash() const { + size_t h = (size_t)start_branch; + hashCombine(h, (size_t)stop_branch); + hashCombine(h, (size_t)step_branch); + hashCombine(h, (size_t)is_empty); + hashCombine(h, (size_t)iter_type); + return h; + } }; //! A set of transformations for a symbolic fusion with concrete sizes From c0593f507f47903608eb5b9b6f0fa1dd874a9cb5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 09:47:40 -0400 Subject: [PATCH 06/14] Fix bug that looked at id->isConstInt() instead of extent Should use id->extent()->isConstInt() instead. --- csrc/ops/alias.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index d1af14004c9..13a5edc2d9f 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -698,11 +698,12 @@ TensorView* slice( bool needs_real_slicing = false; for (const auto idx : c10::irange(ndims)) { auto inp_root_id = inp_dom[idx]; - auto range = normalize_slice_range(ranges.at(idx), inp_root_id->extent()); + auto inp_extent = inp_root_id->getMaybeExpandedExtent(); + auto range = normalize_slice_range(ranges.at(idx), inp_extent); normalized_ranges.at(idx) = range; IterDomain* out_root_id = nullptr; IterDomain* out_rf_id = nullptr; - if (range.start->isZeroInt() && range.stop->sameAs(inp_root_id->extent()) && + if (range.start->isZeroInt() && range.stop->sameAs(inp_extent) && range.step->isOneInt()) { // This dim doesn't need slicing out_root_id = inp_root_id->cloneWithoutRFactor(); @@ -714,26 +715,26 @@ TensorView* slice( // expressions which will be simplified at concretization. Here we set // the output to Symbolic unless all required scalars are constant. if (range.start->isConstInt() && range.stop->isConstInt() && - inp_root_id->isConstInt()) { + inp_extent->isConstInt()) { auto start = range.start->evaluateInt(); auto stop = range.stop->evaluateInt(); auto step = range.step->evaluateInt(); TORCH_INTERNAL_ASSERT(step != 0, "Slice step must be non-zero"); TORCH_INTERNAL_ASSERT( step == 1, "Slicing with step != 1 is not currently supported"); - auto inp_extent = inp_root_id->extent()->evaluateInt(); - adjust_start_stop(start, inp_extent); - adjust_start_stop(stop, inp_extent); + auto inp_extent_val = inp_extent->evaluateInt(); + adjust_start_stop(start, inp_extent_val); + adjust_start_stop(stop, inp_extent_val); out_rf_id = IterDomain::resize( out_root_id, SimplifyingIrBuilder::negExpr(IrBuilder::create(start)), - sub(IrBuilder::create(stop), inp_root_id->extent()), + sub(IrBuilder::create(stop), inp_extent), true); } else if (skip_symbolic) { out_rf_id = IterDomain::resize( out_root_id, SimplifyingIrBuilder::negExpr(range.start), - sub(range.stop, inp_root_id->extent()), + sub(range.stop, inp_extent), true); } else { out_rf_id = IterDomainBuilder( From 6ee7ab4d0b3d38db728dd83ecc0875ffec62bb7c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 13:18:43 -0400 Subject: [PATCH 07/14] Set definitions of concretized extents This is in lieu of replacing all uses of symbolic extents, which cannot be done reliably since they might appear as attributes or members of objects which are untracked. See #420 --- csrc/dynamic_transform.cpp | 29 ++++++++++++++++++++++++++++- csrc/evaluator_common.cpp | 18 ++++++++++++++++++ csrc/evaluator_common.h | 4 ++++ csrc/ir/utils.cpp | 15 ++++++++++++++- csrc/kernel_cache.cpp | 4 ++++ csrc/kernel_cache.h | 2 ++ 6 files changed, 70 insertions(+), 2 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 2ded6ccc214..681119ba2e8 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -489,6 +489,8 @@ class DynamicTransformConcretizer : public OptOutMutator { void concretizeResize(); + void concretizeIterDomain(IterDomain* old_id, IterDomain* new_id); + //! Use this instead of calling registerMutation directly, since it will also //! check that the concretized value is a valid input to all of its uses. void registerConcretization(Val* old_val, Val* new_val) { @@ -572,6 +574,7 @@ void DynamicTransformConcretizer::concretizeSlice() { auto inp_tv = slice_op->input(0)->as(); const auto& root_dom = incomplete_out_tv->getRootDomain(); + const auto& rfactor_dom = incomplete_out_tv->getRFactorDomain(); // Create new rfactor domain with potentially newly-resized root IDs std::vector new_rfactor(root_dom.size()); @@ -648,6 +651,12 @@ void DynamicTransformConcretizer::concretizeSlice() { new_tv = slice(inp_tv, new_ranges, /*skip_symbolic*/ true); } + for (auto i : c10::irange(root_dom.size())) { + auto id = rfactor_dom.at(i); + auto new_id = new_tv->getRFactorDomain().at(i); + concretizeIterDomain(id, new_id); + } + // TODO: We need to update the maybeRFactorDomains of new_tv if there are // any Broadcast sliced dimensions. @@ -683,10 +692,28 @@ void DynamicTransformConcretizer::concretizeResize() { id->isRFactorProduct(), iter_type); - registerConcretization(id, new_id); + concretizeIterDomain(id, new_id); } } +void DynamicTransformConcretizer::concretizeIterDomain( + IterDomain* old_id, + IterDomain* new_id) { + auto old_ext = old_id->getMaybeExpandedExtent(); + auto new_ext = new_id->getMaybeExpandedExtent(); + + // If symbolic, redefine old_ext to equal new_ext + if (old_ext != new_ext && !old_ext->isConst() && !old_ext->definition()) { + // Set the old_ext->definition() to equal set(new_ext). This way, any + // further uses of old_ext are valid, as they will now be computed using + // new_ext. + IrBuilder::create(LoadStoreOpType::Set, old_ext, new_ext); + } + + registerConcretization(old_ext, new_ext); + registerConcretization(old_id, new_id); +} + void DynamicTransformConcretizer::checkConcretizedUses( Val* old_val, Val* new_val) const { diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5a3933f4363..8293651b06f 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -367,6 +367,12 @@ NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values) makeUnaryOp(uop); } else if (auto bop = dynamic_cast(def)) { makeBinaryOp(bop); + } else if (auto setop = dynamic_cast(def)) { + TORCH_INTERNAL_ASSERT( + setop->opType() == LoadStoreOpType::Set, + "NaiveValueMachine: unsupported LoadStoreOpType: ", + setop->opType()); + makeSetOp(setop); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr"); } @@ -448,6 +454,18 @@ void NaiveValueMachine::makeBinaryOp(BinaryOp* bop) { dest_[index] = out; } +void NaiveValueMachine::makeSetOp(LoadStoreOp* lsop) { + int in = lsop->inputs()[0]->evaluatorIndex(); + int out = lsop->outputs()[0]->evaluatorIndex(); + TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: ", lsop); + TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", lsop); + + int index = makeInstructionEntry(); + inst_type_[index] = InstructionType::SET_OP; + src0_[index] = in; + dest_[index] = out; +} + int NaiveValueMachine::makeInstructionEntry() { int index = num_of_instructions_++; inst_type_.emplace_back(InstructionType::UNARY_OP); diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index 112d770a27e..3bb198bf78b 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -55,6 +55,10 @@ class NaiveValueMachine { //! Convert an binary IR expr to an instruction void makeBinaryOp(BinaryOp* bop); + //! Convert a LoadStoreOp expr to an instruction. This assumes lsop->opType() + //! is equal to LoadStoreOpType::Set. + void makeSetOp(LoadStoreOp* lsop); + //! Create an empty instruction with all default values //! and place it at the end of the instruction buffer. int makeInstructionEntry(); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 3315cb3167d..f77c75fd503 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -654,7 +654,7 @@ struct ReplaceValInIndexVal : public OptInDispatch { // Recursively traverse its defining expr auto def = val->definition(); if (def != nullptr) { - if (def->isOneOf()) { + if (def->isOneOf()) { handle(val->definition()); } else { TORCH_INTERNAL_ASSERT(false, "Unexpected definition: ", def->toString()) @@ -665,6 +665,19 @@ struct ReplaceValInIndexVal : public OptInDispatch { } } + // Clone expression after recurisvely replacing inputs + void handle(LoadStoreOp* lsop) override { + handle(lsop->in()); + auto inp = last_visited_val_; + TORCH_INTERNAL_ASSERT( + lsop->out()->isA() || lsop->out()->isA(), + "Unknown output type for expr ", + lsop->toInlineString()); + auto out = IrBuilder::create(c10::nullopt); + IrBuilder::create(lsop->opType(), out, inp); + last_visited_val_ = out; + } + // Clone expression after recurisvely replacing inputs void handle(UnaryOp* uop) override { handle(uop->in()); diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 6e7702e24b8..32e8e674e07 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -293,6 +293,10 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( FusionExecutorCache::FusionExecutorCache(std::unique_ptr fusion) : fusion_(std::move(fusion)) {} +Fusion* FusionExecutorCache::getMostRecentConcretizedFusion() const { + return most_recent_runtime_->fusionSegments()->completeFusion(); +} + KernelArgumentHolder FusionExecutorCache::prepareInputs( const at::ArrayRef& inputs, std::optional selected_device) { diff --git a/csrc/kernel_cache.h b/csrc/kernel_cache.h index 8fc0cc51a3d..a2fd2cbbb62 100644 --- a/csrc/kernel_cache.h +++ b/csrc/kernel_cache.h @@ -507,6 +507,8 @@ class TORCH_CUDA_CU_API FusionExecutorCache { fusion_->printMath(); } + Fusion* getMostRecentConcretizedFusion() const; + FusionKernelRuntime* getMostRecentKernelRuntime() const { return most_recent_runtime_; } From a53912b5157bf25df9f58a6b4fad679491596e50 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 22 Jun 2023 13:20:20 -0400 Subject: [PATCH 08/14] Use FusionExecutorCache::getMostRecentConcretizedFusion() in tests --- test/test_resize.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 29b16d6ed74..68640aded3a 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1242,7 +1242,7 @@ TEST_F(NVFuserTest, FusionResizeSliceReduceScheduler1_CUDA) { auto ref = t1.sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref}, @@ -1300,7 +1300,7 @@ TEST_F(NVFuserTest, FusionResizeSliceReduceScheduler2_CUDA) { auto t4 = t3.sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2, t4}, @@ -1354,7 +1354,7 @@ TEST_F(NVFuserTest, FusionSliceReduceScheduler3_CUDA) { auto t4 = t3.to(at::kDouble).sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2, t4}, @@ -1391,7 +1391,7 @@ TEST_F(NVFuserTest, FusionResizeCatReduceScheduler1_CUDA) { auto ref = at::cat({t0, t1}, 1).sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref}, @@ -1429,7 +1429,7 @@ TEST_F(NVFuserTest, FusionResizeCatSoftmaxScheduler1_CUDA) { auto ref = at::_softmax(t2.to(at::kDouble), -1, false); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref}, @@ -1466,7 +1466,7 @@ TEST_F(NVFuserTest, FusionResizeReductionSliceScheduler1_CUDA) { auto t2 = t1.index({at::indexing::Slice(1, shape0[0] - 2)}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2}, @@ -1507,7 +1507,7 @@ TEST_F(NVFuserTest, FusionResizeSoftmaxSliceScheduler1_CUDA) { at::indexing::Slice(0, at::indexing::None)}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2}, @@ -1548,7 +1548,7 @@ TEST_F(NVFuserTest, FusionResizeSoftmaxSliceScheduler2_CUDA) { at::indexing::Slice(1, shape0[1] - 2)}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2}, @@ -1710,7 +1710,7 @@ TEST_F(NVFuserTest, FusionSliceForNanoGPT1_CUDA) { auto aten_t3 = torch::add(aten_t0_slice, t1); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {aten_t3, aten_t3}, @@ -1836,7 +1836,7 @@ TEST_F(NVFuserTest, FusionSliceForNanoGPT2_CUDA) { auto aten_t7 = aten_t2 + 1; testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {aten_t4, aten_t4, aten_t7}, @@ -2001,7 +2001,7 @@ TEST_F(NVFuserTest, ResizePermuteAndSlice_CUDA) { auto ref_t4 = ref_t2 + 1; testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref_t3, ref_t4}, From 950e17e56dea3ccff398cf0e902b69226ac3ad0d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 23 Jun 2023 09:01:42 -0400 Subject: [PATCH 09/14] Condense initial info into topo sorted vector of vals This lets us process ops in topological order instead of doing all reshapes followed by all slices etc. This is very helpful since we need to evaluate vals that are not defined until their upstream ops are concretized. --- csrc/dynamic_transform.cpp | 462 +++++++++++++++++++------------------ csrc/dynamic_transform.h | 51 ++-- csrc/ir/nodes.cpp | 9 +- 3 files changed, 262 insertions(+), 260 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 681119ba2e8..cc5b620f0b5 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -27,22 +27,10 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( IrCloner& ir_cloner) const { DynamicTransformInitialInfo cloned_info( static_cast(ir_cloner.container())); - cloned_info.dynamic_reshaped_tvs_.reserve(dynamic_reshaped_tvs_.size()); - for (const auto op : dynamic_reshaped_tvs_) { - if (op) { - cloned_info.dynamic_reshaped_tvs_.push_back(ir_cloner.clone(op)); - } - } - cloned_info.dynamic_resized_ids_.reserve(dynamic_resized_ids_.size()); - for (const auto op : dynamic_resized_ids_) { - if (op) { - cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op)); - } - } - cloned_info.dynamic_sliced_tvs_.reserve(dynamic_sliced_tvs_.size()); - for (const auto v : dynamic_sliced_tvs_) { + cloned_info.dynamic_expr_outputs_.reserve(dynamic_expr_outputs_.size()); + for (const auto v : dynamic_expr_outputs_) { if (v) { - cloned_info.dynamic_sliced_tvs_.push_back(ir_cloner.clone(v)); + cloned_info.dynamic_expr_outputs_.push_back(ir_cloner.clone(v)); } } cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); @@ -58,19 +46,11 @@ std::string DynamicTransformInitialInfo::toString() const { std::stringstream ss; ss << "DynamicTransformInitialInfo\n"; std::string indent = " "; - ss << indent << "Dynamic reshaped TensorViews:\n"; - for (const auto& op : dynamic_reshaped_tvs_) { - ss << indent << indent << op->toString() << "\n"; - } - ss << indent << "Dynamic sliced TensorViews:\n"; - for (const auto& op : dynamic_sliced_tvs_) { - ss << indent << indent << op->toString() << "\n"; - } - ss << indent << "Dynamic resized IterDomains:\n"; - for (const auto& op : dynamic_resized_ids_) { - ss << indent << indent << op->toString() << "\n"; + ss << indent << "Dynamic Vals:\n"; + for (const auto& val : dynamic_expr_outputs_) { + ss << indent << indent << val->toString() << "\n"; } - ss << indent << "Root dynamic Vals:\n"; + ss << indent << "Scalars determining concretization:\n"; for (const auto& v : root_dynamic_vals_) { ss << indent << indent << v->toString() << "\n"; } @@ -104,7 +84,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { auto out_tv = op->out()->as(); // If there's no symbolic axis, this is a static reshape op if (out_tv->domain()->hasSymbolicAxis()) { - info_.dynamic_reshaped_tvs_.push_back(out_tv); + info_.dynamic_expr_outputs_.push_back(out_tv); // Input and output extent expressions both affect concretization const auto& inp_dom = @@ -123,7 +103,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { void handle(TensorView* tv) override { if (tv->definition() && tv->definition()->isA()) { if (tv->domain()->hasSymbolicAxis()) { - info_.dynamic_sliced_tvs_.push_back(tv); + info_.dynamic_expr_outputs_.push_back(tv); auto root_dom = tv->getRootDomain(); const auto ranges = tv->definition()->as()->getRanges(); TORCH_INTERNAL_ASSERT( @@ -151,7 +131,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { continue; } if (id->definition()->isA()) { - info_.dynamic_resized_ids_.push_back(id); + info_.dynamic_expr_outputs_.push_back(id); // extent of output determines its IterType leaf_dynamic_vals_.push_back(id->extent()); } @@ -190,223 +170,235 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { std::vector leaf_dynamic_vals_; }; -void DynamicTransformConcretizationInfo::analyzeReshapes( +void DynamicTransformConcretizationInfo::analyze( ExpressionEvaluator* expr_eval) { - const auto& reshape_tvs = initial_info_->getDynamicReshapedTensorViews(); - for (const auto tv_index : c10::irange(reshape_tvs.size())) { - auto out_tv = reshape_tvs.at(tv_index); - auto op = out_tv->definition()->as(); - auto inp_tv = op->in()->as(); - - // If there's no symblic axis, this is a static reshape op - if (!out_tv->domain()->hasSymbolicAxis()) { - return; - } - + const auto& expr_outputs = initial_info_->getDynamicExprOutputs(); + for (auto val_index : c10::irange(expr_outputs.size())) { + auto val = expr_outputs.at(val_index); TORCH_INTERNAL_ASSERT( - out_tv->hasRFactor(), - "Unexpected output tv of ViewOp: ", - out_tv->toString()); - - const auto& inp_dom = - TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); - - // Determine input shape using expr evaluator - std::vector inp_shape(inp_dom.size(), 0); - for (const auto i : c10::irange(inp_dom.size())) { - auto inp_id = inp_dom.at(i); - // This should have been validated when initially creating reshape - // op, but just in case - TORCH_INTERNAL_ASSERT( - !inp_id->maybePartial(), - "Invalid domain to reshape: ", - inp_id->toString()); - auto extent_val = expr_eval->evaluate(inp_id->extent()); - TORCH_INTERNAL_ASSERT( - extent_val.has_value(), - "Cannot evaluate the extent of an input domain to reshape: ", - inp_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val->isInt(), - "Invalid evaluated value of domain extent: ", - inp_id->toString()); + val->definition(), + "Dynamic Val does not have a definition: ", + val->toString()); + if (val->definition()->isA()) { + analyzeReshape(expr_eval, val_index); + } else if (val->definition()->isA()) { + analyzeSlice(expr_eval, val_index); + } else if (val->definition()->isA()) { + analyzeResize(expr_eval, val_index); + } else { TORCH_INTERNAL_ASSERT( - extent_val->as() > 0, - "Invalid input domain extent: ", - extent_val->as()); - inp_shape.at(i) = extent_val->as(); + false, + "Unhandled dynamic expression: ", + val->definition()->toString()); } + } +} + +void DynamicTransformConcretizationInfo::analyzeReshape( + ExpressionEvaluator* expr_eval, + size_t val_index) { + auto out_tv = + initialInfo()->getDynamicExprOutputs().at(val_index)->as(); + auto op = out_tv->definition()->as(); + auto inp_tv = op->in()->as(); + + // If there's no symblic axis, this is a static reshape op + if (!out_tv->domain()->hasSymbolicAxis()) { + return; + } - const auto& out_dom = out_tv->getMaybeRFactorDomain(); + TORCH_INTERNAL_ASSERT( + out_tv->hasRFactor(), + "Unexpected output tv of ViewOp: ", + out_tv->toString()); + + const auto& inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + + // Determine input shape using expr evaluator + std::vector inp_shape(inp_dom.size(), 0); + for (const auto i : c10::irange(inp_dom.size())) { + auto inp_id = inp_dom.at(i); + // This should have been validated when initially creating reshape + // op, but just in case + TORCH_INTERNAL_ASSERT( + !inp_id->maybePartial(), + "Invalid domain to reshape: ", + inp_id->toString()); + auto extent_val = expr_eval->evaluate(inp_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.has_value(), + "Cannot evaluate the extent of an input domain to reshape: ", + inp_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->isInt(), + "Invalid evaluated value of domain extent: ", + inp_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->as() > 0, + "Invalid input domain extent: ", + extent_val->as()); + inp_shape.at(i) = extent_val->as(); + } + + const auto& out_dom = out_tv->getMaybeRFactorDomain(); - // Determine output shape using expr evaluator. Note there may be - // one domain of extent -1 - std::vector out_shape(out_dom.size(), 0); - bool extent_m1_found = false; - for (const auto i : c10::irange(out_dom.size())) { - auto out_id = out_dom.at(i); - auto extent_val = expr_eval->evaluate(out_id->extent()); + // Determine output shape using expr evaluator. Note there may be + // one domain of extent -1 + std::vector out_shape(out_dom.size(), 0); + bool extent_m1_found = false; + for (const auto i : c10::irange(out_dom.size())) { + auto out_id = out_dom.at(i); + auto extent_val = expr_eval->evaluate(out_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.has_value(), + "Cannot evaluate the extent of an output domain to reshape: ", + out_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->isInt(), + "Invalid evaluated value of domain extent: ", + out_id->toString()); + const auto extent_int = extent_val->as(); + if (extent_int == -1) { TORCH_INTERNAL_ASSERT( - extent_val.has_value(), - "Cannot evaluate the extent of an output domain to reshape: ", - out_id->toString()); + !extent_m1_found, + "Multiple output domains of size -1 not allowed", + out_tv->toString()); + extent_m1_found = true; + } else { TORCH_INTERNAL_ASSERT( - extent_val->isInt(), - "Invalid evaluated value of domain extent: ", - out_id->toString()); - const auto extent_int = extent_val->as(); - if (extent_int == -1) { - TORCH_INTERNAL_ASSERT( - !extent_m1_found, - "Multiple output domains of size -1 not allowed", - out_tv->toString()); - extent_m1_found = true; - } else { - TORCH_INTERNAL_ASSERT( - extent_int > 0, "Invalid output domain extent: ", extent_int); - } - out_shape.at(i) = extent_int; + extent_int > 0, "Invalid output domain extent: ", extent_int); } + out_shape.at(i) = extent_int; + } - auto view_result = analyzeView(inp_tv, inp_shape, out_shape); + auto view_result = analyzeView(inp_tv, inp_shape, out_shape); - reshape_transforms_.emplace_back(tv_index, view_result); - } + reshape_transforms_.emplace_back(val_index, view_result); } -void DynamicTransformConcretizationInfo::analyzeSlices( - ExpressionEvaluator* expr_eval) { - const auto& sliced_tvs = initial_info_->getDynamicSlicedTensorViews(); - for (auto tv_index : c10::irange(sliced_tvs.size())) { - auto out_tv = sliced_tvs.at(tv_index); - auto op = out_tv->definition()->as(); - auto inp_tv = op->in()->as(); - const auto inp_dom = - TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); - const auto out_dom = out_tv->getMaybeRFactorDomain(); - const auto ranges = op->getRanges(); - std::vector slice_descs(inp_dom.size()); - for (auto i : c10::irange(inp_dom.size())) { - const auto& range = ranges.at(i); - - auto start_opt = expr_eval->evaluate(range.start); - TORCH_INTERNAL_ASSERT( - start_opt.has_value(), - "Could not evaluate start of slice range ", - range.start); - auto start = start_opt->as(); +void DynamicTransformConcretizationInfo::analyzeSlice( + ExpressionEvaluator* expr_eval, + size_t val_index) { + auto out_tv = + initialInfo()->getDynamicExprOutputs().at(val_index)->as(); + auto op = out_tv->definition()->as(); + auto inp_tv = op->in()->as(); + const auto inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + const auto out_dom = out_tv->getMaybeRFactorDomain(); + const auto ranges = op->getRanges(); + std::vector slice_descs(inp_dom.size()); + for (auto i : c10::irange(inp_dom.size())) { + const auto& range = ranges.at(i); + + auto start_opt = expr_eval->evaluate(range.start); + TORCH_INTERNAL_ASSERT( + start_opt.has_value(), + "Could not evaluate start of slice range ", + range.start); + auto start = start_opt->as(); - auto stop_opt = expr_eval->evaluate(range.stop); - TORCH_INTERNAL_ASSERT( - stop_opt.has_value(), - "Could not evaluate stop of slice range ", - range.stop); - auto stop = stop_opt->as(); + auto stop_opt = expr_eval->evaluate(range.stop); + TORCH_INTERNAL_ASSERT( + stop_opt.has_value(), + "Could not evaluate stop of slice range ", + range.stop); + auto stop = stop_opt->as(); - auto step_opt = expr_eval->evaluate(range.step); - TORCH_INTERNAL_ASSERT( - step_opt.has_value(), - "Could not evaluate step of slice range ", - range.step); - auto step = step_opt->as(); + auto step_opt = expr_eval->evaluate(range.step); + TORCH_INTERNAL_ASSERT( + step_opt.has_value(), + "Could not evaluate step of slice range ", + range.step); + auto step = step_opt->as(); - TORCH_INTERNAL_ASSERT(step != 0, "Slice step must not be zero"); - TORCH_INTERNAL_ASSERT( - step == 1, "Slicing with step != 1 is not currently supported"); + TORCH_INTERNAL_ASSERT(step != 0, "Slice step must not be zero"); + TORCH_INTERNAL_ASSERT( + step == 1, "Slicing with step != 1 is not currently supported"); - auto inp_extent_opt = - expr_eval->evaluate(inp_dom.at(i)->getMaybeExpandedExtent()); - TORCH_INTERNAL_ASSERT( - inp_extent_opt.has_value(), - "Could not evaluate slice input extent ", - inp_dom.at(i)->getMaybeExpandedExtent()); - auto inp_extent = inp_extent_opt->as(); - - auto getBranch = [&inp_extent](int64_t a) -> SliceIndexBranch { - if (a <= -inp_extent) { - return SliceIndexBranch::AlwaysZero; - } else if (a < 0) { - return SliceIndexBranch::Negative; - } else if (a < inp_extent) { - return SliceIndexBranch::Positive; - } else { - return SliceIndexBranch::AlwaysExtent; - } - }; - slice_descs[i].start_branch = getBranch(start); - slice_descs[i].stop_branch = getBranch(stop); - slice_descs[i].is_empty = (stop - start) * step <= 0; - - // The dynamic slice output has a purely symbolic extent. Here we evaluate - // the proper extent to determine the output IterType. - auto map_int_index = [&inp_extent](const int64_t& a) -> int64_t { - if (a <= -inp_extent) { - return 0; - } else if (a < 0) { - return -a; - } else if (a < inp_extent) { - return a; - } else { - return inp_extent; - } - }; - // actual size of sliced dimension is ceilDiv(stop - start, step) when - // step > 0. When step < 0, that expression is off by one and instead the - // extent in that case is ceilDiv(start - stop, -step). - auto concrete_sliced_extent = step > 0 - ? (map_int_index(stop) - map_int_index(start) + step - 1) / step - : (map_int_index(stop) - map_int_index(start) + step + 1) / step; - - if (concrete_sliced_extent == 1) { - slice_descs[i].iter_type = IterType::Broadcast; + auto inp_extent_opt = + expr_eval->evaluate(inp_dom.at(i)->getMaybeExpandedExtent()); + TORCH_INTERNAL_ASSERT( + inp_extent_opt.has_value(), + "Could not evaluate slice input extent ", + inp_dom.at(i)->getMaybeExpandedExtent()); + auto inp_extent = inp_extent_opt->as(); + + auto getBranch = [&inp_extent](int64_t a) -> SliceIndexBranch { + if (a <= -inp_extent) { + return SliceIndexBranch::AlwaysZero; + } else if (a < 0) { + return SliceIndexBranch::Negative; + } else if (a < inp_extent) { + return SliceIndexBranch::Positive; + } else { + return SliceIndexBranch::AlwaysExtent; } - - // Even though we will eventually replace this TV, there will still be - // references to its extents in downstream uses. We will need to evaluate - // these properly both in this analysis, and at concretization. Here we - // bind the output extent so that downstream extents can be properly - // computed during this analysis. After concretization, this will happen - // via ExpressionEvaluator::propagateBoundValuesThroughExactMaps. - expr_eval->bind(out_dom[i]->extent(), concrete_sliced_extent); + }; + slice_descs[i].start_branch = getBranch(start); + slice_descs[i].stop_branch = getBranch(stop); + slice_descs[i].is_empty = (stop - start) * step <= 0; + + // The dynamic slice output has a purely symbolic extent. Here we evaluate + // the proper extent to determine the output IterType. + auto map_int_index = [&inp_extent](const int64_t& a) -> int64_t { + if (a <= -inp_extent) { + return 0; + } else if (a < 0) { + return -a; + } else if (a < inp_extent) { + return a; + } else { + return inp_extent; + } + }; + // actual size of sliced dimension is ceilDiv(stop - start, step) when + // step > 0. When step < 0, that expression is off by one and instead the + // extent in that case is ceilDiv(start - stop, -step). + auto concrete_sliced_extent = step > 0 + ? (map_int_index(stop) - map_int_index(start) + step - 1) / step + : (map_int_index(stop) - map_int_index(start) + step + 1) / step; + + if (concrete_sliced_extent == 1) { + slice_descs[i].iter_type = IterType::Broadcast; } - slice_descriptors_.emplace_back(tv_index, slice_descs); } + slice_descriptors_.emplace_back(val_index, slice_descs); } -void DynamicTransformConcretizationInfo::analyzeResizes( - ExpressionEvaluator* expr_eval) { - const auto& resize_ids = initial_info_->getDynamicResizedIterDomains(); - for (const auto id_index : c10::irange(resize_ids.size())) { - auto out_id = resize_ids.at(id_index); - auto op = out_id->definition()->as(); +void DynamicTransformConcretizationInfo::analyzeResize( + ExpressionEvaluator* expr_eval, + size_t val_index) { + auto out_id = + initialInfo()->getDynamicExprOutputs().at(val_index)->as(); + auto op = out_id->definition()->as(); - TORCH_CHECK( - out_id->getIterType() == IterType::Symbolic, - "Found non-dynamic Resize in initial concretization info: ", - op->toString()); + TORCH_CHECK( + out_id->getIterType() == IterType::Symbolic, + "Found non-dynamic Resize in initial concretization info: ", + op->toString()); - auto extent_val = expr_eval->evaluate(out_id->extent()); - TORCH_INTERNAL_ASSERT( - extent_val.has_value(), - "Cannot evaluate the extent of a resized domain: ", - out_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val->isInt(), - "Invalid evaluated value of resized domain extent: ", - out_id->toString()); - auto extent_int = extent_val->as(); - TORCH_INTERNAL_ASSERT( - extent_int > 0, - "Invalid resized domain extent ", - extent_int, - " for domain ", - out_id->toString()); + auto extent_val = expr_eval->evaluate(out_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.has_value(), + "Cannot evaluate the extent of a resized domain: ", + out_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->isInt(), + "Invalid evaluated value of resized domain extent: ", + out_id->toString()); + auto extent_int = extent_val->as(); + TORCH_INTERNAL_ASSERT( + extent_int > 0, + "Invalid resized domain extent ", + extent_int, + " for domain ", + out_id->toString()); - auto iter_type = - extent_int == 1 ? IterType::Broadcast : IterType::Iteration; + auto iter_type = extent_int == 1 ? IterType::Broadcast : IterType::Iteration; - resize_itertypes_.emplace_back(id_index, iter_type); - } + resize_itertypes_.emplace_back(val_index, iter_type); } bool DynamicTransformConcretizationInfo::operator==( @@ -454,13 +446,13 @@ std::string DynamicTransformConcretizationInfo::toString() const { std::string indent = " "; ss << indent << "Reshape:\n"; for (const auto& [tv_index, analyze_result] : reshape_transforms_) { - auto tv = initial_info_->getDynamicReshapedTensorViews().at(tv_index); + auto tv = initial_info_->getDynamicExprOutputs().at(tv_index); ss << indent << indent << tv->toString() << " (index=" << tv_index << "), " << analyze_result.toString() << "\n"; } ss << indent << "Resize:\n"; for (const auto& [id_index, iter_type] : resize_itertypes_) { - auto id = initial_info_->getDynamicResizedIterDomains().at(id_index); + auto id = initial_info_->getDynamicExprOutputs().at(id_index); ss << indent << indent << id->toString() << " (index=" << id_index << "), " << iter_type << "\n"; } @@ -539,13 +531,22 @@ void DynamicTransformConcretizer::concretize() { void DynamicTransformConcretizer::concretizeReshape() { // Concretize each reshape op. for (const auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { - auto incomplete_out_tv = - info_->initialInfo()->getDynamicReshapedTensorViews().at(tv_index); + auto incomplete_out_tv = info_->initialInfo() + ->getDynamicExprOutputs() + .at(tv_index) + ->as(); auto view_op = incomplete_out_tv->definition()->as(); auto inp_tv = view_op->in()->as(); auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); + for (auto i : + c10::irange(incomplete_out_tv->getMaybeRFactorDomain().size())) { + auto id = incomplete_out_tv->getMaybeRFactorDomain().at(i); + auto new_id = concrete_reshape_out_tv->getMaybeRFactorDomain().at(i); + concretizeIterDomain(id, new_id); + } + // We do the replacement directly here, but we must still check that the // replacement is valid checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); @@ -568,8 +569,10 @@ void DynamicTransformConcretizer::concretizeReshape() { void DynamicTransformConcretizer::concretizeSlice() { auto fusion = FusionGuard::getCurFusion(); for (const auto& [tv_index, slice_descs] : info_->getSliceDescriptors()) { - auto incomplete_out_tv = - info_->initialInfo()->getDynamicSlicedTensorViews().at(tv_index); + auto incomplete_out_tv = info_->initialInfo() + ->getDynamicExprOutputs() + .at(tv_index) + ->as(); auto slice_op = incomplete_out_tv->definition()->as(); auto inp_tv = slice_op->input(0)->as(); @@ -680,7 +683,10 @@ void DynamicTransformConcretizer::concretizeSlice() { void DynamicTransformConcretizer::concretizeResize() { // Concretize each resize op. for (const auto& [id_index, iter_type] : info_->getResizeIterTypes()) { - auto id = info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); + auto id = info_->initialInfo() + ->getDynamicExprOutputs() + .at(id_index) + ->as(); TORCH_CHECK( id->definition() && id->definition()->isA(), "Resized IterDomain must have a Resize definition"); diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 491605942b7..9493620d0f8 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -40,8 +40,7 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { //! Return whether any dynamic transforms exist in the Fusion bool hasDynamicTransforms() const { - return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty() || - !dynamic_sliced_tvs_.empty(); + return !dynamic_expr_outputs_.empty(); } //! Return a set of scalars that are inputs or extents of input TensorViews @@ -51,21 +50,8 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { return root_dynamic_vals_; } - //! Return a vector of outputs of ViewOp expressions that have dynamic output - //! shapes - const std::vector& getDynamicReshapedTensorViews() const { - return dynamic_reshaped_tvs_; - } - - //! Return a vector of outputs of Resize expressions that have symbolic output - //! IterTypes - const std::vector& getDynamicResizedIterDomains() const { - return dynamic_resized_ids_; - } - - //! Return a vector of outputs of Slice expressions - const std::vector& getDynamicSlicedTensorViews() const { - return dynamic_sliced_tvs_; + const std::vector& getDynamicExprOutputs() const { + return dynamic_expr_outputs_; } std::string toString() const; @@ -95,9 +81,9 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { // definitions will merely be altered. When the ops are replaced, if we had // referred to them directly here, we would run into segfaults. Referring only // to the outputs avoids this issue. - std::vector dynamic_reshaped_tvs_; + // std::vector dynamic_reshaped_tvs_; - std::vector dynamic_resized_ids_; + // std::vector dynamic_resized_ids_; // Slice operations can have complicated output extents. The inputs to slice // are a start, stop, and step for each sliced dimension. Each of these is an @@ -115,7 +101,10 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { // Here we keep track of non-static slices or slices with non-static input // extents. That way we can restrict to a single branch in each of these // expressions during concretization. - std::vector dynamic_sliced_tvs_; + // std::vector dynamic_sliced_tvs_; + + // This is a topologically sorted list of outputs of dynamic operations. + std::vector dynamic_expr_outputs_; // Root Vals that determine concretization std::unordered_set root_dynamic_vals_; @@ -187,11 +176,7 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { // evaluator when any one of the IDs has a known value expr_eval->propagateBoundValuesThroughExactMaps(initial_info->fusion()); - analyzeReshapes(expr_eval); - - analyzeSlices(expr_eval); - - analyzeResizes(expr_eval); + analyze(expr_eval); } //! Return a vector of pairs holding the index of each reshaped TensorView in @@ -231,17 +216,21 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { } //! Given an ExpressionEvaluator which already has input scalars bound to it, - //! determine the decomposition of each dynamic reshape operation to use + //! analyze all dynamic ops in topological order. + void analyze(ExpressionEvaluator* expr_eval); + + //! Given an ExpressionEvaluator which already has input scalars bound to it, + //! determine the decomposition of a dynamic reshape operation to use //! during concretization. - void analyzeReshapes(ExpressionEvaluator* expr_eval); + void analyzeReshape(ExpressionEvaluator* expr_eval, size_t val_index); //! Given an ExpressionEvaluator which already has input scalars bound to it, - //! determine the branches of expressions in dynamic slice ops. - void analyzeSlices(ExpressionEvaluator* expr_eval); + //! determine the branches of expressions in a dynamic slice op. + void analyzeSlice(ExpressionEvaluator* expr_eval, size_t val_index); //! Given an ExpressionEvaluator which already has input scalars bound to it, - //! determine the concrete IterType of each resized IterDomain. - void analyzeResizes(ExpressionEvaluator* expr_eval); + //! determine the concrete IterType of a resized IterDomain. + void analyzeResize(ExpressionEvaluator* expr_eval, size_t val_index); const DynamicTransformInitialInfo* initialInfo() const { return initial_info_; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 295e11ad789..1bf336c9272 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2113,7 +2113,14 @@ std::string LoadStoreOp::toString(int indent_size) const { } std::string LoadStoreOp::toInlineString(int indent_size) const { - TORCH_CHECK(false, "Tensor op can not be printed inline"); + if (opType() == LoadStoreOpType::Set) { + TORCH_CHECK( + !in()->isA(), "Cannot print TensorView set() inline"); + std::stringstream ss; + indent(ss, indent_size) << "set(" << in()->toInlineString() << ")"; + return ss.str(); + } + TORCH_CHECK(false, "Non-'Set' LoadStoreOp cannot be printed inline"); } bool LoadStoreOp::hasTranspose() const { From 7337070a596b3d54cf6d03e54bbc58d4afc6f01a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 23 Jun 2023 09:02:43 -0400 Subject: [PATCH 10/14] Update FusionSliceForNanoGPT2 to reflect slice concretization --- test/test_resize.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/test_resize.cpp b/test/test_resize.cpp index 718f5120975..4c312ef7dea 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1777,10 +1777,15 @@ TEST_F(NVFuserTest, FusionSliceForNanoGPT2_CUDA) { continue; } auto out_tv = ir_utils::getTvOutput(expr); - if (out_tv->name() == tv3->name() || out_tv->name() == tv5->name()) { + if ( + // Note: We want to check tv3 and tv5 here, which are the original slice + // outputs. During concretization, T3 gets concretized to T8 and T5 gets + // concretized to T9. Here we manually translate those since we don't + // maintain a mapping from pre-concretized to concretized Vals. + out_tv->name() == 8 || out_tv->name() == 9) { TORCH_CHECK( expr->isA(), - "Unexpected defintion of slice output tensor: ", + "Unexpected definition of slice output tensor: ", out_tv->toString(), ", ", expr->toString()); From 7f65db74de900f539b78bd5b07790f14624ff7ee Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 23 Jun 2023 09:29:53 -0400 Subject: [PATCH 11/14] Use unordered_set to uniquify dynamic_expr_outputs_ --- csrc/dynamic_transform.cpp | 42 +++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index cc5b620f0b5..ff9957aaeaa 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -78,13 +78,21 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { private: using IterVisitor::handle; + bool maybeInsertDynamicExprOutput(Val* val) { + auto inserted = inserted_dynamic_expr_outputs_.insert(val).second; + if (inserted) { + info_.dynamic_expr_outputs_.push_back(val); + } + return inserted; + } + //! Find views that have symbolic outputs void handle(ViewOp* op) override { auto inp_tv = op->in()->as(); auto out_tv = op->out()->as(); // If there's no symbolic axis, this is a static reshape op if (out_tv->domain()->hasSymbolicAxis()) { - info_.dynamic_expr_outputs_.push_back(out_tv); + maybeInsertDynamicExprOutput(out_tv); // Input and output extent expressions both affect concretization const auto& inp_dom = @@ -99,11 +107,34 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { } } + //! Find slices that have symbolic outputs + void handle(SliceOp* op) override { + auto inp_tv = op->in()->as(); + auto out_tv = op->out()->as(); + // If there's no symbolic axis, this is a static slice op + if (out_tv->domain()->hasSymbolicAxis()) { + maybeInsertDynamicExprOutput(out_tv); + + // Input extent and ranges affect concretization + const auto& inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + for (const auto id : inp_dom) { + leaf_dynamic_vals_.push_back(id->extent()); + } + for (const auto& range : op->getRanges()) { + leaf_dynamic_vals_.push_back(range.start); + leaf_dynamic_vals_.push_back(range.stop); + leaf_dynamic_vals_.push_back(range.step); + } + } + } + //! Detect dynamic IterDomain transforms when handling TensorViews void handle(TensorView* tv) override { if (tv->definition() && tv->definition()->isA()) { if (tv->domain()->hasSymbolicAxis()) { - info_.dynamic_expr_outputs_.push_back(tv); + maybeInsertDynamicExprOutput(tv); + auto root_dom = tv->getRootDomain(); const auto ranges = tv->definition()->as()->getRanges(); TORCH_INTERNAL_ASSERT( @@ -131,7 +162,8 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { continue; } if (id->definition()->isA()) { - info_.dynamic_expr_outputs_.push_back(id); + maybeInsertDynamicExprOutput(id); + // extent of output determines its IterType leaf_dynamic_vals_.push_back(id->extent()); } @@ -168,6 +200,10 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { //! scalars that influence concretization. That list of scalars is then used //! to compute a minimal cache key in InputsIdLookup::lookupId(). std::vector leaf_dynamic_vals_; + + //! In order to prevent redundant processing of dynamic ops, we use an + //! unordered_set here to track which values have already been inserted. + std::unordered_set inserted_dynamic_expr_outputs_; }; void DynamicTransformConcretizationInfo::analyze( From d201c97536f6d67a9e85c9d00011494aea03c667 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 27 Jun 2023 08:15:57 -0400 Subject: [PATCH 12/14] Set iter_type for slice with skip_symbolic=true --- csrc/ops/alias.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 13a5edc2d9f..35a604a2e32 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -735,7 +735,8 @@ TensorView* slice( out_root_id, SimplifyingIrBuilder::negExpr(range.start), sub(range.stop, inp_extent), - true); + true, + IterType::Iteration); } else { out_rf_id = IterDomainBuilder( FusionGuard::getCurFusion()->zeroVal(), From f551c78f454b2e3dbda0e38a2c7cc01fe27061b0 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 27 Jun 2023 08:16:49 -0400 Subject: [PATCH 13/14] Concretize in topological order during mutation. Previously we concretized all reshapes then all slices then all resizes. This failed the FusionSliceForNanoGPT3_CUDA test which has the slice->reshape->slice->reshape pattern, since we were concretizing reshapes that had inputs which were not yet concretized and which actually needed to be replaced. Since concretization of slices and reshapes can actually replace a TensorView, we need to ensure that those are done in the correct order. --- csrc/dynamic_transform.cpp | 384 +++++++++++++++++++++---------------- 1 file changed, 214 insertions(+), 170 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index ff9957aaeaa..c5c8a5cecf1 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -399,6 +399,8 @@ void DynamicTransformConcretizationInfo::analyzeSlice( if (concrete_sliced_extent == 1) { slice_descs[i].iter_type = IterType::Broadcast; } + + expr_eval->bind(out_dom.at(i)->extent(), concrete_sliced_extent); } slice_descriptors_.emplace_back(val_index, slice_descs); } @@ -505,23 +507,48 @@ class DynamicTransformConcretizer : public OptOutMutator { TORCH_INTERNAL_ASSERT( fusion == info->fusion(), "Invalid DynamicTransformInitialInfo. The associated Fusion is different from the given Fusion"); + for (auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { + view_analyses_[info_->initialInfo() + ->getDynamicExprOutputs() + .at(tv_index) + ->as()] = view_analysis; + } + for (auto& [tv_index, slice_descs] : info_->getSliceDescriptors()) { + slice_descriptors_[info_->initialInfo() + ->getDynamicExprOutputs() + .at(tv_index) + ->as()] = slice_descs; + } + for (auto& [id_index, iter_type] : info_->getResizeIterTypes()) { + iter_types_[info_->initialInfo() + ->getDynamicExprOutputs() + .at(id_index) + ->as()] = iter_type; + } + concretize(); } private: void concretize(); - void concretizeReshape(); + TensorView* maybeConcretizeReshape(TensorView* tv); - void concretizeSlice(); + TensorView* maybeConcretizeSlice(TensorView* tv); - void concretizeResize(); + IterDomain* maybeConcretizeResize(IterDomain* id); void concretizeIterDomain(IterDomain* old_id, IterDomain* new_id); //! Use this instead of calling registerMutation directly, since it will also //! check that the concretized value is a valid input to all of its uses. void registerConcretization(Val* old_val, Val* new_val) { + TORCH_INTERNAL_ASSERT( + old_val->vtype() == new_val->vtype(), + "Concretization should not change ValType, but attempted concretization of ", + old_val->toString(), + " with ", + new_val->toString()); checkConcretizedUses(old_val, new_val); registerMutation(old_val, new_val); } @@ -537,205 +564,219 @@ class DynamicTransformConcretizer : public OptOutMutator { void mutate(TensorDomain* td) final; + void mutate(IterDomain* id) final; + //! Concretizes the root domain of a symbolic consumer tensor from //! its producer domains. Returns true if any root ID is concretized. bool propagateFromProducerToConsumer(TensorView* consumer); private: const DynamicTransformConcretizationInfo* info_; + + //! These are populated at construction, so that we can easily check whether a + //! Val is the output of a dynamic op. + std::unordered_map view_analyses_; + std::unordered_map> + slice_descriptors_; + std::unordered_map iter_types_; }; void DynamicTransformConcretizer::concretize() { - // First, concretize all dynamic reshape ops - concretizeReshape(); - - // Concretize dynamic slices - concretizeSlice(); - - // Set output IterTypes for dynamic resize ops - concretizeResize(); - - // Finally, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_->fusion(), true); - for (auto stmt : all_stmts) { - if (stmt->isA()) { - mutate(stmt); + // Proceed in order from inputs to outputs. For each statement, propagate + // concretized inputs from producers to consumers, then concretize statement + // if it is a dynamic op. + std::vector vals; + for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) { + // When we concretize reshape and slice, we replace some TVs with others. + // This means we would change the Statements listed in getStmts(), + // resulting in segfaults if we have removed expressions containing replaced + // TVs. So, here we extract all Vals first so that we don't try and + // dereference a removed Expr. Then, when mutating we only need to be sure + // that we do not replace any downstream Vals (replacing downstream Exprs, + // i.e. immediate uses, is fine). + if (stmt->isVal()) { + vals.push_back(stmt->asVal()); } } + for (auto val : vals) { + mutate(val); + } } -void DynamicTransformConcretizer::concretizeReshape() { - // Concretize each reshape op. - for (const auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { - auto incomplete_out_tv = info_->initialInfo() - ->getDynamicExprOutputs() - .at(tv_index) - ->as(); - auto view_op = incomplete_out_tv->definition()->as(); - auto inp_tv = view_op->in()->as(); - - auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); - - for (auto i : - c10::irange(incomplete_out_tv->getMaybeRFactorDomain().size())) { - auto id = incomplete_out_tv->getMaybeRFactorDomain().at(i); - auto new_id = concrete_reshape_out_tv->getMaybeRFactorDomain().at(i); - concretizeIterDomain(id, new_id); - } +TensorView* DynamicTransformConcretizer::maybeConcretizeReshape( + TensorView* incomplete_out_tv) { + const auto it = view_analyses_.find(incomplete_out_tv); + if (it == view_analyses_.end()) { + return incomplete_out_tv; + } + const auto& view_analysis = it->second; - // We do the replacement directly here, but we must still check that the - // replacement is valid - checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); + auto view_op = incomplete_out_tv->definition()->as(); + auto inp_tv = view_op->in()->as(); - // Replace the old tensor with the new concretized tensor - for (auto use_of_old_tv : incomplete_out_tv->uses()) { - ir_utils::replaceValInExpr( - use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); - } + auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); - if (incomplete_out_tv->isFusionOutput()) { - incomplete_out_tv->fusion()->replaceOutput( - incomplete_out_tv, concrete_reshape_out_tv); - } + for (auto i : + c10::irange(incomplete_out_tv->getMaybeRFactorDomain().size())) { + auto id = incomplete_out_tv->getMaybeRFactorDomain().at(i); + auto new_id = concrete_reshape_out_tv->getMaybeRFactorDomain().at(i); + concretizeIterDomain(id, new_id); + } + + // registerConcretization(incomplete_out_tv, concrete_reshape_out_tv); - info_->fusion()->removeVal(incomplete_out_tv); + // Replace the old tensor with the new concretized tensor + for (auto use_of_old_tv : incomplete_out_tv->uses()) { + ir_utils::replaceValInExpr( + use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); } + + if (incomplete_out_tv->isFusionOutput()) { + incomplete_out_tv->fusion()->replaceOutput( + incomplete_out_tv, concrete_reshape_out_tv); + } + + info_->fusion()->removeVal(incomplete_out_tv); + + return concrete_reshape_out_tv; } -void DynamicTransformConcretizer::concretizeSlice() { +TensorView* DynamicTransformConcretizer::maybeConcretizeSlice( + TensorView* incomplete_out_tv) { + const auto it = slice_descriptors_.find(incomplete_out_tv); + if (it == slice_descriptors_.end()) { + return incomplete_out_tv; + } + const auto& slice_descs = it->second; + auto fusion = FusionGuard::getCurFusion(); - for (const auto& [tv_index, slice_descs] : info_->getSliceDescriptors()) { - auto incomplete_out_tv = info_->initialInfo() - ->getDynamicExprOutputs() - .at(tv_index) - ->as(); - auto slice_op = incomplete_out_tv->definition()->as(); - auto inp_tv = slice_op->input(0)->as(); - - const auto& root_dom = incomplete_out_tv->getRootDomain(); - const auto& rfactor_dom = incomplete_out_tv->getRFactorDomain(); - // Create new rfactor domain with potentially newly-resized root IDs - std::vector new_rfactor(root_dom.size()); - - bool is_empty = false; - bool is_sliced = false; - const auto ranges = slice_op->getRanges(); - auto map_index = [&fusion]( - SliceIndexBranch branch, Val* a, Val* extent) -> Val* { - if (branch == SliceIndexBranch::AlwaysExtent) { - return extent; - } else if (branch == SliceIndexBranch::Negative) { - return SimplifyingIrBuilder::negExpr(a); - } else if (branch == SliceIndexBranch::Positive) { - return a; - } else { - return fusion->zeroVal(); - } - }; - std::vector new_ranges; - new_ranges.reserve(ranges.size()); - for (auto i : c10::irange(root_dom.size())) { - auto desc = slice_descs.at(i); - if (desc.is_empty) { - is_empty = true; - // Use 0:0:1 as the canonical empty slice. - new_ranges.push_back( - {fusion->zeroVal(), fusion->zeroVal(), fusion->oneVal()}); - } else { - auto range = ranges.at(i); - auto inp_extent = root_dom.at(i)->getMaybeExpandedExtent(); - auto new_start = map_index(desc.start_branch, range.start, inp_extent); - auto new_stop = map_index(desc.stop_branch, range.stop, inp_extent); - new_ranges.push_back({new_start, new_stop, range.step}); - // Trivial slices correspond to 0:extent:1 - if (desc.start_branch != SliceIndexBranch::AlwaysZero || - desc.stop_branch != SliceIndexBranch::AlwaysExtent || - desc.step_branch != SliceStepBranch::One) { - is_sliced = true; - } - } + auto slice_op = incomplete_out_tv->definition()->as(); + auto inp_tv = slice_op->input(0)->as(); + + const auto& root_dom = incomplete_out_tv->getRootDomain(); + const auto& rfactor_dom = incomplete_out_tv->getRFactorDomain(); + // Create new rfactor domain with potentially newly-resized root IDs + std::vector new_rfactor(root_dom.size()); + + bool is_empty = false; + bool is_sliced = false; + const auto ranges = slice_op->getRanges(); + auto map_index = [&fusion]( + SliceIndexBranch branch, Val* a, Val* extent) -> Val* { + if (branch == SliceIndexBranch::AlwaysExtent) { + return extent; + } else if (branch == SliceIndexBranch::Negative) { + return SimplifyingIrBuilder::negExpr(a); + } else if (branch == SliceIndexBranch::Positive) { + return a; + } else { + return fusion->zeroVal(); } - - TensorView* new_tv = nullptr; - - if (is_empty) { - std::vector new_shape(ranges.size()); - for (auto i : c10::irange(ranges.size())) { - auto new_range = new_ranges.at(i); - auto desc = slice_descs.at(i); - // Depending on the step branch, we can use different output extent - // expressions - switch (desc.step_branch) { - case SliceStepBranch::One: - new_shape[i] = - SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start); - break; - case SliceStepBranch::GreaterThanOne: - new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( - SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start), - new_range.step); - case SliceStepBranch::Negative: - new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( - SimplifyingIrBuilder::subExpr(new_range.start, new_range.stop), - SimplifyingIrBuilder::negExpr(new_range.step)); - } - } - // TODO: process as empty tensor if is_empty - auto dtype = incomplete_out_tv->getDataType().value(); - new_tv = full(new_shape, fusion->zeroVal(dtype), dtype); - } else if (!is_sliced) { - // Replace the slice with set() - new_tv = set(inp_tv); + }; + std::vector new_ranges; + new_ranges.reserve(ranges.size()); + for (auto i : c10::irange(root_dom.size())) { + auto desc = slice_descs.at(i); + if (desc.is_empty) { + is_empty = true; + // Use 0:0:1 as the canonical empty slice. + new_ranges.push_back( + {fusion->zeroVal(), fusion->zeroVal(), fusion->oneVal()}); } else { - new_tv = slice(inp_tv, new_ranges, /*skip_symbolic*/ true); + auto range = ranges.at(i); + auto inp_extent = root_dom.at(i)->getMaybeExpandedExtent(); + auto new_start = map_index(desc.start_branch, range.start, inp_extent); + auto new_stop = map_index(desc.stop_branch, range.stop, inp_extent); + new_ranges.push_back({new_start, new_stop, range.step}); + // Trivial slices correspond to 0:extent:1 + if (desc.start_branch != SliceIndexBranch::AlwaysZero || + desc.stop_branch != SliceIndexBranch::AlwaysExtent || + desc.step_branch != SliceStepBranch::One) { + is_sliced = true; + } } + } - for (auto i : c10::irange(root_dom.size())) { - auto id = rfactor_dom.at(i); - auto new_id = new_tv->getRFactorDomain().at(i); - concretizeIterDomain(id, new_id); + TensorView* new_tv = nullptr; + + if (is_empty) { + std::vector new_shape(ranges.size()); + for (auto i : c10::irange(ranges.size())) { + auto new_range = new_ranges.at(i); + auto desc = slice_descs.at(i); + // Depending on the step branch, we can use different output extent + // expressions + switch (desc.step_branch) { + case SliceStepBranch::One: + new_shape[i] = + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start); + break; + case SliceStepBranch::GreaterThanOne: + new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start), + new_range.step); + case SliceStepBranch::Negative: + new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( + SimplifyingIrBuilder::subExpr(new_range.start, new_range.stop), + SimplifyingIrBuilder::negExpr(new_range.step)); + } } + // TODO: process as empty tensor if is_empty + auto dtype = incomplete_out_tv->getDataType().value(); + new_tv = full(new_shape, fusion->zeroVal(dtype), dtype); + } else if (!is_sliced) { + // Replace the slice with set() + new_tv = set(inp_tv); + } else { + new_tv = slice(inp_tv, new_ranges, /*skip_symbolic*/ true); + } - // TODO: We need to update the maybeRFactorDomains of new_tv if there are - // any Broadcast sliced dimensions. + for (auto i : c10::irange(root_dom.size())) { + auto id = rfactor_dom.at(i); + auto new_id = new_tv->getRFactorDomain().at(i); + concretizeIterDomain(id, new_id); + } - // We do the replacement directly here, but we must still check that the - // replacement is valid - checkConcretizedUses(incomplete_out_tv, new_tv); + // TODO: We need to update the maybeRFactorDomains of new_tv if there are + // any Broadcast sliced dimensions. - // Replace the old tensor with the new concretized tensor - for (auto use_of_old_tv : incomplete_out_tv->uses()) { - ir_utils::replaceValInExpr(use_of_old_tv, incomplete_out_tv, new_tv); - } + // We do the replacement directly here, but we must still check that the + // replacement is valid + // registerConcretization(incomplete_out_tv, new_tv); - if (incomplete_out_tv->isFusionOutput()) { - incomplete_out_tv->fusion()->replaceOutput(incomplete_out_tv, new_tv); - } + // Replace the old tensor with the new concretized tensor + for (auto use_of_old_tv : incomplete_out_tv->uses()) { + ir_utils::replaceValInExpr(use_of_old_tv, incomplete_out_tv, new_tv); + } - info_->fusion()->removeVal(incomplete_out_tv); + if (incomplete_out_tv->isFusionOutput()) { + incomplete_out_tv->fusion()->replaceOutput(incomplete_out_tv, new_tv); } -} -void DynamicTransformConcretizer::concretizeResize() { - // Concretize each resize op. - for (const auto& [id_index, iter_type] : info_->getResizeIterTypes()) { - auto id = info_->initialInfo() - ->getDynamicExprOutputs() - .at(id_index) - ->as(); - TORCH_CHECK( - id->definition() && id->definition()->isA(), - "Resized IterDomain must have a Resize definition"); - auto def = id->definition()->as(); - auto new_id = IterDomain::resize( - def->in(), - def->leftExpand(), - def->rightExpand(), - id->isRFactorProduct(), - iter_type); + info_->fusion()->removeVal(incomplete_out_tv); - concretizeIterDomain(id, new_id); + return new_tv; +} + +IterDomain* DynamicTransformConcretizer::maybeConcretizeResize(IterDomain* id) { + const auto it = iter_types_.find(id); + if (it == iter_types_.end()) { + return id; } + const auto& iter_type = it->second; + TORCH_CHECK( + id->definition() && id->definition()->isA(), + "Resized IterDomain must have a Resize definition"); + auto def = id->definition()->as(); + auto new_id = IterDomain::resize( + def->in(), + def->leftExpand(), + def->rightExpand(), + id->isRFactorProduct(), + iter_type); + + concretizeIterDomain(id, new_id); + return new_id; } void DynamicTransformConcretizer::concretizeIterDomain( @@ -751,9 +792,6 @@ void DynamicTransformConcretizer::concretizeIterDomain( // new_ext. IrBuilder::create(LoadStoreOpType::Set, old_ext, new_ext); } - - registerConcretization(old_ext, new_ext); - registerConcretization(old_id, new_id); } void DynamicTransformConcretizer::checkConcretizedUses( @@ -772,6 +810,8 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { if (!tv->domain()->hasSymbolicAxis()) { return; } + tv = maybeConcretizeReshape(tv); + tv = maybeConcretizeSlice(tv); // First, try to concretize the root domain as there may be symbolic // axes inherited from the producers @@ -917,6 +957,10 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { registerConcretization(td, mutated_val); } +void DynamicTransformConcretizer::mutate(IterDomain* id) { + id = maybeConcretizeResize(id); +} + bool DynamicTransformConcretizer::propagateFromProducerToConsumer( TensorView* consumer) { if (consumer->definition() == nullptr || From 843cb8e687a948b195606786a97ae2d555473908 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 5 Jul 2023 08:53:23 -0400 Subject: [PATCH 14/14] Concretize offset == 0 as Zero. This will remove more trivial slices, but means we trigger a recompile between x[0:m] and x[1:m]. --- csrc/dynamic_transform.cpp | 29 ++++++++++++++++------------- csrc/dynamic_transform.h | 12 ++++++------ 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index c5c8a5cecf1..e9986162c86 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -362,14 +362,14 @@ void DynamicTransformConcretizationInfo::analyzeSlice( auto inp_extent = inp_extent_opt->as(); auto getBranch = [&inp_extent](int64_t a) -> SliceIndexBranch { - if (a <= -inp_extent) { - return SliceIndexBranch::AlwaysZero; + if (a == 0 || a <= -inp_extent) { + return SliceIndexBranch::Zero; } else if (a < 0) { return SliceIndexBranch::Negative; } else if (a < inp_extent) { return SliceIndexBranch::Positive; } else { - return SliceIndexBranch::AlwaysExtent; + return SliceIndexBranch::Extent; } }; slice_descs[i].start_branch = getBranch(start); @@ -663,15 +663,18 @@ TensorView* DynamicTransformConcretizer::maybeConcretizeSlice( const auto ranges = slice_op->getRanges(); auto map_index = [&fusion]( SliceIndexBranch branch, Val* a, Val* extent) -> Val* { - if (branch == SliceIndexBranch::AlwaysExtent) { - return extent; - } else if (branch == SliceIndexBranch::Negative) { - return SimplifyingIrBuilder::negExpr(a); - } else if (branch == SliceIndexBranch::Positive) { - return a; - } else { - return fusion->zeroVal(); + Val* normalized_index = nullptr; + switch (branch) { + case SliceIndexBranch::Zero: + normalized_index = fusion->zeroVal(); + case SliceIndexBranch::Extent: + normalized_index = extent; + case SliceIndexBranch::Negative: + normalized_index = SimplifyingIrBuilder::addExpr(extent, a); + case SliceIndexBranch::Positive: + normalized_index = a; } + return normalized_index; }; std::vector new_ranges; new_ranges.reserve(ranges.size()); @@ -689,8 +692,8 @@ TensorView* DynamicTransformConcretizer::maybeConcretizeSlice( auto new_stop = map_index(desc.stop_branch, range.stop, inp_extent); new_ranges.push_back({new_start, new_stop, range.step}); // Trivial slices correspond to 0:extent:1 - if (desc.start_branch != SliceIndexBranch::AlwaysZero || - desc.stop_branch != SliceIndexBranch::AlwaysExtent || + if (desc.start_branch != SliceIndexBranch::Zero || + desc.stop_branch != SliceIndexBranch::Extent || desc.step_branch != SliceStepBranch::One) { is_sliced = true; } diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 9493620d0f8..4620f83a65d 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -116,10 +116,10 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { //! slice(). Each of these leads to a different branch in the normalized form's //! general expression. enum class SliceIndexBranch { - AlwaysZero, // a <= -extent - Negative, // -ext < a < 0 - Positive, // 0 <= a < extent - AlwaysExtent // extent <= a + Negative, // -extent < a < 0 + Zero, // a == 0 OR a <= -extent + Positive, // 0 < a < extent + Extent // extent <= a }; //! This enum describes the "step" argument to slice, which can be a positive or @@ -130,8 +130,8 @@ enum class SliceStepBranch { Negative, One, GreaterThanOne }; //! Describes a 1D slice in terms of the start, stop, and extent values struct Concrete1DSliceDescriptor { //! These enums determine the form of the simplified expressions - SliceIndexBranch start_branch = SliceIndexBranch::Positive; - SliceIndexBranch stop_branch = SliceIndexBranch::Positive; + SliceIndexBranch start_branch = SliceIndexBranch::Zero; + SliceIndexBranch stop_branch = SliceIndexBranch::Extent; SliceStepBranch step_branch = SliceStepBranch::One; //! True if normalized values satisfy (stop - start) * step <= 0 in which case