Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/contrib/msc/core/codegen/base_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,17 @@ class BaseCodeGen {
return 1;
}
if (node->scope.size() == scopes_.top().size()) {
ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top()))
ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top()))
<< "Scope mismatch, node " << node->scope << " compare to current " << scopes_.top();
return 0;
} else if (node->scope.size() == scopes_.top().size() + 1) {
ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size()))
ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size()))
<< "Scope increase mismatch, node " << node->scope << " compare to current "
<< scopes_.top();
scopes_.push(node->scope);
return 1;
} else if (node->scope.size() == scopes_.top().size() - 1) {
ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size()))
ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size()))
<< "Scope decrease mismatch, node " << node->scope << " compare to current "
<< scopes_.top();
scopes_.pop();
Expand Down
13 changes: 12 additions & 1 deletion src/contrib/msc/core/ir/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "graph_builder.h"

#include <algorithm>
#include <set>

namespace tvm {
Expand Down Expand Up @@ -71,6 +72,13 @@ void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* op) {
for (const auto& arg : op->args) {
if (const auto* s_node = arg.as<relax::PrimValueNode>()) {
values_.push_back(StringUtils::ToString(s_node->value));
} else if (const auto* s_node = arg.as<relax::TupleNode>()) {
bool all_values =
std::all_of(s_node->fields.begin(), s_node->fields.end(),
[](const relax::Expr& e) { return e->IsInstance<relax::PrimValueNode>(); });
if (all_values) {
values_.push_back(StringUtils::ToString(s_node->fields));
}
}
}
}
Expand Down Expand Up @@ -337,6 +345,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional<Expr>
ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
<< " should has special type, get " << input_types;
attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
} else if (input_types[i] != "input" && arg->IsInstance<relax::TupleNode>()) {
attrs.Set(input_types[i], StringUtils::ToString(arg));
}
}
for (size_t i = call->args.size(); i < input_types.size(); i++) {
Expand Down Expand Up @@ -371,7 +381,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional<Expr>
Array<String> arg_names;
if (expr_tensor_map_.count(arg)) {
arg_names = expr_tensor_map_[arg];
} else if (const auto* tuple_node = arg.as<relax::TupleNode>()) {
} else if (input_types[i] == "input" && arg->IsInstance<relax::TupleNode>()) {
const auto* tuple_node = arg.as<relax::TupleNode>();
for (const auto& f : tuple_node->fields) {
ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f;
for (const auto& in_name : expr_tensor_map_[f]) {
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/core/transform/bind_named_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeNamedBindings(
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, key->name_hint());
const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint());
return Constant(opt.value(), StructInfo(), span);
} else {
LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
Expand Down
67 changes: 65 additions & 2 deletions src/contrib/msc/core/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) {
}
} else if (const auto* n = obj.as<relax::PrimValueNode>()) {
obj_string = ToString(n->value);
} else if (const auto* n = obj.as<relax::TupleNode>()) {
obj_string = ToString(n->fields);
} else {
std::ostringstream obj_des;
obj_des << obj;
Expand All @@ -288,7 +290,7 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) {
return obj_string;
}

bool StringUtils::CompareArrays(const Array<String>& left, const Array<String>& right, int size) {
bool ArrayUtils::CompareArrays(const Array<String>& left, const Array<String>& right, int size) {
if (left.size() == right.size() && left.size() == 0) {
return true;
}
Expand All @@ -311,6 +313,37 @@ bool StringUtils::CompareArrays(const Array<String>& left, const Array<String>&
return true;
}

PrimExpr ArrayUtils::Accumulate(const Array<PrimExpr>& array, int pos) {
size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos;
PrimExpr accumulate = Integer(1);
for (size_t i = 0; i < t_pos; i++) {
accumulate = accumulate * array[i];
}
return accumulate;
}

bool ArrayUtils::Broadcastable(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (size_t i = 0; i < lhs.size(); i++) {
const auto& lp = lhs[i];
const auto& rp = rhs[i];
if (lp->IsInstance<tvm::tir::VarNode>() && rp->IsInstance<tvm::tir::VarNode>()) {
continue;
}
if (lp->IsInstance<IntImmNode>() && rp->IsInstance<IntImmNode>() &&
Downcast<Integer>(lp)->value == Downcast<Integer>(rp)->value) {
continue;
}
if (lp->IsInstance<IntImmNode>() && Downcast<Integer>(lp)->value == 1) {
continue;
}
return false;
}
return true;
}

const Span SpanUtils::SetAttr(const Span& span, const String& key, const String& value) {
if (value.size() == 0) {
return span;
Expand Down Expand Up @@ -353,6 +386,10 @@ const Map<String, String> SpanUtils::GetAttrs(const Span& span) {
return attrs;
}

const Span SpanUtils::CreateWithAttr(const String& key, const String& value) {
return SetAttr(Span(), key, value);
}

const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t inputs_num,
bool as_relax) {
Array<String> input_types;
Expand All @@ -370,6 +407,14 @@ const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t inputs
} else if (optype == "full" && as_relax) {
input_types.push_back("shape");
input_types.push_back("input");
} else if (optype == "strided_slice") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("axes");
input_types.push_back("begin");
input_types.push_back("end");
input_types.push_back("strides");
}
} else if (optype == "triu") {
input_types.push_back("input");
input_types.push_back("k");
Expand Down Expand Up @@ -454,13 +499,31 @@ const Array<String> ExprUtils::GetInputTypes(const RelayCall& call) {
return GetInputTypes(optype, call->args.size(), false);
}

const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) {
const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName);
if (suffix.size() > 0) {
return name + "_" + suffix;
}
return name;
}

const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
const auto& shape_opt = Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
return shape_opt.value();
}

const DataType ExprUtils::GetDataType(const Expr& expr) {
return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
}

TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr);

TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs);

TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr")
.set_body_typed([](const String& key, const String& value) -> Span {
return SpanUtils::SetAttr(Span(), key, value);
return SpanUtils::CreateWithAttr(key, value);
});

TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr")
Expand Down
54 changes: 47 additions & 7 deletions src/contrib/msc/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/source_map.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relay/expr.h>

#include <string>
Expand Down Expand Up @@ -175,13 +176,6 @@ class StringUtils {
* \return The String.
*/
TVM_DLL static const String ToString(const runtime::ObjectRef& obj);

/*!
* \brief Compare String arrays.
* \return Whether two array are same.
*/
TVM_DLL static bool CompareArrays(const Array<String>& left, const Array<String>& right,
int size = -1);
};

/*!
Expand Down Expand Up @@ -238,6 +232,10 @@ class ArrayUtils {
return new_array;
}

/*!
* \brief Product elements in the arrays.
* \return The producted array
*/
template <typename T>
TVM_DLL static const Array<Array<T>> Product(const Array<Array<T>>& arrays) {
Array<Array<T>> p_arrays;
Expand All @@ -260,6 +258,24 @@ class ArrayUtils {
}
return p_arrays;
}

/*!
* \brief Compare String arrays.
* \return Whether two array are same.
*/
TVM_DLL static bool CompareArrays(const Array<String>& left, const Array<String>& right,
int size = -1);
/*!
* \brief Accumulate array.
* \return The accumulate result
*/
TVM_DLL static PrimExpr Accumulate(const Array<PrimExpr>& array, int pos = -1);

/*!
* \brief Check if lhs array is broadcastable to rhs.
* \return broadcastable
*/
TVM_DLL static bool Broadcastable(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs);
};

/*!
Expand All @@ -284,6 +300,12 @@ class SpanUtils {
* \return The Attrs Map.
*/
TVM_DLL static const Map<String, String> GetAttrs(const Span& span);

/*!
* \brief Create a span with <key>value</key>.
* \return The created Span.
*/
TVM_DLL static const Span CreateWithAttr(const String& key, const String& value);
};

/*!
Expand Down Expand Up @@ -365,6 +387,24 @@ class ExprUtils {
TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 0) {
return GetScalar<T>(constant->data, i);
}

/*!
* \brief Get name in span.
* \return The name.
*/
TVM_DLL static const String GetSpanName(const Expr& expr, const String& suffix = "");

/*!
* \brief Get shape of expr.
* \return The shape.
*/
TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);

/*!
* \brief Get dtype of expr.
* \return The shape.
*/
TVM_DLL static const DataType GetDataType(const Expr& expr);
};

} // namespace msc
Expand Down
3 changes: 0 additions & 3 deletions tests/python/contrib/test_msc/test_graph_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

""" Test graph builder && graph. """

import pytest

import torch
from torch import fx
from torch.nn import Module
Expand Down Expand Up @@ -1101,7 +1099,6 @@ def forward(self, data):
verify_model(GetAttr1(), input_info, expected)


@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")
def test_getitem():
"""test graph builder for getitem"""

Expand Down
4 changes: 0 additions & 4 deletions tests/python/contrib/test_msc/test_translate_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

""" Test translate from relax. """

import pytest

import torch
from torch import fx
from torch.nn import Module
Expand Down Expand Up @@ -57,7 +55,6 @@ def _run_relax(relax_mod):
relax_exec = tvm.relax.build(relax_mod, target)
vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
res = vm_runner["main"](*args)

return _tvm_runtime_to_np(res)

rt_mod = tvm_codegen.to_relax(
Expand Down Expand Up @@ -629,7 +626,6 @@ def forward(self, data):
_verify_model(GetAttr1(), input_info)


@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")
def test_getitem():
"""test relax translator for getitem"""

Expand Down
4 changes: 0 additions & 4 deletions tests/python/contrib/test_msc/test_translate_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

""" Test translate from tensorflow. """

import pytest

from packaging import version as package_version
import numpy as np

Expand Down Expand Up @@ -504,7 +502,6 @@ def _test_stridedslice(
verify_model(graph_def, golden, **io_info)


@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")
def test_stridedslice():
"""test tensorflow translator for stridedslice"""

Expand Down Expand Up @@ -1065,7 +1062,6 @@ def _test_slice_operation_input(input_value, begin_value, size_value):
verify_model(graph_def, golden, **io_info)


@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")
def test_slice():
"""test tensorflow translator for slice"""

Expand Down
3 changes: 0 additions & 3 deletions tests/python/contrib/test_msc/test_translate_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

""" Test translate from torch. """

import pytest

import numpy as np

import torch
Expand Down Expand Up @@ -589,7 +587,6 @@ def forward(self, data):
verify_model(GetAttr1(), input_info)


@pytest.mark.xfail(reason="MSC does not support Tuple of PrimValue")
def test_getitem():
"""test torch translator for getitem"""

Expand Down