diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index c5a83608c617..45366f3ad55a 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -35,6 +35,8 @@ struct NodeTypeChecker { // It can be turned off, but will make non strict checking. // TODO(tqchen) possibly find alternative to turn of RTTI using ContainerType = typename T::ContainerType; + // always allow nullptr. + if (sptr == nullptr) return true; return sptr->derived_from(); } static inline void PrintName(std::ostringstream& os) { // NOLINT(*) @@ -46,7 +48,7 @@ struct NodeTypeChecker { template struct NodeTypeChecker > { static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; + if (sptr == nullptr) return true; if (!sptr->is_type()) return false; ArrayNode* n = static_cast(sptr); for (const auto& p : n->data) { @@ -64,7 +66,7 @@ struct NodeTypeChecker > { template struct NodeTypeChecker > { static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; + if (sptr == nullptr) return true; if (!sptr->is_type()) return false; StrMapNode* n = static_cast(sptr); for (const auto& kv : n->data) { @@ -83,7 +85,7 @@ struct NodeTypeChecker > { template struct NodeTypeChecker > { static inline bool Check(Node* sptr) { - if (sptr == nullptr) return false; + if (sptr == nullptr) return true; if (!sptr->is_type()) return false; MapNode* n = static_cast(sptr); for (const auto& kv : n->data) { diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 4826aed54ba5..96e805b5af2f 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -150,5 +150,10 @@ TVM_REGISTER_NODE_TYPE(OpNode) return static_cast(n)->name; }); +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const OpNode* node, tvm::IRPrinter* p) { + p->stream << "Op(" << node->name << ")"; + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index b573a2981c39..8e1d9db50e7e 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -21,7 +21,6 @@ bool Conv2DRel(const Array& types, const auto* data = types[0].as(); const auto* weight = types[1].as(); if (data == nullptr) return false; - static const Layout kNCHW("NCHW"); static const Layout kOIHW("OIHW"); @@ -42,14 +41,17 @@ bool Conv2DRel(const Array& types, << "Conv only support output layouts that are convertible from NCHW." << " But got " << out_layout; + std::vector dshape_nchw = ConvertLayout( + data->shape, in_layout, kNCHW); + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; // infer weight if the kernel_size and channels are defined if (param->kernel_size.defined() && param->channels.defined()) { CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); std::vector wshape( - {param->channels / param->groups, - data->shape[1] / param->groups, + {param->channels / param->groups, + dshape_nchw[1] / param->groups, param->kernel_size[0], param->kernel_size[1]}); wshape = ConvertLayout(wshape, kOIHW, kernel_layout); @@ -78,16 +80,16 @@ bool Conv2DRel(const Array& types, << " channels=" << param->channels << " wshape=" << Array(wshape); } - CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1])); + CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); channels = wshape[0]; dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; } // dilation - std::vector oshape({data->shape[0], channels, 0, 0}); + std::vector oshape({dshape_nchw[0], channels, 0, 0}); - oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1; - oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1; + oshape[2] = (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1; + oshape[3] = (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1; DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { out_dtype = data->dtype; @@ -183,7 +185,9 @@ bool Conv2DTransposeRel(const Array& types, << " But got "<< kernel_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); + + auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); + // infer weight if the kernel_size and channels are defined if (param->kernel_size.defined() && param->channels.defined()) { CHECK_EQ(param->kernel_size.size(), 2); diff --git a/src/relay/op/nn/layout.h b/src/relay/op/nn/layout.h index b1dc4a71af1c..d9eb59d6e31c 100644 --- a/src/relay/op/nn/layout.h +++ b/src/relay/op/nn/layout.h @@ -495,9 +495,7 @@ inline std::vector ConvertLayout( IndexExpr src_dim_size = src[i]; if (src_minor_pos >= 0) { - const int64_t* minor_size = as_const_int(src[src_minor_pos]); - CHECK(minor_size == nullptr && - src_factor == minor_size[0]) + CHECK(is_const_int(src[src_minor_pos], src_factor)) << "src shape " << Array(src) << " does not agree with layout " << src_layout; diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 2f32b316924a..9dd2491289f2 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -32,9 +32,9 @@ def test_conv2d_infer_type(): # Infer with a different layout n, c, h, w = 4, 32, 224, 224 - x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) - w = relay.var("w") - y = relay.nn.conv2d(x, w, + x = relay.var("x", relay.TensorType((n//4, c//4, h, w, 4, 4), "int8")) + wt = relay.var("w") + y = relay.nn.conv2d(x, wt, kernel_size=(3, 3), padding=(1, 1), channels=16, @@ -47,6 +47,21 @@ def test_conv2d_infer_type(): assert yy.args[1].checked_type == relay.TensorType( (4, 8, 3, 3, 4, 4), "int8") + # Infer with NHWC + n, c, h, w = 4, 32, 224, 224 + x = relay.var("x", relay.TensorType((n, h, w, c), "int8")) + wt = relay.var("w") + y = relay.nn.conv2d(x, wt, + kernel_size=(3, 3), + padding=(1, 1), + channels=16, + data_layout="NHWC", + out_dtype="int32") + yy = relay.ir_pass.infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, h, w, 16), "int32") + + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension n, c, h, w = tvm.var("n"), 10, 10, 12