Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions apps/cpp_rtvm/tvm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ namespace runtime {
/*!
* \brief various meta information related to the compiled TVM model.
*/
typedef struct {
public:
typedef struct _TVMMetaInfo {
int n_inputs;
int n_outputs;
std::map<std::string, std::pair<std::vector<int>, std::string>> input_info;
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_VULKAN="${USE_VULKAN}"
TVM_INFO_USE_CLML="${USE_CLML}"
TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}"
TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}"
TVM_INFO_USE_UMA="${USE_UMA}"
TVM_INFO_USE_VERILATOR="${USE_VERILATOR}"
TVM_INFO_USE_CCACHE="${USE_CCACHE}"
Expand Down
45 changes: 44 additions & 1 deletion python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
def clml_sdk_version():
"""Utility function to get clml version version"""

return tvm.support.libinfo().get("TVM_CLML_VERSION", 2)
return int(tvm.support.libinfo().get("TVM_CLML_VERSION", 2))


def is_clml_runtime_enabled():
Expand Down Expand Up @@ -155,6 +155,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
seq = tvm.transform.Sequential(
[
transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}),
transform.ConvertLayout({"nn.conv2d_transpose": ["NCHW", "OIHW"]}),
transform.AlterOpLayout(),
transform.FoldConstant(),
]
Expand Down Expand Up @@ -203,6 +204,22 @@ def conv_pattern():
pattern = pattern.optional(is_op("clip"))
return pattern

def conv_transpose_pattern():
"""Create a transposed convolution pattern."""
pattern = is_op("nn.conv2d_transpose")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
pattern = pattern.optional(
lambda x: is_tuple_get_item(
is_op("nn.batch_norm")(
x, is_constant(), is_constant(), is_constant(), is_constant()
)
)
)
pattern = pattern.optional(is_op("nn.relu"))
pattern = pattern.optional(is_op("clip"))
return pattern

def pad_conv_pattern():
"""Create a pad with convolution pattern."""
pattern = is_op("nn.pad")(wildcard(), is_constant())
Expand Down Expand Up @@ -300,6 +317,31 @@ def check_conv(extract):
return False
return True

def check_conv_transpose(extract):
"""Check transposed conv pattern is supported by CLML."""
call = extract
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value
elif call.op.name == "nn.relu":
call = call.args[0]
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value
elif call.op.name == "clip":
if call.attrs["a_min"] != 0.0 or call.attrs["a_max"] != 6.0:
return False
call = call.args[0]
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value

while call.op.name != "nn.conv2d_transpose":
call = call.args[0]

attrs = call.attrs
if attrs.data_layout != "NCHW":
return False

return True

def check_binary_op(extract):
call = extract
if len(call.args[1].checked_type.shape) > 0:
Expand Down Expand Up @@ -340,6 +382,7 @@ def check_default_op(extract):
return [
("clml.pad_conv2d", pad_conv_pattern(), check_conv),
("clml.conv2d", conv_pattern(), check_conv),
("clml.conv2d_transpose", conv_transpose_pattern(), check_conv_transpose),
("clml.dense", dense_pattern(), check_default_op),
("clml.pad", pad_pattern(), check_pad_op),
("clml.concat", concat_pattern(), check_concat_op),
Expand Down
40 changes: 24 additions & 16 deletions src/relay/backend/contrib/clml/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK(comp.defined()) << "CLML JSON runtime only supports composite functions.";
const std::string name = comp.value();
std::shared_ptr<JSONGraphNode> json_node;
if (name == "clml.conv2d" || name == "clml.pad_conv2d") {
if (name == "clml.conv2d" || name == "clml.pad_conv2d" || name == "clml.conv2d_transpose") {
json_node = CreateCompositeConvJSONNode(cn);
} else if (name == "clml.batch_norm") {
json_node = CreateBatchNormJSONNode(cn);
Expand Down Expand Up @@ -169,7 +169,10 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
current_call = current_call->args[0].as<CallNode>();
}
// Enforce a convolution node exists at this point during traversal
ICHECK(backend::IsOp(current_call, "nn.conv2d"));
if (!backend::IsOp(current_call, "nn.conv2d") &&
!backend::IsOp(current_call, "nn.conv2d_transpose")) {
LOG(FATAL) << "Can't find primary op in Convolution node";
}
nodes.conv = current_call;
if (!current_call->args.empty() && current_call->args[0]->IsInstance<CallNode>()) {
current_call = current_call->args[0].as<CallNode>();
Expand All @@ -189,22 +192,27 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
CompositeConvNode nodes = UnpackCompositeConvolution(cn);

const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>();
ICHECK(conv_attr);

std::string name;
std::string name_prefix = "nn";

// Distinguish between normal and depth-wise convolution
if (conv_attr->channels.defined() &&
tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) &&
conv_attr->groups != 1) {
name = "depthwise_conv2d";
ICHECK(conv_attr->kernel_layout == "IOHW")
<< "Kernel layout must be IHWO, has the module been pre-processed correctly?";
} else {
name = "conv2d";
ICHECK(conv_attr->kernel_layout == "OIHW")
if (backend::IsOp(nodes.conv, "nn.conv2d")) {
const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>();
ICHECK(conv_attr);
if (conv_attr->channels.defined() &&
tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) &&
conv_attr->groups != 1) {
name = "depthwise_conv2d";
ICHECK(conv_attr->kernel_layout == "IOHW")
<< "Kernel layout must be IHWO, has the module been pre-processed correctly?";
} else {
name = "conv2d";
ICHECK(conv_attr->kernel_layout == "OIHW")
<< "Kernel layout must be OHWI, has the module been pre-processed correctly?";
}
} else if (backend::IsOp(nodes.conv, "nn.conv2d_transpose")) {
name = "conv2d_transpose";
const auto* conv_transpose_attr = nodes.conv->attrs.as<Conv2DTransposeAttrs>();
ICHECK(conv_transpose_attr);
ICHECK(conv_transpose_attr->kernel_layout == "OIHW")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to check kernel_layout here? You have already done it in python/tvm/relay/op/contrib/clml.py. If it is not necessary, then probably the same checks can be removed for convolution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this check in python frontend. Codegen is the final component to make sure we don't generate incompatible code.

<< "Kernel layout must be OHWI, has the module been pre-processed correctly?";
}

Expand Down
Loading