From 0fbf4d062c1d8348b56f19021b3c3051cd25f952 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Sep 2021 18:44:17 +0900 Subject: [PATCH 1/4] Unify dense input layout to NC --- python/tvm/relay/op/nn/nn.py | 2 +- python/tvm/topi/x86/dense_alter_op.py | 2 +- src/relay/op/nn/nn.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index e882bcf7e271..9285772eb3dd 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1548,7 +1548,7 @@ def dense(data, weight, units=None, out_dtype=""): return _make.dense(data, weight, units, out_dtype) -def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""): +def contrib_dense_pack(data, weight, weight_layout="NC", units=None, out_dtype=""): """Dense operator. Applies a linear transformation with packed weight diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index cb2f1929d395..8db84497f82d 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -47,7 +47,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): if cfg.is_fallback: _default_dense_pack_config(cfg, M, N, K) packw_bn = cfg["tile_x"].size[-1] - weight_layout = "NK%dn" % packw_bn + weight_layout = "NC%dn" % packw_bn new_weight = te.placeholder( (N // packw_bn, K, packw_bn), dtype=weight_tensor.dtype, diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index a05e460dc680..f334361874a3 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -210,7 +210,7 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs); + return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs); } TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense); From 8dc9548ac2f3720ffb271b2dc4842683f23f77cd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Sep 2021 19:07:47 +0900 Subject: [PATCH 2/4] add test --- tests/python/frontend/pytorch/test_forward.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index bae7c1b5498c..7b1cd8f53e8b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1576,6 +1576,13 @@ class LinearNoBias(Module): def forward(self, input, weight): return F.linear(input, weight) + class LinearNested(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return F.linear(x, F.linear(y, z)) + input2d = torch.rand([2, 2]).float() input3d = torch.rand([4, 3, 2]).float() weight1d = torch.rand([2]).float() @@ -1595,6 +1602,9 @@ def forward(self, input, weight): verify_model(LinearNoBias(), input_data=[input2d, weight1d]) # 3D input, 2D weight, no bias verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) + + verify_model(LinearNested(), input_data=[torch.randn(10, 10) for _ in range(3)]) + # TODO: Add the following cases when matmul(1D, _) is supported by TVM # 1D input, 2D weight, 1D bias # 1D input, 2D weight, no bias From d48817a641324ab18a363ab6aae2d3e3e665fff3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Sep 2021 19:09:09 +0900 Subject: [PATCH 3/4] update existing tests --- tests/python/relay/test_pass_alter_op_layout.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index b5702a1542a9..69041fa4c8d8 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1317,8 +1317,8 @@ def before(): def expected(): x = relay.var("x", shape=(32, 64)) weight = relay.var("weight", shape=(48, 64)) - target_layout = "NK16n" - weight_transform = relay.layout_transform(weight, "NK", target_layout) + target_layout = "NC16n" + weight_transform = relay.layout_transform(weight, "NC", target_layout) y = relay.nn.contrib_dense_pack( x, weight_transform, target_layout, units=None, out_dtype="float32" ) @@ -1387,8 +1387,8 @@ def expected(): squeeze = relay.squeeze(pool, axis=[2, 3]) dense = relay.nn.contrib_dense_pack( relay.layout_transform(squeeze, "NC8c", "NC"), - relay.layout_transform(dense_weight, "NK", "NK16n"), - "NK16n", + relay.layout_transform(dense_weight, "NC", "NC16n"), + "NC16n", out_dtype="float32", ) return relay.Function(analysis.free_vars(dense), dense) From d9abe20258bf32b6aa65653c6b821f0c0cc51fad Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 3 Sep 2021 19:10:44 +0900 Subject: [PATCH 4/4] update doc --- include/tvm/relay/attrs/nn.h | 4 ++-- python/tvm/relay/op/nn/nn.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index d28044c3845d..de60deb9cccb 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1017,8 +1017,8 @@ struct DensePackAttrs : public tvm::AttrsNode { .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); TVM_ATTR_FIELD(weight_layout) - .set_default("NK") - .describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible."); + .set_default("NC") + .describe("Dimension ordering of weight. Packed layouts, such as NC8n, are possible."); } }; diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 9285772eb3dd..5a17db745b3e 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1567,7 +1567,7 @@ def contrib_dense_pack(data, weight, weight_layout="NC", units=None, out_dtype=" of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`. weight_layout: str - The layout of weight, such as "NK" or "NK8n". + The layout of weight, such as "NC" or "NC8n". units : int, optional Number of hidden units of the dense transformation.