diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 42581cdeb5d..e9986162c86 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -25,16 +27,10 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( IrCloner& ir_cloner) const { DynamicTransformInitialInfo cloned_info( static_cast(ir_cloner.container())); - cloned_info.dynamic_reshaped_tvs_.reserve(dynamic_reshaped_tvs_.size()); - for (const auto op : dynamic_reshaped_tvs_) { - if (op) { - cloned_info.dynamic_reshaped_tvs_.push_back(ir_cloner.clone(op)); - } - } - cloned_info.dynamic_resized_ids_.reserve(dynamic_resized_ids_.size()); - for (const auto op : dynamic_resized_ids_) { - if (op) { - cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op)); + cloned_info.dynamic_expr_outputs_.reserve(dynamic_expr_outputs_.size()); + for (const auto v : dynamic_expr_outputs_) { + if (v) { + cloned_info.dynamic_expr_outputs_.push_back(ir_cloner.clone(v)); } } cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); @@ -50,15 +46,11 @@ std::string DynamicTransformInitialInfo::toString() const { std::stringstream ss; ss << "DynamicTransformInitialInfo\n"; std::string indent = " "; - ss << indent << "Dynamic reshaped TensorViews:\n"; - for (const auto& op : dynamic_reshaped_tvs_) { - ss << indent << indent << op->toString() << "\n"; + ss << indent << "Dynamic Vals:\n"; + for (const auto& val : dynamic_expr_outputs_) { + ss << indent << indent << val->toString() << "\n"; } - ss << indent << "Dynamic resized IterDomains:\n"; - for (const auto& op : dynamic_resized_ids_) { - ss << indent << indent << op->toString() << "\n"; - } - ss << indent << "Root dynamic Vals:\n"; + ss << indent << "Scalars determining concretization:\n"; for (const auto& v : root_dynamic_vals_) { ss << indent << indent << v->toString() << "\n"; } @@ -86,13 +78,21 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { private: using IterVisitor::handle; + bool maybeInsertDynamicExprOutput(Val* val) { + auto inserted = inserted_dynamic_expr_outputs_.insert(val).second; + if (inserted) { + info_.dynamic_expr_outputs_.push_back(val); + } + return inserted; + } + //! Find views that have symbolic outputs void handle(ViewOp* op) override { auto inp_tv = op->in()->as(); auto out_tv = op->out()->as(); // If there's no symbolic axis, this is a static reshape op if (out_tv->domain()->hasSymbolicAxis()) { - info_.dynamic_reshaped_tvs_.push_back(out_tv); + maybeInsertDynamicExprOutput(out_tv); // Input and output extent expressions both affect concretization const auto& inp_dom = @@ -107,15 +107,63 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { } } + //! Find slices that have symbolic outputs + void handle(SliceOp* op) override { + auto inp_tv = op->in()->as(); + auto out_tv = op->out()->as(); + // If there's no symbolic axis, this is a static slice op + if (out_tv->domain()->hasSymbolicAxis()) { + maybeInsertDynamicExprOutput(out_tv); + + // Input extent and ranges affect concretization + const auto& inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + for (const auto id : inp_dom) { + leaf_dynamic_vals_.push_back(id->extent()); + } + for (const auto& range : op->getRanges()) { + leaf_dynamic_vals_.push_back(range.start); + leaf_dynamic_vals_.push_back(range.stop); + leaf_dynamic_vals_.push_back(range.step); + } + } + } + //! Detect dynamic IterDomain transforms when handling TensorViews void handle(TensorView* tv) override { + if (tv->definition() && tv->definition()->isA()) { + if (tv->domain()->hasSymbolicAxis()) { + maybeInsertDynamicExprOutput(tv); + + auto root_dom = tv->getRootDomain(); + const auto ranges = tv->definition()->as()->getRanges(); + TORCH_INTERNAL_ASSERT( + ranges.size() == root_dom.size(), + "Mismatch between number of slice ranges ", + ranges.size(), + " and size of root domain ", + root_dom.size()); + for (auto i : c10::irange(root_dom.size())) { + // input extent and start/stop/step values determine slice + // concretization + auto root_ext = root_dom.at(i)->getMaybeExpandedExtent(); + leaf_dynamic_vals_.push_back(root_ext); + auto range = ranges.at(i); + leaf_dynamic_vals_.push_back(range.start); + leaf_dynamic_vals_.push_back(range.stop); + leaf_dynamic_vals_.push_back(range.step); + } + } + return; + } const auto& rfd = tv->getMaybeRFactorDomain(); for (auto id : rfd) { if (!id->definition() || id->getIterType() != IterType::Symbolic) { continue; } if (id->definition()->isA()) { - info_.dynamic_resized_ids_.push_back(id); + maybeInsertDynamicExprOutput(id); + // extent of output determines its IterType leaf_dynamic_vals_.push_back(id->extent()); } @@ -152,126 +200,243 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { //! scalars that influence concretization. That list of scalars is then used //! to compute a minimal cache key in InputsIdLookup::lookupId(). std::vector leaf_dynamic_vals_; + + //! In order to prevent redundant processing of dynamic ops, we use an + //! unordered_set here to track which values have already been inserted. + std::unordered_set inserted_dynamic_expr_outputs_; }; -void DynamicTransformConcretizationInfo::analyzeReshapes( +void DynamicTransformConcretizationInfo::analyze( ExpressionEvaluator* expr_eval) { - const auto& reshape_tvs = initial_info_->getDynamicReshapedTensorViews(); - for (const auto tv_index : c10::irange(reshape_tvs.size())) { - auto out_tv = reshape_tvs.at(tv_index); - auto op = out_tv->definition()->as(); - auto inp_tv = op->in()->as(); - - // If there's no symblic axis, this is a static reshape op - if (!out_tv->domain()->hasSymbolicAxis()) { - return; - } - + const auto& expr_outputs = initial_info_->getDynamicExprOutputs(); + for (auto val_index : c10::irange(expr_outputs.size())) { + auto val = expr_outputs.at(val_index); TORCH_INTERNAL_ASSERT( - out_tv->hasRFactor(), - "Unexpected output tv of ViewOp: ", - out_tv->toString()); - - const auto& inp_dom = - TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); - - // Determine input shape using expr evaluator - std::vector inp_shape(inp_dom.size(), 0); - for (const auto i : c10::irange(inp_dom.size())) { - auto inp_id = inp_dom.at(i); - // This should have been validated when initially creating reshape - // op, but just in case - TORCH_INTERNAL_ASSERT( - !inp_id->maybePartial(), - "Invalid domain to reshape: ", - inp_id->toString()); - auto extent_val = expr_eval->evaluate(inp_id->extent()); - TORCH_INTERNAL_ASSERT( - extent_val.has_value(), - "Cannot evaluate the extent of an input domain to reshape: ", - inp_id->toString()); + val->definition(), + "Dynamic Val does not have a definition: ", + val->toString()); + if (val->definition()->isA()) { + analyzeReshape(expr_eval, val_index); + } else if (val->definition()->isA()) { + analyzeSlice(expr_eval, val_index); + } else if (val->definition()->isA()) { + analyzeResize(expr_eval, val_index); + } else { TORCH_INTERNAL_ASSERT( - extent_val->isInt(), - "Invalid evaluated value of domain extent: ", - inp_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val->as() > 0, - "Invalid input domain extent: ", - extent_val->as()); - inp_shape.at(i) = extent_val->as(); - } - - const auto& out_dom = out_tv->getMaybeRFactorDomain(); - - // Determine output shape using expr evaluator. Note there may be - // one domain of extent -1 - std::vector out_shape(out_dom.size(), 0); - bool extent_m1_found = false; - for (const auto i : c10::irange(out_dom.size())) { - auto out_id = out_dom.at(i); - auto extent_val = expr_eval->evaluate(out_id->extent()); - TORCH_INTERNAL_ASSERT( - extent_val.has_value(), - "Cannot evaluate the extent of an output domain to reshape: ", - out_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val->isInt(), - "Invalid evaluated value of domain extent: ", - out_id->toString()); - const auto extent_int = extent_val->as(); - if (extent_int == -1) { - TORCH_INTERNAL_ASSERT( - !extent_m1_found, - "Multiple output domains of size -1 not allowed", - out_tv->toString()); - extent_m1_found = true; - } else { - TORCH_INTERNAL_ASSERT( - extent_int > 0, "Invalid output domain extent: ", extent_int); - } - out_shape.at(i) = extent_int; + false, + "Unhandled dynamic expression: ", + val->definition()->toString()); } + } +} - auto view_result = analyzeView(inp_tv, inp_shape, out_shape); +void DynamicTransformConcretizationInfo::analyzeReshape( + ExpressionEvaluator* expr_eval, + size_t val_index) { + auto out_tv = + initialInfo()->getDynamicExprOutputs().at(val_index)->as(); + auto op = out_tv->definition()->as(); + auto inp_tv = op->in()->as(); - reshape_transforms_.emplace_back(tv_index, view_result); + // If there's no symblic axis, this is a static reshape op + if (!out_tv->domain()->hasSymbolicAxis()) { + return; } -} -void DynamicTransformConcretizationInfo::analyzeResizes( - ExpressionEvaluator* expr_eval) { - const auto& resize_ids = initial_info_->getDynamicResizedIterDomains(); - for (const auto id_index : c10::irange(resize_ids.size())) { - auto out_id = resize_ids.at(id_index); - auto op = out_id->definition()->as(); + TORCH_INTERNAL_ASSERT( + out_tv->hasRFactor(), + "Unexpected output tv of ViewOp: ", + out_tv->toString()); + + const auto& inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + + // Determine input shape using expr evaluator + std::vector inp_shape(inp_dom.size(), 0); + for (const auto i : c10::irange(inp_dom.size())) { + auto inp_id = inp_dom.at(i); + // This should have been validated when initially creating reshape + // op, but just in case + TORCH_INTERNAL_ASSERT( + !inp_id->maybePartial(), + "Invalid domain to reshape: ", + inp_id->toString()); + auto extent_val = expr_eval->evaluate(inp_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.has_value(), + "Cannot evaluate the extent of an input domain to reshape: ", + inp_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->isInt(), + "Invalid evaluated value of domain extent: ", + inp_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->as() > 0, + "Invalid input domain extent: ", + extent_val->as()); + inp_shape.at(i) = extent_val->as(); + } - TORCH_CHECK( - out_id->getIterType() == IterType::Symbolic, - "Found non-dynamic Resize in initial concretization info: ", - op->toString()); + const auto& out_dom = out_tv->getMaybeRFactorDomain(); + // Determine output shape using expr evaluator. Note there may be + // one domain of extent -1 + std::vector out_shape(out_dom.size(), 0); + bool extent_m1_found = false; + for (const auto i : c10::irange(out_dom.size())) { + auto out_id = out_dom.at(i); auto extent_val = expr_eval->evaluate(out_id->extent()); TORCH_INTERNAL_ASSERT( extent_val.has_value(), - "Cannot evaluate the extent of a resized domain: ", + "Cannot evaluate the extent of an output domain to reshape: ", out_id->toString()); TORCH_INTERNAL_ASSERT( extent_val->isInt(), - "Invalid evaluated value of resized domain extent: ", + "Invalid evaluated value of domain extent: ", out_id->toString()); - auto extent_int = extent_val->as(); + const auto extent_int = extent_val->as(); + if (extent_int == -1) { + TORCH_INTERNAL_ASSERT( + !extent_m1_found, + "Multiple output domains of size -1 not allowed", + out_tv->toString()); + extent_m1_found = true; + } else { + TORCH_INTERNAL_ASSERT( + extent_int > 0, "Invalid output domain extent: ", extent_int); + } + out_shape.at(i) = extent_int; + } + + auto view_result = analyzeView(inp_tv, inp_shape, out_shape); + + reshape_transforms_.emplace_back(val_index, view_result); +} + +void DynamicTransformConcretizationInfo::analyzeSlice( + ExpressionEvaluator* expr_eval, + size_t val_index) { + auto out_tv = + initialInfo()->getDynamicExprOutputs().at(val_index)->as(); + auto op = out_tv->definition()->as(); + auto inp_tv = op->in()->as(); + const auto inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + const auto out_dom = out_tv->getMaybeRFactorDomain(); + const auto ranges = op->getRanges(); + std::vector slice_descs(inp_dom.size()); + for (auto i : c10::irange(inp_dom.size())) { + const auto& range = ranges.at(i); + + auto start_opt = expr_eval->evaluate(range.start); TORCH_INTERNAL_ASSERT( - extent_int > 0, - "Invalid resized domain extent ", - extent_int, - " for domain ", - out_id->toString()); + start_opt.has_value(), + "Could not evaluate start of slice range ", + range.start); + auto start = start_opt->as(); - auto iter_type = - extent_int == 1 ? IterType::Broadcast : IterType::Iteration; + auto stop_opt = expr_eval->evaluate(range.stop); + TORCH_INTERNAL_ASSERT( + stop_opt.has_value(), + "Could not evaluate stop of slice range ", + range.stop); + auto stop = stop_opt->as(); - resize_itertypes_.emplace_back(id_index, iter_type); + auto step_opt = expr_eval->evaluate(range.step); + TORCH_INTERNAL_ASSERT( + step_opt.has_value(), + "Could not evaluate step of slice range ", + range.step); + auto step = step_opt->as(); + + TORCH_INTERNAL_ASSERT(step != 0, "Slice step must not be zero"); + TORCH_INTERNAL_ASSERT( + step == 1, "Slicing with step != 1 is not currently supported"); + + auto inp_extent_opt = + expr_eval->evaluate(inp_dom.at(i)->getMaybeExpandedExtent()); + TORCH_INTERNAL_ASSERT( + inp_extent_opt.has_value(), + "Could not evaluate slice input extent ", + inp_dom.at(i)->getMaybeExpandedExtent()); + auto inp_extent = inp_extent_opt->as(); + + auto getBranch = [&inp_extent](int64_t a) -> SliceIndexBranch { + if (a == 0 || a <= -inp_extent) { + return SliceIndexBranch::Zero; + } else if (a < 0) { + return SliceIndexBranch::Negative; + } else if (a < inp_extent) { + return SliceIndexBranch::Positive; + } else { + return SliceIndexBranch::Extent; + } + }; + slice_descs[i].start_branch = getBranch(start); + slice_descs[i].stop_branch = getBranch(stop); + slice_descs[i].is_empty = (stop - start) * step <= 0; + + // The dynamic slice output has a purely symbolic extent. Here we evaluate + // the proper extent to determine the output IterType. + auto map_int_index = [&inp_extent](const int64_t& a) -> int64_t { + if (a <= -inp_extent) { + return 0; + } else if (a < 0) { + return -a; + } else if (a < inp_extent) { + return a; + } else { + return inp_extent; + } + }; + // actual size of sliced dimension is ceilDiv(stop - start, step) when + // step > 0. When step < 0, that expression is off by one and instead the + // extent in that case is ceilDiv(start - stop, -step). + auto concrete_sliced_extent = step > 0 + ? (map_int_index(stop) - map_int_index(start) + step - 1) / step + : (map_int_index(stop) - map_int_index(start) + step + 1) / step; + + if (concrete_sliced_extent == 1) { + slice_descs[i].iter_type = IterType::Broadcast; + } + + expr_eval->bind(out_dom.at(i)->extent(), concrete_sliced_extent); } + slice_descriptors_.emplace_back(val_index, slice_descs); +} + +void DynamicTransformConcretizationInfo::analyzeResize( + ExpressionEvaluator* expr_eval, + size_t val_index) { + auto out_id = + initialInfo()->getDynamicExprOutputs().at(val_index)->as(); + auto op = out_id->definition()->as(); + + TORCH_CHECK( + out_id->getIterType() == IterType::Symbolic, + "Found non-dynamic Resize in initial concretization info: ", + op->toString()); + + auto extent_val = expr_eval->evaluate(out_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.has_value(), + "Cannot evaluate the extent of a resized domain: ", + out_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val->isInt(), + "Invalid evaluated value of resized domain extent: ", + out_id->toString()); + auto extent_int = extent_val->as(); + TORCH_INTERNAL_ASSERT( + extent_int > 0, + "Invalid resized domain extent ", + extent_int, + " for domain ", + out_id->toString()); + + auto iter_type = extent_int == 1 ? IterType::Broadcast : IterType::Iteration; + + resize_itertypes_.emplace_back(val_index, iter_type); } bool DynamicTransformConcretizationInfo::operator==( @@ -281,7 +446,8 @@ bool DynamicTransformConcretizationInfo::operator==( } if (reshape_transforms_.size() != other.reshape_transforms_.size() || - resize_itertypes_.size() != other.resize_itertypes_.size()) { + resize_itertypes_.size() != other.resize_itertypes_.size() || + slice_descriptors_.size() != other.slice_descriptors_.size()) { return false; } @@ -293,6 +459,14 @@ bool DynamicTransformConcretizationInfo::operator==( } } + for (const auto i : c10::irange(slice_descriptors_.size())) { + const auto& desc = slice_descriptors_.at(i); + const auto& other_desc = other.slice_descriptors_.at(i); + if (desc != other_desc) { + return false; + } + } + for (const auto i : c10::irange(resize_itertypes_.size())) { const auto& itertype = resize_itertypes_.at(i); const auto& other_itertype = other.resize_itertypes_.at(i); @@ -310,13 +484,13 @@ std::string DynamicTransformConcretizationInfo::toString() const { std::string indent = " "; ss << indent << "Reshape:\n"; for (const auto& [tv_index, analyze_result] : reshape_transforms_) { - auto tv = initial_info_->getDynamicReshapedTensorViews().at(tv_index); + auto tv = initial_info_->getDynamicExprOutputs().at(tv_index); ss << indent << indent << tv->toString() << " (index=" << tv_index << "), " << analyze_result.toString() << "\n"; } ss << indent << "Resize:\n"; for (const auto& [id_index, iter_type] : resize_itertypes_) { - auto id = initial_info_->getDynamicResizedIterDomains().at(id_index); + auto id = initial_info_->getDynamicExprOutputs().at(id_index); ss << indent << indent << id->toString() << " (index=" << id_index << "), " << iter_type << "\n"; } @@ -333,19 +507,48 @@ class DynamicTransformConcretizer : public OptOutMutator { TORCH_INTERNAL_ASSERT( fusion == info->fusion(), "Invalid DynamicTransformInitialInfo. The associated Fusion is different from the given Fusion"); + for (auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { + view_analyses_[info_->initialInfo() + ->getDynamicExprOutputs() + .at(tv_index) + ->as()] = view_analysis; + } + for (auto& [tv_index, slice_descs] : info_->getSliceDescriptors()) { + slice_descriptors_[info_->initialInfo() + ->getDynamicExprOutputs() + .at(tv_index) + ->as()] = slice_descs; + } + for (auto& [id_index, iter_type] : info_->getResizeIterTypes()) { + iter_types_[info_->initialInfo() + ->getDynamicExprOutputs() + .at(id_index) + ->as()] = iter_type; + } + concretize(); } private: void concretize(); - void concretizeReshape(); + TensorView* maybeConcretizeReshape(TensorView* tv); + + TensorView* maybeConcretizeSlice(TensorView* tv); - void concretizeResize(); + IterDomain* maybeConcretizeResize(IterDomain* id); + + void concretizeIterDomain(IterDomain* old_id, IterDomain* new_id); //! Use this instead of calling registerMutation directly, since it will also //! check that the concretized value is a valid input to all of its uses. void registerConcretization(Val* old_val, Val* new_val) { + TORCH_INTERNAL_ASSERT( + old_val->vtype() == new_val->vtype(), + "Concretization should not change ValType, but attempted concretization of ", + old_val->toString(), + " with ", + new_val->toString()); checkConcretizedUses(old_val, new_val); registerMutation(old_val, new_val); } @@ -361,75 +564,236 @@ class DynamicTransformConcretizer : public OptOutMutator { void mutate(TensorDomain* td) final; + void mutate(IterDomain* id) final; + //! Concretizes the root domain of a symbolic consumer tensor from //! its producer domains. Returns true if any root ID is concretized. bool propagateFromProducerToConsumer(TensorView* consumer); private: const DynamicTransformConcretizationInfo* info_; + + //! These are populated at construction, so that we can easily check whether a + //! Val is the output of a dynamic op. + std::unordered_map view_analyses_; + std::unordered_map> + slice_descriptors_; + std::unordered_map iter_types_; }; void DynamicTransformConcretizer::concretize() { - // First, concretize all dynamic reshape ops - concretizeReshape(); - - // Set output IterTypes for dynamic resize ops - concretizeResize(); - - // Finally, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_->fusion(), true); - for (auto stmt : all_stmts) { - if (stmt->isA()) { - mutate(stmt); + // Proceed in order from inputs to outputs. For each statement, propagate + // concretized inputs from producers to consumers, then concretize statement + // if it is a dynamic op. + std::vector vals; + for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) { + // When we concretize reshape and slice, we replace some TVs with others. + // This means we would change the Statements listed in getStmts(), + // resulting in segfaults if we have removed expressions containing replaced + // TVs. So, here we extract all Vals first so that we don't try and + // dereference a removed Expr. Then, when mutating we only need to be sure + // that we do not replace any downstream Vals (replacing downstream Exprs, + // i.e. immediate uses, is fine). + if (stmt->isVal()) { + vals.push_back(stmt->asVal()); } } + for (auto val : vals) { + mutate(val); + } } -void DynamicTransformConcretizer::concretizeReshape() { - // Concretize each reshape op. - for (const auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { - auto incomplete_out_tv = - info_->initialInfo()->getDynamicReshapedTensorViews().at(tv_index); - auto view_op = incomplete_out_tv->definition()->as(); - auto inp_tv = view_op->in()->as(); +TensorView* DynamicTransformConcretizer::maybeConcretizeReshape( + TensorView* incomplete_out_tv) { + const auto it = view_analyses_.find(incomplete_out_tv); + if (it == view_analyses_.end()) { + return incomplete_out_tv; + } + const auto& view_analysis = it->second; + + auto view_op = incomplete_out_tv->definition()->as(); + auto inp_tv = view_op->in()->as(); - auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); + auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); + + for (auto i : + c10::irange(incomplete_out_tv->getMaybeRFactorDomain().size())) { + auto id = incomplete_out_tv->getMaybeRFactorDomain().at(i); + auto new_id = concrete_reshape_out_tv->getMaybeRFactorDomain().at(i); + concretizeIterDomain(id, new_id); + } - // We do the replacement directly here, but we must still check that the - // replacement is valid - checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); + // registerConcretization(incomplete_out_tv, concrete_reshape_out_tv); + + // Replace the old tensor with the new concretized tensor + for (auto use_of_old_tv : incomplete_out_tv->uses()) { + ir_utils::replaceValInExpr( + use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); + } + + if (incomplete_out_tv->isFusionOutput()) { + incomplete_out_tv->fusion()->replaceOutput( + incomplete_out_tv, concrete_reshape_out_tv); + } - // Replace the old tensor with the new concretized tensor - for (auto use_of_old_tv : incomplete_out_tv->uses()) { - ir_utils::replaceValInExpr( - use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); + info_->fusion()->removeVal(incomplete_out_tv); + + return concrete_reshape_out_tv; +} + +TensorView* DynamicTransformConcretizer::maybeConcretizeSlice( + TensorView* incomplete_out_tv) { + const auto it = slice_descriptors_.find(incomplete_out_tv); + if (it == slice_descriptors_.end()) { + return incomplete_out_tv; + } + const auto& slice_descs = it->second; + + auto fusion = FusionGuard::getCurFusion(); + auto slice_op = incomplete_out_tv->definition()->as(); + auto inp_tv = slice_op->input(0)->as(); + + const auto& root_dom = incomplete_out_tv->getRootDomain(); + const auto& rfactor_dom = incomplete_out_tv->getRFactorDomain(); + // Create new rfactor domain with potentially newly-resized root IDs + std::vector new_rfactor(root_dom.size()); + + bool is_empty = false; + bool is_sliced = false; + const auto ranges = slice_op->getRanges(); + auto map_index = [&fusion]( + SliceIndexBranch branch, Val* a, Val* extent) -> Val* { + Val* normalized_index = nullptr; + switch (branch) { + case SliceIndexBranch::Zero: + normalized_index = fusion->zeroVal(); + case SliceIndexBranch::Extent: + normalized_index = extent; + case SliceIndexBranch::Negative: + normalized_index = SimplifyingIrBuilder::addExpr(extent, a); + case SliceIndexBranch::Positive: + normalized_index = a; } + return normalized_index; + }; + std::vector new_ranges; + new_ranges.reserve(ranges.size()); + for (auto i : c10::irange(root_dom.size())) { + auto desc = slice_descs.at(i); + if (desc.is_empty) { + is_empty = true; + // Use 0:0:1 as the canonical empty slice. + new_ranges.push_back( + {fusion->zeroVal(), fusion->zeroVal(), fusion->oneVal()}); + } else { + auto range = ranges.at(i); + auto inp_extent = root_dom.at(i)->getMaybeExpandedExtent(); + auto new_start = map_index(desc.start_branch, range.start, inp_extent); + auto new_stop = map_index(desc.stop_branch, range.stop, inp_extent); + new_ranges.push_back({new_start, new_stop, range.step}); + // Trivial slices correspond to 0:extent:1 + if (desc.start_branch != SliceIndexBranch::Zero || + desc.stop_branch != SliceIndexBranch::Extent || + desc.step_branch != SliceStepBranch::One) { + is_sliced = true; + } + } + } - if (incomplete_out_tv->isFusionOutput()) { - incomplete_out_tv->fusion()->replaceOutput( - incomplete_out_tv, concrete_reshape_out_tv); + TensorView* new_tv = nullptr; + + if (is_empty) { + std::vector new_shape(ranges.size()); + for (auto i : c10::irange(ranges.size())) { + auto new_range = new_ranges.at(i); + auto desc = slice_descs.at(i); + // Depending on the step branch, we can use different output extent + // expressions + switch (desc.step_branch) { + case SliceStepBranch::One: + new_shape[i] = + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start); + break; + case SliceStepBranch::GreaterThanOne: + new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( + SimplifyingIrBuilder::subExpr(new_range.stop, new_range.start), + new_range.step); + case SliceStepBranch::Negative: + new_shape[i] = SimplifyingIrBuilder::ceilDivExpr( + SimplifyingIrBuilder::subExpr(new_range.start, new_range.stop), + SimplifyingIrBuilder::negExpr(new_range.step)); + } } + // TODO: process as empty tensor if is_empty + auto dtype = incomplete_out_tv->getDataType().value(); + new_tv = full(new_shape, fusion->zeroVal(dtype), dtype); + } else if (!is_sliced) { + // Replace the slice with set() + new_tv = set(inp_tv); + } else { + new_tv = slice(inp_tv, new_ranges, /*skip_symbolic*/ true); + } - info_->fusion()->removeVal(incomplete_out_tv); + for (auto i : c10::irange(root_dom.size())) { + auto id = rfactor_dom.at(i); + auto new_id = new_tv->getRFactorDomain().at(i); + concretizeIterDomain(id, new_id); } + + // TODO: We need to update the maybeRFactorDomains of new_tv if there are + // any Broadcast sliced dimensions. + + // We do the replacement directly here, but we must still check that the + // replacement is valid + // registerConcretization(incomplete_out_tv, new_tv); + + // Replace the old tensor with the new concretized tensor + for (auto use_of_old_tv : incomplete_out_tv->uses()) { + ir_utils::replaceValInExpr(use_of_old_tv, incomplete_out_tv, new_tv); + } + + if (incomplete_out_tv->isFusionOutput()) { + incomplete_out_tv->fusion()->replaceOutput(incomplete_out_tv, new_tv); + } + + info_->fusion()->removeVal(incomplete_out_tv); + + return new_tv; +} + +IterDomain* DynamicTransformConcretizer::maybeConcretizeResize(IterDomain* id) { + const auto it = iter_types_.find(id); + if (it == iter_types_.end()) { + return id; + } + const auto& iter_type = it->second; + TORCH_CHECK( + id->definition() && id->definition()->isA(), + "Resized IterDomain must have a Resize definition"); + auto def = id->definition()->as(); + auto new_id = IterDomain::resize( + def->in(), + def->leftExpand(), + def->rightExpand(), + id->isRFactorProduct(), + iter_type); + + concretizeIterDomain(id, new_id); + return new_id; } -void DynamicTransformConcretizer::concretizeResize() { - // Concretize each resize op. - for (const auto& [id_index, iter_type] : info_->getResizeIterTypes()) { - auto id = info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); - TORCH_CHECK( - id->definition() && id->definition()->isA(), - "Resized IterDomain must have a Resize definition"); - auto def = id->definition()->as(); - auto new_id = IterDomain::resize( - def->in(), - def->leftExpand(), - def->rightExpand(), - id->isRFactorProduct(), - iter_type); - - registerConcretization(id, new_id); +void DynamicTransformConcretizer::concretizeIterDomain( + IterDomain* old_id, + IterDomain* new_id) { + auto old_ext = old_id->getMaybeExpandedExtent(); + auto new_ext = new_id->getMaybeExpandedExtent(); + + // If symbolic, redefine old_ext to equal new_ext + if (old_ext != new_ext && !old_ext->isConst() && !old_ext->definition()) { + // Set the old_ext->definition() to equal set(new_ext). This way, any + // further uses of old_ext are valid, as they will now be computed using + // new_ext. + IrBuilder::create(LoadStoreOpType::Set, old_ext, new_ext); } } @@ -449,6 +813,8 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { if (!tv->domain()->hasSymbolicAxis()) { return; } + tv = maybeConcretizeReshape(tv); + tv = maybeConcretizeSlice(tv); // First, try to concretize the root domain as there may be symbolic // axes inherited from the producers @@ -594,6 +960,10 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { registerConcretization(td, mutated_val); } +void DynamicTransformConcretizer::mutate(IterDomain* id) { + id = maybeConcretizeResize(id); +} + bool DynamicTransformConcretizer::propagateFromProducerToConsumer( TensorView* consumer) { if (consumer->definition() == nullptr || @@ -684,6 +1054,11 @@ size_t DynamicTransformConcretizationInfo::hash() const { for (const auto& [id, iter_type] : getResizeIterTypes()) { hashCombine(hash, (size_t)iter_type); } + for (const auto& [tv, slice_descs] : getSliceDescriptors()) { + for (const auto& desc : slice_descs) { + hashCombine(hash, desc.hash()); + } + } return hash; } diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 34609a002f1..4620f83a65d 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -40,7 +40,7 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { //! Return whether any dynamic transforms exist in the Fusion bool hasDynamicTransforms() const { - return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty(); + return !dynamic_expr_outputs_.empty(); } //! Return a set of scalars that are inputs or extents of input TensorViews @@ -50,16 +50,8 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { return root_dynamic_vals_; } - //! Return a vector of outputs of ViewOp expressions that have dynamic output - //! shapes - const std::vector& getDynamicReshapedTensorViews() const { - return dynamic_reshaped_tvs_; - } - - //! Return a vector of outputs of Resize expressions that have symbolic output - //! IterTypes - const std::vector& getDynamicResizedIterDomains() const { - return dynamic_resized_ids_; + const std::vector& getDynamicExprOutputs() const { + return dynamic_expr_outputs_; } std::string toString() const; @@ -89,9 +81,30 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { // definitions will merely be altered. When the ops are replaced, if we had // referred to them directly here, we would run into segfaults. Referring only // to the outputs avoids this issue. - std::vector dynamic_reshaped_tvs_; - - std::vector dynamic_resized_ids_; + // std::vector dynamic_reshaped_tvs_; + + // std::vector dynamic_resized_ids_; + + // Slice operations can have complicated output extents. The inputs to slice + // are a start, stop, and step for each sliced dimension. Each of these is an + // integer, and any combination of three finite integers with step != 0 is + // acceptable and should run without error. Normalization of the start and + // stop values must be done, followed by computation of the output extent: + // + // normed_start = min(max(where(start < 0, extent + start, start), 0), + // extent); normed_stop = max(min(max(where(stop < 0, extent + stop, stop), + // 0), extent), normed_start); extent = max((normed_stop - normed_start + 1) + // / step, 0); + // + // These expressions are unwieldy and cannot be significantly simplified + // unless we know certain relations about the start, stop, and step scalars. + // Here we keep track of non-static slices or slices with non-static input + // extents. That way we can restrict to a single branch in each of these + // expressions during concretization. + // std::vector dynamic_sliced_tvs_; + + // This is a topologically sorted list of outputs of dynamic operations. + std::vector dynamic_expr_outputs_; // Root Vals that determine concretization std::unordered_set root_dynamic_vals_; @@ -99,6 +112,54 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { friend class DynamicTransformInitialInfoBuilder; }; +//! This enum describes cases that can occur for the start or stop arguments to +//! slice(). Each of these leads to a different branch in the normalized form's +//! general expression. +enum class SliceIndexBranch { + Negative, // -extent < a < 0 + Zero, // a == 0 OR a <= -extent + Positive, // 0 < a < extent + Extent // extent <= a +}; + +//! This enum describes the "step" argument to slice, which can be a positive or +//! negative integer (but not zero). We handle the special case of step == 1 +//! separately from step > 1 since this simplifies some expressions. +enum class SliceStepBranch { Negative, One, GreaterThanOne }; + +//! Describes a 1D slice in terms of the start, stop, and extent values +struct Concrete1DSliceDescriptor { + //! These enums determine the form of the simplified expressions + SliceIndexBranch start_branch = SliceIndexBranch::Zero; + SliceIndexBranch stop_branch = SliceIndexBranch::Extent; + SliceStepBranch step_branch = SliceStepBranch::One; + + //! True if normalized values satisfy (stop - start) * step <= 0 in which case + //! we would return an empty tensor. + bool is_empty = false; + + //! This can be either Iteration or Broadcast (if sliced extent is 1) + IterType iter_type = IterType::Iteration; + + bool operator==(const Concrete1DSliceDescriptor& other) const { + return start_branch == other.start_branch && + stop_branch == other.stop_branch && step_branch == other.step_branch && + is_empty == other.is_empty && iter_type == other.iter_type; + } + bool operator!=(const Concrete1DSliceDescriptor& other) const { + return !operator==(other); + } + + size_t hash() const { + size_t h = (size_t)start_branch; + hashCombine(h, (size_t)stop_branch); + hashCombine(h, (size_t)step_branch); + hashCombine(h, (size_t)is_empty); + hashCombine(h, (size_t)iter_type); + return h; + } +}; + //! A set of transformations for a symbolic fusion with concrete sizes //! of the fusion inputs class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { @@ -115,9 +176,7 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { // evaluator when any one of the IDs has a known value expr_eval->propagateBoundValuesThroughExactMaps(initial_info->fusion()); - analyzeReshapes(expr_eval); - - analyzeResizes(expr_eval); + analyze(expr_eval); } //! Return a vector of pairs holding the index of each reshaped TensorView in @@ -136,6 +195,15 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { return resize_itertypes_; } + //! Return a vector of pairs holding the index of each sliced TensorView in + //! the vector returned by initialInfo()->getDynamicSlicedTensorViews(), + //! along with a vector of descriptors indicating how each axis should be + //! concretized. + const std::vector>>& + getSliceDescriptors() const { + return slice_descriptors_; + } + //! Comparison operator for the purposes of determining cache hits. This does //! not guarantee equality of all members. Instead, it returns equal if the //! resulting concretizations would be structurally equivalent. Note that @@ -148,13 +216,21 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { } //! Given an ExpressionEvaluator which already has input scalars bound to it, - //! determine the decomposition of each dynamic reshape operation to use + //! analyze all dynamic ops in topological order. + void analyze(ExpressionEvaluator* expr_eval); + + //! Given an ExpressionEvaluator which already has input scalars bound to it, + //! determine the decomposition of a dynamic reshape operation to use //! during concretization. - void analyzeReshapes(ExpressionEvaluator* expr_eval); + void analyzeReshape(ExpressionEvaluator* expr_eval, size_t val_index); + + //! Given an ExpressionEvaluator which already has input scalars bound to it, + //! determine the branches of expressions in a dynamic slice op. + void analyzeSlice(ExpressionEvaluator* expr_eval, size_t val_index); //! Given an ExpressionEvaluator which already has input scalars bound to it, - //! determine the concrete IterType of each resized IterDomain. - void analyzeResizes(ExpressionEvaluator* expr_eval); + //! determine the concrete IterType of a resized IterDomain. + void analyzeResize(ExpressionEvaluator* expr_eval, size_t val_index); const DynamicTransformInitialInfo* initialInfo() const { return initial_info_; @@ -189,6 +265,12 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { //! vector returned by initial_info_->getDynamicResizedIterDomains() along //! with its concretized IterType std::vector> resize_itertypes_; + + //! Holds the index of the sliced TensorView (output of the SliceOp) in the + //! vector returned by initial_info_->getDynamicSlicedTensorViews() along + //! with a descriptor of how it should be concretized. + std::vector>> + slice_descriptors_; }; class TORCH_CUDA_CU_API DynamicTransform { diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5a3933f4363..8293651b06f 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -367,6 +367,12 @@ NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values) makeUnaryOp(uop); } else if (auto bop = dynamic_cast(def)) { makeBinaryOp(bop); + } else if (auto setop = dynamic_cast(def)) { + TORCH_INTERNAL_ASSERT( + setop->opType() == LoadStoreOpType::Set, + "NaiveValueMachine: unsupported LoadStoreOpType: ", + setop->opType()); + makeSetOp(setop); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr"); } @@ -448,6 +454,18 @@ void NaiveValueMachine::makeBinaryOp(BinaryOp* bop) { dest_[index] = out; } +void NaiveValueMachine::makeSetOp(LoadStoreOp* lsop) { + int in = lsop->inputs()[0]->evaluatorIndex(); + int out = lsop->outputs()[0]->evaluatorIndex(); + TORCH_INTERNAL_ASSERT(in >= 0, "Integer Machine: unknown input: ", lsop); + TORCH_INTERNAL_ASSERT(out >= 0, "Integer Machine: unknown out: ", lsop); + + int index = makeInstructionEntry(); + inst_type_[index] = InstructionType::SET_OP; + src0_[index] = in; + dest_[index] = out; +} + int NaiveValueMachine::makeInstructionEntry() { int index = num_of_instructions_++; inst_type_.emplace_back(InstructionType::UNARY_OP); diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index 112d770a27e..3bb198bf78b 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -55,6 +55,10 @@ class NaiveValueMachine { //! Convert an binary IR expr to an instruction void makeBinaryOp(BinaryOp* bop); + //! Convert a LoadStoreOp expr to an instruction. This assumes lsop->opType() + //! is equal to LoadStoreOpType::Set. + void makeSetOp(LoadStoreOp* lsop); + //! Create an empty instruction with all default values //! and place it at the end of the instruction buffer. int makeInstructionEntry(); diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 295e11ad789..1bf336c9272 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2113,7 +2113,14 @@ std::string LoadStoreOp::toString(int indent_size) const { } std::string LoadStoreOp::toInlineString(int indent_size) const { - TORCH_CHECK(false, "Tensor op can not be printed inline"); + if (opType() == LoadStoreOpType::Set) { + TORCH_CHECK( + !in()->isA(), "Cannot print TensorView set() inline"); + std::stringstream ss; + indent(ss, indent_size) << "set(" << in()->toInlineString() << ")"; + return ss.str(); + } + TORCH_CHECK(false, "Non-'Set' LoadStoreOp cannot be printed inline"); } bool LoadStoreOp::hasTranspose() const { diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 3315cb3167d..f77c75fd503 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -654,7 +654,7 @@ struct ReplaceValInIndexVal : public OptInDispatch { // Recursively traverse its defining expr auto def = val->definition(); if (def != nullptr) { - if (def->isOneOf()) { + if (def->isOneOf()) { handle(val->definition()); } else { TORCH_INTERNAL_ASSERT(false, "Unexpected definition: ", def->toString()) @@ -665,6 +665,19 @@ struct ReplaceValInIndexVal : public OptInDispatch { } } + // Clone expression after recurisvely replacing inputs + void handle(LoadStoreOp* lsop) override { + handle(lsop->in()); + auto inp = last_visited_val_; + TORCH_INTERNAL_ASSERT( + lsop->out()->isA() || lsop->out()->isA(), + "Unknown output type for expr ", + lsop->toInlineString()); + auto out = IrBuilder::create(c10::nullopt); + IrBuilder::create(lsop->opType(), out, inp); + last_visited_val_ = out; + } + // Clone expression after recurisvely replacing inputs void handle(UnaryOp* uop) override { handle(uop->in()); diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 6e7702e24b8..32e8e674e07 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -293,6 +293,10 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( FusionExecutorCache::FusionExecutorCache(std::unique_ptr fusion) : fusion_(std::move(fusion)) {} +Fusion* FusionExecutorCache::getMostRecentConcretizedFusion() const { + return most_recent_runtime_->fusionSegments()->completeFusion(); +} + KernelArgumentHolder FusionExecutorCache::prepareInputs( const at::ArrayRef& inputs, std::optional selected_device) { diff --git a/csrc/kernel_cache.h b/csrc/kernel_cache.h index 8fc0cc51a3d..a2fd2cbbb62 100644 --- a/csrc/kernel_cache.h +++ b/csrc/kernel_cache.h @@ -507,6 +507,8 @@ class TORCH_CUDA_CU_API FusionExecutorCache { fusion_->printMath(); } + Fusion* getMostRecentConcretizedFusion() const; + FusionKernelRuntime* getMostRecentKernelRuntime() const { return most_recent_runtime_; } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 49ed3b9912d..35a604a2e32 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -639,10 +639,14 @@ TensorView* cat(const std::vector& inputs, int64_t cat_dim) { return out; } -// Currently there's no error check about the actual values of the -// Slice parameters. For example, the start parameter of a range of a -// domain is assumed to be >= 0 and < the extent of the domain. -TensorView* slice(TensorView* inp, const std::vector& ranges) { +// If skip_symbolic is true, then the start and stop parameters of a range of a +// domain is assumed to be >= 0 and < the extent of the domain. Otherwise, +// non-constant inputs will lead to Symbolic IterDomains in the output, which +// must be later concretized. +TensorView* slice( + TensorView* inp, + const std::vector& ranges, + bool skip_symbolic) { const auto inp_dom = TensorDomain::noReductions(inp->getMaybeRFactorDomain()); const int ndims = static_cast(inp_dom.size()); @@ -666,6 +670,19 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { return range; }; + // Adjust an integer value relative to a given extent. This is + // min(max(where(a < 0, extent + a, a), 0), extent) + auto adjust_start_stop = [](int64_t& a, int64_t extent) { + if (a < 0) { + a += extent; + } + if (a < 0) { + a = 0; + } else if (a > extent) { + a = extent; + } + }; + for (auto& range : ranges) { // Step not supported yet TORCH_CHECK( @@ -681,11 +698,12 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { bool needs_real_slicing = false; for (const auto idx : c10::irange(ndims)) { auto inp_root_id = inp_dom[idx]; - auto range = normalize_slice_range(ranges.at(idx), inp_root_id->extent()); + auto inp_extent = inp_root_id->getMaybeExpandedExtent(); + auto range = normalize_slice_range(ranges.at(idx), inp_extent); normalized_ranges.at(idx) = range; IterDomain* out_root_id = nullptr; IterDomain* out_rf_id = nullptr; - if (range.start->isZeroInt() && range.stop->sameAs(inp_root_id->extent()) && + if (range.start->isZeroInt() && range.stop->sameAs(inp_extent) && range.step->isOneInt()) { // This dim doesn't need slicing out_root_id = inp_root_id->cloneWithoutRFactor(); @@ -693,11 +711,40 @@ TensorView* slice(TensorView* inp, const std::vector& ranges) { } else { out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); - out_rf_id = IterDomain::resize( - out_root_id, - SimplifyingIrBuilder::negExpr(range.start), - sub(range.stop, inp_root_id->extent()), - true); + // The start, stop, and extent of the output will all require complicated + // expressions which will be simplified at concretization. Here we set + // the output to Symbolic unless all required scalars are constant. + if (range.start->isConstInt() && range.stop->isConstInt() && + inp_extent->isConstInt()) { + auto start = range.start->evaluateInt(); + auto stop = range.stop->evaluateInt(); + auto step = range.step->evaluateInt(); + TORCH_INTERNAL_ASSERT(step != 0, "Slice step must be non-zero"); + TORCH_INTERNAL_ASSERT( + step == 1, "Slicing with step != 1 is not currently supported"); + auto inp_extent_val = inp_extent->evaluateInt(); + adjust_start_stop(start, inp_extent_val); + adjust_start_stop(stop, inp_extent_val); + out_rf_id = IterDomain::resize( + out_root_id, + SimplifyingIrBuilder::negExpr(IrBuilder::create(start)), + sub(IrBuilder::create(stop), inp_extent), + true); + } else if (skip_symbolic) { + out_rf_id = IterDomain::resize( + out_root_id, + SimplifyingIrBuilder::negExpr(range.start), + sub(range.stop, inp_extent), + true, + IterType::Iteration); + } else { + out_rf_id = IterDomainBuilder( + FusionGuard::getCurFusion()->zeroVal(), + IrBuilder::create()) + .is_rfactor_domain(true) + .iter_type(IterType::Symbolic) + .build(); + } needs_real_slicing = true; } root_ids.at(idx) = out_root_id; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 1d3299fd6e5..c2625de6355 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -106,8 +106,13 @@ TORCH_CUDA_CU_API TensorView* cat( //! Return a tensor where each dimension is sliced as specified by the //! ranges parameter. Stepping must be one at this moment. +//! If skip_symbolic is true, we assume start < stop and both start and stop are +//! between 0 and extent (inclusive). Otherwise, unless all input scalars are +//! constant, the output will have Symbolic IterDomains that must be concretized +//! later. TORCH_CUDA_CU_API TensorView* slice( TensorView* inp, - const std::vector& ranges); + const std::vector& ranges, + bool skip_symbolic = false); } // namespace nvfuser diff --git a/test/test_resize.cpp b/test/test_resize.cpp index f96fc05b9a2..4c312ef7dea 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -1242,7 +1242,7 @@ TEST_F(NVFuserTest, FusionResizeSliceReduceScheduler1_CUDA) { auto ref = t1.sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref}, @@ -1300,7 +1300,7 @@ TEST_F(NVFuserTest, FusionResizeSliceReduceScheduler2_CUDA) { auto t4 = t3.sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2, t4}, @@ -1354,7 +1354,7 @@ TEST_F(NVFuserTest, FusionSliceReduceScheduler3_CUDA) { auto t4 = t3.to(at::kDouble).sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2, t4}, @@ -1391,7 +1391,7 @@ TEST_F(NVFuserTest, FusionResizeCatReduceScheduler1_CUDA) { auto ref = at::cat({t0, t1}, 1).sum({1}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref}, @@ -1429,7 +1429,7 @@ TEST_F(NVFuserTest, FusionResizeCatSoftmaxScheduler1_CUDA) { auto ref = at::_softmax(t2.to(at::kDouble), -1, false); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref}, @@ -1466,7 +1466,7 @@ TEST_F(NVFuserTest, FusionResizeReductionSliceScheduler1_CUDA) { auto t2 = t1.index({at::indexing::Slice(1, shape0[0] - 2)}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2}, @@ -1507,7 +1507,7 @@ TEST_F(NVFuserTest, FusionResizeSoftmaxSliceScheduler1_CUDA) { at::indexing::Slice(0, at::indexing::None)}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2}, @@ -1548,7 +1548,7 @@ TEST_F(NVFuserTest, FusionResizeSoftmaxSliceScheduler2_CUDA) { at::indexing::Slice(1, shape0[1] - 2)}); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {t2}, @@ -1710,7 +1710,7 @@ TEST_F(NVFuserTest, FusionSliceForNanoGPT1_CUDA) { auto aten_t3 = torch::add(aten_t0_slice, t1); testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {aten_t3, aten_t3}, @@ -1777,10 +1777,15 @@ TEST_F(NVFuserTest, FusionSliceForNanoGPT2_CUDA) { continue; } auto out_tv = ir_utils::getTvOutput(expr); - if (out_tv->name() == tv3->name() || out_tv->name() == tv5->name()) { + if ( + // Note: We want to check tv3 and tv5 here, which are the original slice + // outputs. During concretization, T3 gets concretized to T8 and T5 gets + // concretized to T9. Here we manually translate those since we don't + // maintain a mapping from pre-concretized to concretized Vals. + out_tv->name() == 8 || out_tv->name() == 9) { TORCH_CHECK( expr->isA(), - "Unexpected defintion of slice output tensor: ", + "Unexpected definition of slice output tensor: ", out_tv->toString(), ", ", expr->toString()); @@ -1836,7 +1841,7 @@ TEST_F(NVFuserTest, FusionSliceForNanoGPT2_CUDA) { auto aten_t7 = aten_t2 + 1; testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {aten_t4, aten_t4, aten_t7}, @@ -2001,7 +2006,7 @@ TEST_F(NVFuserTest, ResizePermuteAndSlice_CUDA) { auto ref_t4 = ref_t2 + 1; testValidate( - executor_cache.fusion(), + executor_cache.getMostRecentConcretizedFusion(), cg_outputs, aten_inputs, {ref_t3, ref_t4},