Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct NodeTypeChecker {
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
// always allow nullptr.
if (sptr == nullptr) return true;
return sptr->derived_from<ContainerType>();
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
Expand All @@ -46,7 +48,7 @@ struct NodeTypeChecker {
template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (sptr == nullptr) return true;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
Expand All @@ -64,7 +66,7 @@ struct NodeTypeChecker<Array<T> > {
template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (sptr == nullptr) return true;
if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) {
Expand All @@ -83,7 +85,7 @@ struct NodeTypeChecker<Map<std::string, V> > {
template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (sptr == nullptr) return true;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
Expand Down
5 changes: 5 additions & 0 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,10 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<OpNode>([](const OpNode* node, tvm::IRPrinter* p) {
p->stream << "Op(" << node->name << ")";
});

} // namespace relay
} // namespace tvm
20 changes: 12 additions & 8 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ bool Conv2DRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;

static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");

Expand All @@ -42,14 +41,17 @@ bool Conv2DRel(const Array<Type>& types,
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;

std::vector<IndexExpr> dshape_nchw = ConvertLayout(
data->shape, in_layout, kNCHW);

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
{param->channels / param->groups,
data->shape[1] / param->groups,
{param->channels / param->groups,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
Expand Down Expand Up @@ -78,16 +80,16 @@ bool Conv2DRel(const Array<Type>& types,
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({data->shape[0], channels, 0, 0});
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});

oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
Expand Down Expand Up @@ -183,7 +185,9 @@ bool Conv2DTransposeRel(const Array<Type>& types,
<< " But got "<< kernel_layout;

IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);

auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);

// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
Expand Down
4 changes: 1 addition & 3 deletions src/relay/op/nn/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,7 @@ inline std::vector<IndexExpr> ConvertLayout(
IndexExpr src_dim_size = src[i];

if (src_minor_pos >= 0) {
const int64_t* minor_size = as_const_int(src[src_minor_pos]);
CHECK(minor_size == nullptr &&
src_factor == minor_size[0])
CHECK(is_const_int(src[src_minor_pos], src_factor))
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
Expand Down
21 changes: 18 additions & 3 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test_conv2d_infer_type():

# Infer with a different layout
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
w = relay.var("w")
y = relay.nn.conv2d(x, w,
x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(x, wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
Expand All @@ -47,6 +47,21 @@ def test_conv2d_infer_type():
assert yy.args[1].checked_type == relay.TensorType(
(4, 8, 3, 3, 4, 4), "int8")

# Infer with NHWC
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, h, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(x, wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NHWC",
out_dtype="int32")
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, h, w, 16), "int32")


def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 10, 12
Expand Down