From e29e3a1ee86a45d931881094e1bcdc654818e6f2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 26 Jul 2023 14:13:42 -0400 Subject: [PATCH 1/6] Initial broken half implementation --- csrc/dynamic_transform.cpp | 435 +++++++++++++++++++------------------ csrc/dynamic_transform.h | 85 +++----- 2 files changed, 256 insertions(+), 264 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index c82520e8a11..27f55741132 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -25,36 +25,32 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( IrCloner& ir_cloner) const { DynamicTransformInitialInfo cloned_info( static_cast(ir_cloner.container())); - cloned_info.dynamic_reshaped_tvs_.reserve(dynamic_reshaped_tvs_.size()); - for (const auto op : dynamic_reshaped_tvs_) { - if (op) { - cloned_info.dynamic_reshaped_tvs_.push_back(ir_cloner.clone(op)); - } - } - cloned_info.dynamic_resized_ids_.reserve(dynamic_resized_ids_.size()); - for (const auto op : dynamic_resized_ids_) { - if (op) { - cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op)); - } - } + + cloned_info.dynamic_exprs_.reserve(dynamic_exprs_.size()); + std::transform( + dynamic_exprs_.begin(), + dynamic_exprs_.end(), + std::back_inserter(cloned_info.dynamic_exprs_), + [&ir_cloner](TensorView* tv) { return ir_cloner.clone(tv); }); + cloned_info.maybe_zero_extents_set_.reserve(maybe_zero_extents_set_.size()); - for (const auto v : maybe_zero_extents_set_) { - if (v) { - cloned_info.maybe_zero_extents_set_.insert(ir_cloner.clone(v)); - } - } - cloned_info.maybe_zero_extents_.reserve(maybe_zero_extents_.size()); - for (const auto v : maybe_zero_extents_) { - if (v) { - cloned_info.maybe_zero_extents_.push_back(ir_cloner.clone(v)); - } - } + std::transform( + maybe_zero_extents_set_.begin(), + maybe_zero_extents_set_.end(), + std::inserter( + cloned_info.maybe_zero_extents_set_, + cloned_info.maybe_zero_extents_set_.begin()), + [&ir_cloner](TensorView* tv) { return ir_cloner.clone(tv); }); + cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); - for (const auto v : root_dynamic_vals_) { - if (v) { - cloned_info.root_dynamic_vals_.insert(ir_cloner.clone(v)); - } - } + std::transform( + root_dynamic_vals_.begin(), + root_dynamic_vals_.end(), + std::inserter( + cloned_info.root_dynamic_vals_, + cloned_info.root_dynamic_vals_.begin()), + [&ir_cloner](TensorView* tv) { return ir_cloner.clone(tv); }); + return cloned_info; } @@ -62,13 +58,9 @@ std::string DynamicTransformInitialInfo::toString() const { std::stringstream ss; ss << "DynamicTransformInitialInfo\n"; std::string indent = " "; - ss << indent << "Dynamic reshaped TensorViews:\n"; - for (const auto& op : dynamic_reshaped_tvs_) { - ss << indent << indent << op->toString() << "\n"; - } - ss << indent << "Dynamic resized IterDomains:\n"; - for (const auto& op : dynamic_resized_ids_) { - ss << indent << indent << op->toString() << "\n"; + ss << indent << "Dynamic expressions:\n"; + for (const auto& expr : dynamic_exprs_) { + ss << indent << indent << expr->toString(); } ss << indent << "Dynamic extent Vals:\n"; for (const auto& v : maybe_zero_extents_) { @@ -78,6 +70,10 @@ std::string DynamicTransformInitialInfo::toString() const { for (const auto& v : root_dynamic_vals_) { ss << indent << indent << v->toString() << "\n"; } + ss << indent << "Input positions affecting concretization:\n"; + for (const auto& pos : scalar_inputs_affecting_concretization_) { + ss << indent << indent << std::to_string(pos) << "\n"; + } return ss.str(); } @@ -110,7 +106,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { auto out_tv = op->out()->as(); // If there's no symbolic axis, this is a static reshape op if (out_tv->domain()->hasSymbolicAxis()) { - info_.dynamic_reshaped_tvs_.push_back(out_tv); + info_.dynamic_exprs_.push_back(op); // Input and output extent expressions both affect concretization for (const auto& id : @@ -136,7 +132,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { continue; } if (id->definition()->isA()) { - info_.dynamic_resized_ids_.push_back(id); + info_.dynamic_exprs_.push_back(id->definition()); // extent of output determines its IterType leaf_dynamic_vals_.push_back(id->extent()); } @@ -196,9 +192,15 @@ DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( // evaluator when any one of the IDs has a known value expr_eval->propagateBoundValuesThroughExactMaps(initial_info_->fusion()); - analyzeReshapes(expr_eval); - - analyzeResizes(expr_eval); + for (auto expr : initial_info_->getDynamicExprs()) { + if (auto vop = dynamic_cast(expr)) { + analyze(vop, expr_eval); + } else if (auto rop = dynamic_cast(expr)) { + analyze(rop, expr_eval); + } else { + TORCH_CHECK(false, "Unhandled dynamic Expr type: ", expr->toString()); + } + } auto maybe_zero_extents = initial_info_->getMaybeZeroExtents(); for (auto i : c10::irange(maybe_zero_extents.size())) { @@ -214,119 +216,112 @@ DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( } } -void DynamicTransformConcretizationInfo::analyzeReshapes( +void DynamicTransformConcretizationInfo::analyze( + ViewOp* op, ExpressionEvaluator* expr_eval) { - const auto& reshape_tvs = initial_info_->getDynamicReshapedTensorViews(); - for (const auto tv_index : c10::irange(reshape_tvs.size())) { - auto out_tv = reshape_tvs.at(tv_index); - auto op = out_tv->definition()->as(); - auto inp_tv = op->in()->as(); + auto out_tv = op->out()->as(); + auto inp_tv = op->in()->as(); - // If there's no symblic axis, this is a static reshape op - if (!out_tv->domain()->hasSymbolicAxis()) { - return; - } + // If there's no symblic axis, this is a static reshape op + if (!out_tv->domain()->hasSymbolicAxis()) { + return; + } + TORCH_INTERNAL_ASSERT( + out_tv->hasRFactor(), + "Unexpected output tv of ViewOp: ", + out_tv->toString()); + + const auto& inp_dom = + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); + + // Determine input shape using expr evaluator + std::vector inp_shape(inp_dom.size(), 0); + for (const auto i : c10::irange(inp_dom.size())) { + auto inp_id = inp_dom.at(i); + // This should have been validated when initially creating reshape + // op, but just in case TORCH_INTERNAL_ASSERT( - out_tv->hasRFactor(), - "Unexpected output tv of ViewOp: ", - out_tv->toString()); - - const auto& inp_dom = - TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); - - // Determine input shape using expr evaluator - std::vector inp_shape(inp_dom.size(), 0); - for (const auto i : c10::irange(inp_dom.size())) { - auto inp_id = inp_dom.at(i); - // This should have been validated when initially creating reshape - // op, but just in case - TORCH_INTERNAL_ASSERT( - !inp_id->maybePartial(), - "Invalid domain to reshape: ", - inp_id->toString()); - auto extent_val = expr_eval->evaluate(inp_id->extent()); - TORCH_INTERNAL_ASSERT( - extent_val.hasValue(), - "Cannot evaluate the extent of an input domain to reshape: ", - inp_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val.is(), - "Invalid evaluated value of domain extent: ", - inp_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val.as() > 0, - "Invalid input domain extent: ", - extent_val.as()); - inp_shape.at(i) = extent_val.as(); - } - - const auto& out_dom = out_tv->getMaybeRFactorDomain(); - - // Determine output shape using expr evaluator. Note there may be - // one domain of extent -1 - std::vector out_shape(out_dom.size(), 0); - for (const auto i : c10::irange(out_dom.size())) { - auto out_id = out_dom.at(i); - auto extent_val = expr_eval->evaluate(out_id->extent()); - TORCH_INTERNAL_ASSERT( - extent_val.hasValue(), - "Cannot evaluate the extent of an output domain to reshape: ", - out_id->toString()); - TORCH_INTERNAL_ASSERT( - extent_val.is(), - "Invalid evaluated value of domain extent: ", - out_id->toString()); - auto extent_int = extent_val.as(); - if (extent_int == -1) { - // For non-constant Scalar sizes, check that we have not passed -1. - TORCH_CHECK( - out_id->extent()->isConst(), - "Values of -1 passed to reshape must be constant at definition.") - } - out_shape.at(i) = extent_int; - } - - auto view_result = analyzeView(inp_tv, inp_shape, out_shape); - - reshape_transforms_.emplace_back(tv_index, view_result); + !inp_id->maybePartial(), + "Invalid domain to reshape: ", + inp_id->toString()); + auto extent_val = expr_eval->evaluate(inp_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.hasValue(), + "Cannot evaluate the extent of an input domain to reshape: ", + inp_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val.is(), + "Invalid evaluated value of domain extent: ", + inp_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val.as() > 0, + "Invalid input domain extent: ", + extent_val.as()); + inp_shape.at(i) = extent_val.as(); } -} -void DynamicTransformConcretizationInfo::analyzeResizes( - ExpressionEvaluator* expr_eval) { - const auto& resize_ids = initial_info_->getDynamicResizedIterDomains(); - for (const auto id_index : c10::irange(resize_ids.size())) { - auto out_id = resize_ids.at(id_index); - auto op = out_id->definition()->as(); - - TORCH_CHECK( - out_id->getIterType() == IterType::Symbolic, - "Found non-dynamic Resize in initial concretization info: ", - op->toString()); + const auto& out_dom = out_tv->getMaybeRFactorDomain(); + // Determine output shape using expr evaluator. Note there may be + // one domain of extent -1 + std::vector out_shape(out_dom.size(), 0); + for (const auto i : c10::irange(out_dom.size())) { + auto out_id = out_dom.at(i); auto extent_val = expr_eval->evaluate(out_id->extent()); TORCH_INTERNAL_ASSERT( extent_val.hasValue(), - "Cannot evaluate the extent of a resized domain: ", + "Cannot evaluate the extent of an output domain to reshape: ", out_id->toString()); TORCH_INTERNAL_ASSERT( extent_val.is(), - "Invalid evaluated value of resized domain extent: ", + "Invalid evaluated value of domain extent: ", out_id->toString()); auto extent_int = extent_val.as(); - TORCH_INTERNAL_ASSERT( - extent_int > 0, - "Invalid resized domain extent ", - extent_int, - " for domain ", - out_id->toString()); + if (extent_int == -1) { + // For non-constant Scalar sizes, check that we have not passed -1. + TORCH_CHECK( + out_id->extent()->isConst(), + "Values of -1 passed to reshape must be constant at definition.") + } + out_shape.at(i) = extent_int; + } - auto iter_type = - extent_int == 1 ? IterType::Broadcast : IterType::Iteration; + auto view_result = analyzeView(inp_tv, inp_shape, out_shape); - resize_itertypes_.emplace_back(id_index, iter_type); - } + concretization_descriptors_.emplace_back(view_result); +} + +void DynamicTransformConcretizationInfo::analyze( + Resize* op, + ExpressionEvaluator* expr_eval) { + auto out_id = op->out()->as(); + + TORCH_CHECK( + out_id->getIterType() == IterType::Symbolic, + "Found non-dynamic Resize in initial concretization info: ", + op->toString()); + + auto extent_val = expr_eval->evaluate(out_id->extent()); + TORCH_INTERNAL_ASSERT( + extent_val.hasValue(), + "Cannot evaluate the extent of a resized domain: ", + out_id->toString()); + TORCH_INTERNAL_ASSERT( + extent_val.is(), + "Invalid evaluated value of resized domain extent: ", + out_id->toString()); + auto extent_int = extent_val.as(); + TORCH_INTERNAL_ASSERT( + extent_int > 0, + "Invalid resized domain extent ", + extent_int, + " for domain ", + out_id->toString()); + + auto iter_type = extent_int == 1 ? IterType::Broadcast : IterType::Iteration; + + concretization_descriptors_.emplace_back(iter_type); } bool DynamicTransformConcretizationInfo::operator==( @@ -335,28 +330,15 @@ bool DynamicTransformConcretizationInfo::operator==( return true; } - if (reshape_transforms_.size() != other.reshape_transforms_.size() || - resize_itertypes_.size() != other.resize_itertypes_.size() || - empty_extents_.size() != other.empty_extents_.size()) { + if (concretization_descriptors_.size() != + other.concretization_descriptors_.size() || + !std::equal( + concretization_descriptors_.begin(), + concretization_descriptors_.end(), + other.concretization_descriptors_.begin())) { return false; } - for (const auto i : c10::irange(reshape_transforms_.size())) { - const auto& analysis = reshape_transforms_.at(i); - const auto& other_analysis = other.reshape_transforms_.at(i); - if (analysis != other_analysis) { - return false; - } - } - - for (const auto i : c10::irange(resize_itertypes_.size())) { - const auto& itertype = resize_itertypes_.at(i); - const auto& other_itertype = other.resize_itertypes_.at(i); - if (itertype != other_itertype) { - return false; - } - } - for (const auto i : c10::irange(empty_extents_.size())) { const auto& ee = empty_extents_.at(i); const auto& other_ee = other.empty_extents_.at(i); @@ -377,19 +359,16 @@ std::string DynamicTransformConcretizationInfo::toString() const { auto ext = initial_info_->getMaybeZeroExtents().at(i); ss << indent << indent << ext->toString() << " is zero\n"; } - ss << indent << "Reshape:\n"; - for (const auto& [tv_index, analyze_result] : reshape_transforms_) { - auto tv = initial_info_->getDynamicReshapedTensorViews().at(tv_index); - ss << indent << indent << tv->toString() << " (index=" << tv_index << "), " - << analyze_result.toString() << "\n"; - } - ss << indent << "Resize:\n"; - for (const auto& [id_index, iter_type] : resize_itertypes_) { - auto id = initial_info_->getDynamicResizedIterDomains().at(id_index); - ss << indent << indent << id->toString() << " (index=" << id_index << "), " - << iter_type << "\n"; + ss << indent << "Dynamic expression concretization descriptors:\n"; + for (const auto& desc_var : concretization_descriptors_) { + if (auto analyze_result_ptr = std::get_if(&desc_var)) { + ss << indent << indent + << "AnalyzeViewResult: " << analyze_result_ptr->toString() << "\n"; + } else if (auto iter_type_ptr = std::get_if(&desc_var)) { + ss << indent << indent << "IterType: " << (*iter_type_ptr) << "\n"; + } + return ss.str(); } - return ss.str(); } //! Concretize a symbolic fusion with concrete transformation info @@ -409,22 +388,23 @@ class DynamicTransformConcretizer : public OptOutMutator { private: void concretize(); - void concretizeReshape(); + void concretizeReshape(ViewOp* op, AnalyzeViewResult& view_analysis) const; - void concretizeResize(); + void concretizeResize(Resize* op, IterType iter_type) const; void concretizeEmptyExtents(); - //! Use this instead of calling registerMutation directly, since it will also - //! check that the concretized value is a valid input to all of its uses. + //! Use this instead of calling registerMutation directly, since it will + //! also check that the concretized value is a valid input to all of its + //! uses. void registerConcretization(Val* old_val, Val* new_val) { checkConcretizedUses(old_val, new_val); registerMutation(old_val, new_val); } //! Check uses of old_val to ensure that new_val does not violate - //! assumptions. This is currently only used to check that inputs to SqueezeOp - //! are marked broadcast during concretization. + //! assumptions. This is currently only used to check that inputs to + //! SqueezeOp are marked broadcast during concretization. void checkConcretizedUses(Val* old_val, Val* new_val) const; using OptOutMutator::mutate; @@ -442,13 +422,47 @@ class DynamicTransformConcretizer : public OptOutMutator { }; void DynamicTransformConcretizer::concretize() { - // Concretize all dynamic reshape ops - concretizeReshape(); - - // Set output IterTypes for dynamic resize ops - concretizeResize(); + // First, concretize dynamic Exprs. We do this in reverse topological order. + // + // To understand why we use reverse topo order, consider what happens when + // we concretize a dynamic reshape `auto tv3 = reshape(tv2, sh);`. To do + // this, we create a new TensorView, let's call it `conc_tv` using the + // particular AnalyzeViewResult in info_. Then, we replace `tv3` with + // `conc_tv` in all uses of `tv3` using `ir_utils::replaceValInExpr`. That + // function actually replaces those use Exprs with new ones and removes the + // old pointers from the Fusion. As long as we do not dereference any of + // these Exprs after this pass, and we do this pass in reverse topological + // order, we will not encounter Expr* pointers that have been invalidated + // due to `replaceValInExpr` like this. + for (int i = info_->getExprConcretizationDescriptors().size() - 1; i > 0; + --i) { + auto op = info_->initialInfo()->getDynamicExprs().at(i); + auto desc_var = info_->getExprConcretizationDescriptors().at(i); + if (auto vop = dynamic_cast(op)) { + if (auto analyze_result_ptr = std::get_if(desc_var)) { + concretizeReshape(op, *analyze_result_ptr); + } else { + TORCH_CHECK( + false, + "Dynamic ViewOp expects AnalyzeViewResult descriptor but found variant index ", + desc_var.index()); + } + } else if (auto rop = dynamic_cast(op)) { + if (auto iter_type_ptr = std::get_if(desc_var)) { + concretizeResize(op, *iter_type_ptr); + } else { + TORCH_CHECK( + false, + "Dynamic Resize expects IterType descriptor but found variant index ", + desc_var.index()); + } + } else { + TORCH_CHECK(false, "Unhandled dynamic Expr type: ", op->toString()); + } + } - // Registers replacement of all empty extents with zeroVal() + // Registers replacement of all empty extents with zeroVal(). This does not + // modify the Fusion. concretizeEmptyExtents(); // Finally, propagate concretized domains @@ -471,19 +485,21 @@ void DynamicTransformConcretizer::concretizeEmptyExtents() { for (auto use : uses) { ir_utils::replaceValInExpr(use, ext, zero); } - // Register the concretization of this scalar, which allows us to replace it - // whenever it is used as an extent member of an IterDomain. + // Register the concretization of this scalar, which allows us to replace + // it whenever it is used as an extent member of an IterDomain. // // When we ext in all uses above, it affects downstream expressions. For // example we might replace i0 with 0 in (i0 + i1) + i2 to form (0 + i1) + - // i2. However, i0 itself might be used as the extent, start, or stop values - // in an IterDomain, so we register the concretization here so that we can - // replace these values whenever we encounter them. + // i2. However, i0 itself might be used as the extent, start, or stop + // values in an IterDomain, so we register the concretization here so that + // we can replace these values whenever we encounter them. registerConcretization(ext, zero); } } -void DynamicTransformConcretizer::concretizeReshape() { +void DynamicTransformConcretizer::concretizeReshape( + ViewOp* view_op, + AnalyzeViewResult& view_analysis) const { // Concretize each reshape op. for (const auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { auto incomplete_out_tv = @@ -498,8 +514,8 @@ void DynamicTransformConcretizer::concretizeReshape() { checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); // Extent expressions often change when concretizing a reshape. Here we - // replace these in all downstream expressions so that the Fusion looks just - // like it would have if we had used a static reshape instead. + // replace these in all downstream expressions so that the Fusion looks + // just like it would have if we had used a static reshape instead. auto old_rfactor = incomplete_out_tv->getMaybeRFactorDomain(); auto new_rfactor = concrete_reshape_out_tv->getMaybeRFactorDomain(); TORCH_INTERNAL_ASSERT( @@ -532,23 +548,19 @@ void DynamicTransformConcretizer::concretizeReshape() { } } -void DynamicTransformConcretizer::concretizeResize() { +void DynamicTransformConcretizer::concretizeResize( + Resize* op, + IterType iter_type) const { // Concretize each resize op. - for (const auto& [id_index, iter_type] : info_->getResizeIterTypes()) { - auto id = info_->initialInfo()->getDynamicResizedIterDomains().at(id_index); - TORCH_CHECK( - id->definition() && id->definition()->isA(), - "Resized IterDomain must have a Resize definition"); - auto def = id->definition()->as(); - auto new_id = IterDomain::resize( - def->in(), - def->leftExpand(), - def->rightExpand(), - id->isRFactorProduct(), - iter_type); - - registerConcretization(id, new_id); - } + auto out_id = op->out()->as(); + auto new_id = IterDomain::resize( + op->in(), + op->leftExpand(), + op->rightExpand(), + out_id->isRFactorProduct(), + iter_type); + + registerConcretization(out_id, new_id); } void DynamicTransformConcretizer::checkConcretizedUses( @@ -575,9 +587,9 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { propagateFromProducerToConsumer(tv); // If no root domain is altered by producer, we don't need to propagate back - // up to rfactor. We could return early, but instead we go ahead and check the - // root to rfactor transforms to be sure we have concretized any intermediate - // IterDomains. + // up to rfactor. We could return early, but instead we go ahead and check + // the root to rfactor transforms to be sure we have concretized any + // intermediate IterDomains. // At this point, there should be no expr beyond rfactor root TORCH_INTERNAL_ASSERT( @@ -614,16 +626,17 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // padded with constant pad widths (1, 1), in which case although we do // not know the exact extent of the output, we know it is at least as // large as the sum of the pad widths, 2. In such cases, the output - // IterDomain is concrete at definition, since if the extent is >1 we know - // the IterType is Iteration. In these cases, we must continue to + // IterDomain is concrete at definition, since if the extent is >1 we + // know the IterType is Iteration. In these cases, we must continue to // concretize intermediate expressions between the root and R-factor // domain. See test DynamicTransform5_CUDA which demonstrates this // behavior. // NOTE: We also do not assume that if one output ID is symbolic, that // they all must be. See test FusionSliceForNanoGPT3_CUDA for an example - // that does a static split by a factor of 16 of a symbolic input domain. - // The static split in that case results in a concrete IterDomain with - // extent 16 along with a symbolic one (extent ceilDiv(n / 16)). + // that does a static split by a factor of 16 of a symbolic input + // domain. The static split in that case results in a concrete + // IterDomain with extent 16 along with a symbolic one (extent ceilDiv(n + // / 16)). // Determine the output IterType IterType iter_type = IterType::Symbolic; diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 0cb1af5f9d3..5ba701c5437 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -18,6 +18,7 @@ #include #include +#include #include namespace nvfuser { @@ -44,8 +45,7 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { //! given some user input. In either of these cases, concretization may change //! the structure of the Fusion. bool isDynamic() const { - return hasPossibleEmptyTensor() || !dynamic_reshaped_tvs_.empty() || - !dynamic_resized_ids_.empty(); + return hasPossibleEmptyTensor() || !dynamic_exprs_.empty(); } //! Return whether there are any tensors with unknown extent in some @@ -54,6 +54,10 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { return !maybe_zero_extents_.empty(); } + const std::vector& getDynamicExprs() const { + return dynamic_exprs_; + } + //! Return a set of scalars that are inputs or extents of input TensorViews //! and that appear in inputs to dynamic expressions. Any Vals not in this //! list do not affect concretization. @@ -68,18 +72,6 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { return maybe_zero_extents_; } - //! Return a vector of outputs of ViewOp expressions that have dynamic output - //! shapes - const std::vector& getDynamicReshapedTensorViews() const { - return dynamic_reshaped_tvs_; - } - - //! Return a vector of outputs of Resize expressions that have symbolic output - //! IterTypes - const std::vector& getDynamicResizedIterDomains() const { - return dynamic_resized_ids_; - } - std::string toString() const; DynamicTransformInitialInfo clone(IrCloner& ir_cloner) const; @@ -92,7 +84,7 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { } protected: - //! Holds the set of scalar fusion inputs that affect concretization. + //! Holds the set of scalar fusion input positions that affect concretization. std::unordered_set scalar_inputs_affecting_concretization_; private: @@ -101,23 +93,17 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { private: Fusion* fusion_ = nullptr; - // We hold vectors of the _outputs_ of dynamic ops. The reason we don't hold - // the ops themselves is that during concretization, the ops will actually be - // removed by ir_utils::replaceValInExpr. The outputs will not: their - // definitions will merely be altered. When the ops are replaced, if we had - // referred to them directly here, we would run into segfaults. Referring only - // to the outputs avoids this issue. - std::vector dynamic_reshaped_tvs_; - - std::vector dynamic_resized_ids_; + // This is a vector of dynamic Exprs, in topological order. + std::vector dynamic_exprs_; // This is a minimal set of scalars to check for empty tensors. If any are // zero, we should traverse to find empty tensors. std::unordered_set maybe_zero_extents_set_; - // The set above is populated then used to create this unique vector + // The set above is populated then used to create this unique vector which is + // ordered arbitrarily. std::vector maybe_zero_extents_; - // Root Vals that determine concretization + // Minimal set of Vals that determine all aspects of concretization std::unordered_set root_dynamic_vals_; friend class DynamicTransformInitialInfoBuilder; @@ -135,20 +121,10 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { return empty_extents_; } - //! Return a vector of pairs holding the index of each reshaped TensorView in - //! the vector returned by initialInfo()->getDynamicReshapedTensorViews(), - //! along with an AnalyzeViewResult describing how that reshape operation - //! should be decomposed into split, merge, squeeze, and broadcast transforms. - const std::vector>& getReshapeTransforms() - const { - return reshape_transforms_; - } - - //! Return a vector of pairs holding the index of each resized IterDomain in - //! the vector returned by initialInfo()->getDynamicResizedIterDomains(), - //! along with the IterType it should be concretized to. - const std::vector>& getResizeIterTypes() const { - return resize_itertypes_; + //! Return a vector of descriptors describing how to concretize the dynamic + //! expressions in the initial info. + const auto& getExprConcretizationDescriptors() const { + return concretization_descriptors_; } //! Comparison operator for the purposes of determining cache hits. This does @@ -163,13 +139,13 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { } //! Given an ExpressionEvaluator which already has input scalars bound to it, - //! determine the decomposition of each dynamic reshape operation to use - //! during concretization. - void analyzeReshapes(ExpressionEvaluator* expr_eval); + //! determine the decomposition of a dynamic reshape operation to use during + //! concretization. + void analyze(ViewOp* vop, ExpressionEvaluator* expr_eval); //! Given an ExpressionEvaluator which already has input scalars bound to it, //! determine the concrete IterType of each resized IterDomain. - void analyzeResizes(ExpressionEvaluator* expr_eval); + void analyze(Resize* rop, ExpressionEvaluator* expr_eval); const DynamicTransformInitialInfo* initialInfo() const { return initial_info_; @@ -195,20 +171,23 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { private: const DynamicTransformInitialInfo* initial_info_ = nullptr; - //! Holds the index of the output TensorView in the vector returned by - //! initial_info_->getDynamicReshapedTensorViews(), and the corresponding - //! result of analyzeView - std::vector> reshape_transforms_; + //! Holds data required to concretize an operation. Entries in this vector + //! correspond to the Exprs in initial_info_->getDynamicExprs(). + //! + //! Each type of Expr requires a different type of information at + //! concretization, so this holds a variant that should enumerate all data + //! types encountered. Each should be copyable and should not include pointers + //! to any Statements. + std::vector> + concretization_descriptors_; //! Holds a vector of indices into initial_info_.getMaybeZeroExtents() which //! evaluate to 0 std::vector empty_extents_; - //! Holds the index of the resized IterDomain (output of the Resize op) in the - //! vector returned by initial_info_->getDynamicResizedIterDomains() along - //! with its concretized IterType - std::vector> resize_itertypes_; - friend class DynamicTransformInfoBuilder; }; From 52fdaf52c170c15604edc7e4d9a843af1256466c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 26 Jul 2023 20:01:57 -0400 Subject: [PATCH 2/6] Compiles --- csrc/dynamic_transform.cpp | 112 +++++++++++++++----------------- test/test_dynamic_transform.cpp | 27 ++++++-- 2 files changed, 75 insertions(+), 64 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 27f55741132..9e9193a0639 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -31,7 +31,7 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( dynamic_exprs_.begin(), dynamic_exprs_.end(), std::back_inserter(cloned_info.dynamic_exprs_), - [&ir_cloner](TensorView* tv) { return ir_cloner.clone(tv); }); + [&ir_cloner](Expr* e) { return ir_cloner.clone(e); }); cloned_info.maybe_zero_extents_set_.reserve(maybe_zero_extents_set_.size()); std::transform( @@ -40,7 +40,7 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( std::inserter( cloned_info.maybe_zero_extents_set_, cloned_info.maybe_zero_extents_set_.begin()), - [&ir_cloner](TensorView* tv) { return ir_cloner.clone(tv); }); + [&ir_cloner](Val* v) { return ir_cloner.clone(v); }); cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); std::transform( @@ -49,7 +49,7 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( std::inserter( cloned_info.root_dynamic_vals_, cloned_info.root_dynamic_vals_.begin()), - [&ir_cloner](TensorView* tv) { return ir_cloner.clone(tv); }); + [&ir_cloner](Val* v) { return ir_cloner.clone(v); }); return cloned_info; } @@ -367,8 +367,8 @@ std::string DynamicTransformConcretizationInfo::toString() const { } else if (auto iter_type_ptr = std::get_if(&desc_var)) { ss << indent << indent << "IterType: " << (*iter_type_ptr) << "\n"; } - return ss.str(); } + return ss.str(); } //! Concretize a symbolic fusion with concrete transformation info @@ -388,9 +388,9 @@ class DynamicTransformConcretizer : public OptOutMutator { private: void concretize(); - void concretizeReshape(ViewOp* op, AnalyzeViewResult& view_analysis) const; + void concretizeReshape(ViewOp* op, AnalyzeViewResult& view_analysis); - void concretizeResize(Resize* op, IterType iter_type) const; + void concretizeResize(Resize* op, IterType iter_type); void concretizeEmptyExtents(); @@ -439,8 +439,8 @@ void DynamicTransformConcretizer::concretize() { auto op = info_->initialInfo()->getDynamicExprs().at(i); auto desc_var = info_->getExprConcretizationDescriptors().at(i); if (auto vop = dynamic_cast(op)) { - if (auto analyze_result_ptr = std::get_if(desc_var)) { - concretizeReshape(op, *analyze_result_ptr); + if (auto analyze_result_ptr = std::get_if(&desc_var)) { + concretizeReshape(vop, *analyze_result_ptr); } else { TORCH_CHECK( false, @@ -448,8 +448,8 @@ void DynamicTransformConcretizer::concretize() { desc_var.index()); } } else if (auto rop = dynamic_cast(op)) { - if (auto iter_type_ptr = std::get_if(desc_var)) { - concretizeResize(op, *iter_type_ptr); + if (auto iter_type_ptr = std::get_if(&desc_var)) { + concretizeResize(rop, *iter_type_ptr); } else { TORCH_CHECK( false, @@ -499,58 +499,53 @@ void DynamicTransformConcretizer::concretizeEmptyExtents() { void DynamicTransformConcretizer::concretizeReshape( ViewOp* view_op, - AnalyzeViewResult& view_analysis) const { - // Concretize each reshape op. - for (const auto& [tv_index, view_analysis] : info_->getReshapeTransforms()) { - auto incomplete_out_tv = - info_->initialInfo()->getDynamicReshapedTensorViews().at(tv_index); - auto view_op = incomplete_out_tv->definition()->as(); - auto inp_tv = view_op->in()->as(); - - auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); - - // We do the replacement directly here, but we must still check that the - // replacement is valid - checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); - - // Extent expressions often change when concretizing a reshape. Here we - // replace these in all downstream expressions so that the Fusion looks - // just like it would have if we had used a static reshape instead. - auto old_rfactor = incomplete_out_tv->getMaybeRFactorDomain(); - auto new_rfactor = concrete_reshape_out_tv->getMaybeRFactorDomain(); - TORCH_INTERNAL_ASSERT( - old_rfactor.size() == new_rfactor.size(), - "Concretized reshape rfactor size does not match symbolic rfactor"); - for (auto idx : c10::irange(new_rfactor.size())) { - auto old_extent = old_rfactor.at(idx)->extent(); - auto new_extent = new_rfactor.at(idx)->extent(); - // If the old extent did not have a definition, we don't need to replace - // it, since it will get bound whenever this tensor is a segmentation - // edge. - if (old_extent->definition() && !new_extent->sameAs(old_extent)) { - registerConcretization(old_extent, new_extent); - } - } + AnalyzeViewResult& view_analysis) { + auto incomplete_out_tv = view_op->out()->as(); + auto inp_tv = view_op->in()->as(); - // Replace the old tensor with the new concretized tensor - auto uses = incomplete_out_tv->uses(); - for (auto use_of_old_tv : uses) { - ir_utils::replaceValInExpr( - use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); - } + auto concrete_reshape_out_tv = reshape(inp_tv, view_analysis); + + // We do the replacement directly here, but we must still check that the + // replacement is valid + checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); - if (incomplete_out_tv->isFusionOutput()) { - incomplete_out_tv->fusion()->replaceOutput( - incomplete_out_tv, concrete_reshape_out_tv); + // Extent expressions often change when concretizing a reshape. Here we + // replace these in all downstream expressions so that the Fusion looks + // just like it would have if we had used a static reshape instead. + auto old_rfactor = incomplete_out_tv->getMaybeRFactorDomain(); + auto new_rfactor = concrete_reshape_out_tv->getMaybeRFactorDomain(); + TORCH_INTERNAL_ASSERT( + old_rfactor.size() == new_rfactor.size(), + "Concretized reshape rfactor size does not match symbolic rfactor"); + for (auto idx : c10::irange(new_rfactor.size())) { + auto old_extent = old_rfactor.at(idx)->extent(); + auto new_extent = new_rfactor.at(idx)->extent(); + // If the old extent did not have a definition, we don't need to replace + // it, since it will get bound whenever this tensor is a segmentation + // edge. + if (old_extent->definition() && !new_extent->sameAs(old_extent)) { + registerConcretization(old_extent, new_extent); } + } + + // Replace the old tensor with the new concretized tensor + auto uses = incomplete_out_tv->uses(); + for (auto use_of_old_tv : uses) { + ir_utils::replaceValInExpr( + use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); + } - info_->fusion()->removeVal(incomplete_out_tv); + if (incomplete_out_tv->isFusionOutput()) { + incomplete_out_tv->fusion()->replaceOutput( + incomplete_out_tv, concrete_reshape_out_tv); } + + info_->fusion()->removeVal(incomplete_out_tv); } void DynamicTransformConcretizer::concretizeResize( Resize* op, - IterType iter_type) const { + IterType iter_type) { // Concretize each resize op. auto out_id = op->out()->as(); auto new_id = IterDomain::resize( @@ -815,11 +810,12 @@ void DynamicTransform::concretizeFusion( size_t DynamicTransformConcretizationInfo::hash() const { size_t hash = 0; - for (const auto& [tv, view_result] : getReshapeTransforms()) { - hashCombine(hash, view_result.hash()); - } - for (const auto& [id, iter_type] : getResizeIterTypes()) { - hashCombine(hash, (size_t)iter_type); + for (const auto& desc_var : getExprConcretizationDescriptors()) { + if (auto view_result_ptr = std::get_if(&desc_var)) { + hashCombine(hash, view_result_ptr->hash()); + } else if (auto iter_type_ptr = std::get_if(&desc_var)) { + hashCombine(hash, (size_t)(*iter_type_ptr)); + } } return hash; } diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index fcd3b59f9b9..cf80183544e 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -66,8 +66,15 @@ TEST_F(NVFuserTest, DynamicTransform1_CUDA) { auto initial_info = DynamicTransform::getInitialInfo(&fusion); auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); TORCH_CHECK( - info.getReshapeTransforms().size() == 1, + initial_info.getDynamicExprs().size() == 1 && + initial_info.getDynamicExprs().at(0)->isA(), "Expected to have one reshape transform: ", + initial_info.toString()); + TORCH_CHECK( + info.getExprConcretizationDescriptors().size() == 1 && + std::holds_alternative( + info.getExprConcretizationDescriptors().at(0)), + "Expected to have one AnalyzeViewResult in concretization info: ", info.toString()); } @@ -150,8 +157,15 @@ TEST_F(NVFuserTest, DynamicTransform2_CUDA) { auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); TORCH_CHECK( - info.getReshapeTransforms().size() == 1, + initial_info.getDynamicExprs().size() == 1 && + initial_info.getDynamicExprs().at(0)->isA(), "Expected to have one reshape transform: ", + initial_info.toString()); + TORCH_CHECK( + info.getExprConcretizationDescriptors().size() == 1 && + std::holds_alternative( + info.getExprConcretizationDescriptors().at(0)), + "Expected to have one AnalyzeViewResult in concretization info: ", info.toString()); } } @@ -541,10 +555,11 @@ TEST_F(NVFuserTest, DynamicTransform9_CUDA) { // There must be only one dynamic reshape entry, and that must be // for tv2. TORCH_CHECK( - info.getReshapeTransforms().size() == 1, - info.getReshapeTransforms().at(0).first == 0, // first and only reshape - "Unexpected dynamic transform info:", - info.toString()); + initial_info.getDynamicExprs().size() == 1 && + initial_info.getDynamicExprs().at(0)->isA() && + initial_info.getDynamicExprs().at(0) == tv2->definition(), + "Expected to have one reshape transform corresponding to tv2: ", + initial_info.toString()); } // Make sure inherited symbolic IDs are concretized through rfactor exprs From 2eba93ef9cbf1278a8680607146de093590354c3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 26 Jul 2023 20:46:12 -0400 Subject: [PATCH 3/6] Fix new bugs --- csrc/dynamic_transform.cpp | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 9e9193a0639..f0af5f06472 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -42,6 +42,15 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( cloned_info.maybe_zero_extents_set_.begin()), [&ir_cloner](Val* v) { return ir_cloner.clone(v); }); + cloned_info.maybe_zero_extents_.reserve(maybe_zero_extents_.size()); + std::transform( + maybe_zero_extents_.begin(), + maybe_zero_extents_.end(), + std::inserter( + cloned_info.maybe_zero_extents_, + cloned_info.maybe_zero_extents_.begin()), + [&ir_cloner](Val* v) { return ir_cloner.clone(v); }); + cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); std::transform( root_dynamic_vals_.begin(), @@ -123,10 +132,10 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { void handle(TensorView* tv) override { const auto& rfd = tv->getMaybeRFactorDomain(); for (auto id : rfd) { - if (!id->getMaybeExpandedExtent()->isConstScalar() || - id->getMaybeExpandedExtent()->evaluateInt() == 0) { - info_.maybe_zero_extents_set_.insert(id->getMaybeExpandedExtent()); - leaf_dynamic_vals_.push_back(id->getMaybeExpandedExtent()); + auto ext = id->getMaybeExpandedExtent(); + if (!ext->isConstScalar() || ext->evaluateInt() == 0) { + info_.maybe_zero_extents_set_.insert(ext); + leaf_dynamic_vals_.push_back(ext); } if (!id->definition() || id->getIterType() != IterType::Symbolic) { continue; @@ -332,21 +341,18 @@ bool DynamicTransformConcretizationInfo::operator==( if (concretization_descriptors_.size() != other.concretization_descriptors_.size() || + empty_extents_.size() != other.empty_extents_.size() || !std::equal( concretization_descriptors_.begin(), concretization_descriptors_.end(), - other.concretization_descriptors_.begin())) { + other.concretization_descriptors_.begin()) || + !std::equal( + empty_extents_.begin(), + empty_extents_.end(), + other.empty_extents_.begin())) { return false; } - for (const auto i : c10::irange(empty_extents_.size())) { - const auto& ee = empty_extents_.at(i); - const auto& other_ee = other.empty_extents_.at(i); - if (ee != other_ee) { - return false; - } - } - return true; } @@ -434,7 +440,7 @@ void DynamicTransformConcretizer::concretize() { // these Exprs after this pass, and we do this pass in reverse topological // order, we will not encounter Expr* pointers that have been invalidated // due to `replaceValInExpr` like this. - for (int i = info_->getExprConcretizationDescriptors().size() - 1; i > 0; + for (int i = info_->getExprConcretizationDescriptors().size() - 1; i >= 0; --i) { auto op = info_->initialInfo()->getDynamicExprs().at(i); auto desc_var = info_->getExprConcretizationDescriptors().at(i); From d954caa33714cb5e9ede3b2d74d2c38873802453 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 12:12:14 -0400 Subject: [PATCH 4/6] Silence clang-tidy --- csrc/dynamic_transform.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index f0af5f06472..d4f22d54f44 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -440,9 +440,11 @@ void DynamicTransformConcretizer::concretize() { // these Exprs after this pass, and we do this pass in reverse topological // order, we will not encounter Expr* pointers that have been invalidated // due to `replaceValInExpr` like this. - for (int i = info_->getExprConcretizationDescriptors().size() - 1; i >= 0; + for (int i = (int)info_->getExprConcretizationDescriptors().size() - 1; + i >= 0; --i) { auto op = info_->initialInfo()->getDynamicExprs().at(i); + TORCH_CHECK(op != nullptr, "Dynamic expression must not be nullptr"); auto desc_var = info_->getExprConcretizationDescriptors().at(i); if (auto vop = dynamic_cast(op)) { if (auto analyze_result_ptr = std::get_if(&desc_var)) { From 7211deb70f086280bf037620587aa4d9062d3342 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 22 Aug 2023 12:25:22 -0400 Subject: [PATCH 5/6] Add comment on why process IDs in handle(TV) --- csrc/dynamic_transform.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index d4f22d54f44..c73055ad623 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -129,6 +129,24 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { } //! Detect possibly empty TensorViews and dynamic IterDomain transforms + //! + //! NOTE: we detect IterDomain transforms this way instead of e.g. + //! handle(Resize*) to guarantee that the dynamic_exprs_ vector will be + //! topologically sorted _with respect to TensorView expressions_. To see why + //! this is necessary, consider this example: + //! + //! auto tv1 = reshape(tv0, dyn_shape); + //! auto tv2 = pad(tv1, {s0, s1}); + //! + //! We need to process the reshape op before processing the Resize of the last + //! axis of tv2, since we might define a placeholder extent on the last axis + //! of tv1 which will need to be bound before we handle the Resize. However, + //! IterDomains do not have TensorViews as producers, so a valid topological + //! ordering could handle _all_ IterDomains in the Fusion before processing + //! the TensorViews. By processing IterDomains only when we visit a TensorView + //! and not earlier, we guarantee that we have already processed the + //! TensorView definition, which means we've had a chance to bind placeholder + //! extents if needed. void handle(TensorView* tv) override { const auto& rfd = tv->getMaybeRFactorDomain(); for (auto id : rfd) { From 2f483e3b287cac6667e6cfb282cf3a10c53b9ace Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 6 Sep 2023 10:10:02 -0400 Subject: [PATCH 6/6] Small fixes --- csrc/dynamic_transform.cpp | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 8b2d66ebb61..0dbaf1274f8 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -225,7 +225,7 @@ DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( } else if (auto rop = dynamic_cast(expr)) { analyze(rop, expr_eval); } else { - TORCH_CHECK(false, "Unhandled dynamic Expr type: ", expr->toString()); + NVF_CHECK(false, "Unhandled dynamic Expr type: ", expr->toString()); } } @@ -254,7 +254,7 @@ void DynamicTransformConcretizationInfo::analyze( return; } - TORCH_INTERNAL_ASSERT( + NVF_ERROR( out_tv->hasRFactor(), "Unexpected output tv of ViewOp: ", out_tv->toString()); @@ -268,20 +268,20 @@ void DynamicTransformConcretizationInfo::analyze( auto inp_id = inp_dom.at(i); // This should have been validated when initially creating reshape // op, but just in case - TORCH_INTERNAL_ASSERT( + NVF_ERROR( !inp_id->maybePartial(), "Invalid domain to reshape: ", inp_id->toString()); auto extent_val = expr_eval->evaluate(inp_id->extent()); - TORCH_INTERNAL_ASSERT( + NVF_ERROR( extent_val.hasValue(), "Cannot evaluate the extent of an input domain to reshape: ", inp_id->toString()); - TORCH_INTERNAL_ASSERT( + NVF_ERROR( extent_val.is(), "Invalid evaluated value of domain extent: ", inp_id->toString()); - TORCH_INTERNAL_ASSERT( + NVF_ERROR( extent_val.as() > 0, "Invalid input domain extent: ", extent_val.as()); @@ -324,22 +324,22 @@ void DynamicTransformConcretizationInfo::analyze( ExpressionEvaluator* expr_eval) { auto out_id = op->out()->as(); - TORCH_CHECK( + NVF_CHECK( out_id->getIterType() == IterType::Symbolic, "Found non-dynamic Resize in initial concretization info: ", op->toString()); auto extent_val = expr_eval->evaluate(out_id->extent()); - TORCH_INTERNAL_ASSERT( + NVF_ERROR( extent_val.hasValue(), "Cannot evaluate the extent of a resized domain: ", out_id->toString()); - TORCH_INTERNAL_ASSERT( + NVF_ERROR( extent_val.is(), "Invalid evaluated value of resized domain extent: ", out_id->toString()); auto extent_int = extent_val.as(); - TORCH_INTERNAL_ASSERT( + NVF_ERROR( extent_int > 0, "Invalid resized domain extent ", extent_int, @@ -462,13 +462,13 @@ void DynamicTransformConcretizer::concretize() { i >= 0; --i) { auto op = info_->initialInfo()->getDynamicExprs().at(i); - TORCH_CHECK(op != nullptr, "Dynamic expression must not be nullptr"); + NVF_CHECK(op != nullptr, "Dynamic expression must not be nullptr"); auto desc_var = info_->getExprConcretizationDescriptors().at(i); if (auto vop = dynamic_cast(op)) { if (auto analyze_result_ptr = std::get_if(&desc_var)) { concretizeReshape(vop, *analyze_result_ptr); } else { - TORCH_CHECK( + NVF_CHECK( false, "Dynamic ViewOp expects AnalyzeViewResult descriptor but found variant index ", desc_var.index()); @@ -477,13 +477,13 @@ void DynamicTransformConcretizer::concretize() { if (auto iter_type_ptr = std::get_if(&desc_var)) { concretizeResize(rop, *iter_type_ptr); } else { - TORCH_CHECK( + NVF_CHECK( false, "Dynamic Resize expects IterType descriptor but found variant index ", desc_var.index()); } } else { - TORCH_CHECK(false, "Unhandled dynamic Expr type: ", op->toString()); + NVF_CHECK(false, "Unhandled dynamic Expr type: ", op->toString()); } } @@ -551,7 +551,6 @@ void DynamicTransformConcretizer::concretizeReshape( // edge. if (old_extent->definition() && !new_extent->sameAs(old_extent)) { registerConcretization(old_extent, new_extent); - } } } @@ -868,3 +867,4 @@ size_t DynamicTransformConcretizationInfo::hash() const { } } // namespace nvfuser +