From 7400b832efd146407e1dc276d32bb6357c40f542 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Thu, 29 Aug 2024 08:50:24 +0800 Subject: [PATCH] support strided_slice --- src/contrib/msc/core/codegen/base_codegen.h | 6 +- src/contrib/msc/core/ir/graph_builder.cc | 13 +++- .../msc/core/transform/bind_named_params.cc | 2 +- src/contrib/msc/core/utils.cc | 67 ++++++++++++++++++- src/contrib/msc/core/utils.h | 54 +++++++++++++-- .../contrib/test_msc/test_graph_build.py | 3 - .../contrib/test_msc/test_translate_relax.py | 4 -- .../test_msc/test_translate_tensorflow.py | 4 -- .../contrib/test_msc/test_translate_torch.py | 3 - 9 files changed, 128 insertions(+), 28 deletions(-) diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index 19d8b524b9e2..acaac896a153 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -179,17 +179,17 @@ class BaseCodeGen { return 1; } if (node->scope.size() == scopes_.top().size()) { - ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top())) + ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top())) << "Scope mismatch, node " << node->scope << " compare to current " << scopes_.top(); return 0; } else if (node->scope.size() == scopes_.top().size() + 1) { - ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) + ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) << "Scope increase mismatch, node " << node->scope << " compare to current " << scopes_.top(); scopes_.push(node->scope); return 1; } else if (node->scope.size() == scopes_.top().size() - 1) { - ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) + ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) << "Scope decrease mismatch, node " << node->scope << " compare to current " << scopes_.top(); scopes_.pop(); diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index d35a462579d9..a968df4204a2 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -23,6 +23,7 @@ #include "graph_builder.h" +#include #include namespace tvm { @@ -71,6 +72,13 @@ void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) { for (const auto& arg : op->args) { if (const auto* s_node = arg.as()) { values_.push_back(StringUtils::ToString(s_node->value)); + } else if (const auto* s_node = arg.as()) { + bool all_values = + std::all_of(s_node->fields.begin(), s_node->fields.end(), + [](const relax::Expr& e) { return e->IsInstance(); }); + if (all_values) { + values_.push_back(StringUtils::ToString(s_node->fields)); + } } } } @@ -337,6 +345,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype << " should has special type, get " << input_types; attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); + } else if (input_types[i] != "input" && arg->IsInstance()) { + attrs.Set(input_types[i], StringUtils::ToString(arg)); } } for (size_t i = call->args.size(); i < input_types.size(); i++) { @@ -371,7 +381,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional Array arg_names; if (expr_tensor_map_.count(arg)) { arg_names = expr_tensor_map_[arg]; - } else if (const auto* tuple_node = arg.as()) { + } else if (input_types[i] == "input" && arg->IsInstance()) { + const auto* tuple_node = arg.as(); for (const auto& f : tuple_node->fields) { ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; for (const auto& in_name : expr_tensor_map_[f]) { diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 5ba1ca30eb1c..6256fae05f83 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -84,7 +84,7 @@ std::tuple, Map> NormalizeNamedBindings( if (auto opt = obj.as()) { return opt.value(); } else if (auto opt = obj.as()) { - const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, key->name_hint()); + const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint()); return Constant(opt.value(), StructInfo(), span); } else { LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey() diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 5fcbe924ae1c..c6e74d42843d 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -280,6 +280,8 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { } } else if (const auto* n = obj.as()) { obj_string = ToString(n->value); + } else if (const auto* n = obj.as()) { + obj_string = ToString(n->fields); } else { std::ostringstream obj_des; obj_des << obj; @@ -288,7 +290,7 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { return obj_string; } -bool StringUtils::CompareArrays(const Array& left, const Array& right, int size) { +bool ArrayUtils::CompareArrays(const Array& left, const Array& right, int size) { if (left.size() == right.size() && left.size() == 0) { return true; } @@ -311,6 +313,37 @@ bool StringUtils::CompareArrays(const Array& left, const Array& return true; } +PrimExpr ArrayUtils::Accumulate(const Array& array, int pos) { + size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos; + PrimExpr accumulate = Integer(1); + for (size_t i = 0; i < t_pos; i++) { + accumulate = accumulate * array[i]; + } + return accumulate; +} + +bool ArrayUtils::Broadcastable(const Array& lhs, const Array& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); i++) { + const auto& lp = lhs[i]; + const auto& rp = rhs[i]; + if (lp->IsInstance() && rp->IsInstance()) { + continue; + } + if (lp->IsInstance() && rp->IsInstance() && + Downcast(lp)->value == Downcast(rp)->value) { + continue; + } + if (lp->IsInstance() && Downcast(lp)->value == 1) { + continue; + } + return false; + } + return true; +} + const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& value) { if (value.size() == 0) { return span; @@ -353,6 +386,10 @@ const Map SpanUtils::GetAttrs(const Span& span) { return attrs; } +const Span SpanUtils::CreateWithAttr(const String& key, const String& value) { + return SetAttr(Span(), key, value); +} + const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs_num, bool as_relax) { Array input_types; @@ -370,6 +407,14 @@ const Array ExprUtils::GetInputTypes(const String& optype, size_t inputs } else if (optype == "full" && as_relax) { input_types.push_back("shape"); input_types.push_back("input"); + } else if (optype == "strided_slice") { + input_types.push_back("input"); + if (inputs_num > 1) { + input_types.push_back("axes"); + input_types.push_back("begin"); + input_types.push_back("end"); + input_types.push_back("strides"); + } } else if (optype == "triu") { input_types.push_back("input"); input_types.push_back("k"); @@ -454,13 +499,31 @@ const Array ExprUtils::GetInputTypes(const RelayCall& call) { return GetInputTypes(optype, call->args.size(), false); } +const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) { + const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName); + if (suffix.size() > 0) { + return name + "_" + suffix; + } + return name; +} + +const Array ExprUtils::GetShape(const Expr& expr) { + const auto& shape_opt = Downcast(relax::GetStructInfo(expr))->GetShape(); + ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr; + return shape_opt.value(); +} + +const DataType ExprUtils::GetDataType(const Expr& expr) { + return Downcast(relax::GetStructInfo(expr))->dtype; +} + TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") .set_body_typed([](const String& key, const String& value) -> Span { - return SpanUtils::SetAttr(Span(), key, value); + return SpanUtils::CreateWithAttr(key, value); }); TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr") diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 6c39a8d0a16a..d7758cc23d8b 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -175,13 +176,6 @@ class StringUtils { * \return The String. */ TVM_DLL static const String ToString(const runtime::ObjectRef& obj); - - /*! - * \brief Compare String arrays. - * \return Whether two array are same. - */ - TVM_DLL static bool CompareArrays(const Array& left, const Array& right, - int size = -1); }; /*! @@ -238,6 +232,10 @@ class ArrayUtils { return new_array; } + /*! + * \brief Product elements in the arrays. + * \return The producted array + */ template TVM_DLL static const Array> Product(const Array>& arrays) { Array> p_arrays; @@ -260,6 +258,24 @@ class ArrayUtils { } return p_arrays; } + + /*! + * \brief Compare String arrays. + * \return Whether two array are same. + */ + TVM_DLL static bool CompareArrays(const Array& left, const Array& right, + int size = -1); + /*! + * \brief Accumulate array. + * \return The accumulate result + */ + TVM_DLL static PrimExpr Accumulate(const Array& array, int pos = -1); + + /*! + * \brief Check if lhs array is broadcastable to rhs. + * \return broadcastable + */ + TVM_DLL static bool Broadcastable(const Array& lhs, const Array& rhs); }; /*! @@ -284,6 +300,12 @@ class SpanUtils { * \return The Attrs Map. */ TVM_DLL static const Map GetAttrs(const Span& span); + + /*! + * \brief Create a span with value. + * \return The created Span. + */ + TVM_DLL static const Span CreateWithAttr(const String& key, const String& value); }; /*! @@ -365,6 +387,24 @@ class ExprUtils { TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 0) { return GetScalar(constant->data, i); } + + /*! + * \brief Get name in span. + * \return The name. + */ + TVM_DLL static const String GetSpanName(const Expr& expr, const String& suffix = ""); + + /*! + * \brief Get shape of expr. + * \return The shape. + */ + TVM_DLL static const Array GetShape(const Expr& expr); + + /*! + * \brief Get dtype of expr. + * \return The shape. + */ + TVM_DLL static const DataType GetDataType(const Expr& expr); }; } // namespace msc diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 069ffff53bd7..d02767208206 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -17,8 +17,6 @@ """ Test graph builder && graph. """ -import pytest - import torch from torch import fx from torch.nn import Module @@ -1101,7 +1099,6 @@ def forward(self, data): verify_model(GetAttr1(), input_info, expected) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_getitem(): """test graph builder for getitem""" diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index e8b7149a68a2..66aa90a625ea 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -17,8 +17,6 @@ """ Test translate from relax. """ -import pytest - import torch from torch import fx from torch.nn import Module @@ -57,7 +55,6 @@ def _run_relax(relax_mod): relax_exec = tvm.relax.build(relax_mod, target) vm_runner = tvm.relax.VirtualMachine(relax_exec, dev) res = vm_runner["main"](*args) - return _tvm_runtime_to_np(res) rt_mod = tvm_codegen.to_relax( @@ -629,7 +626,6 @@ def forward(self, data): _verify_model(GetAttr1(), input_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_getitem(): """test relax translator for getitem""" diff --git a/tests/python/contrib/test_msc/test_translate_tensorflow.py b/tests/python/contrib/test_msc/test_translate_tensorflow.py index 61f8ce1a973c..cb4ea3c02e4b 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorflow.py +++ b/tests/python/contrib/test_msc/test_translate_tensorflow.py @@ -18,8 +18,6 @@ """ Test translate from tensorflow. """ -import pytest - from packaging import version as package_version import numpy as np @@ -504,7 +502,6 @@ def _test_stridedslice( verify_model(graph_def, golden, **io_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_stridedslice(): """test tensorflow translator for stridedslice""" @@ -1065,7 +1062,6 @@ def _test_slice_operation_input(input_value, begin_value, size_value): verify_model(graph_def, golden, **io_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_slice(): """test tensorflow translator for slice""" diff --git a/tests/python/contrib/test_msc/test_translate_torch.py b/tests/python/contrib/test_msc/test_translate_torch.py index 60dcbb293a51..f3e01493d96a 100644 --- a/tests/python/contrib/test_msc/test_translate_torch.py +++ b/tests/python/contrib/test_msc/test_translate_torch.py @@ -17,8 +17,6 @@ """ Test translate from torch. """ -import pytest - import numpy as np import torch @@ -589,7 +587,6 @@ def forward(self, data): verify_model(GetAttr1(), input_info) -@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue") def test_getitem(): """test torch translator for getitem"""