From 253c9699f283fbe4e806a4fee10faf9dc9c76122 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 4 Feb 2021 22:36:09 +0000 Subject: [PATCH 1/3] Add support for MXNet GroupNorm --- python/tvm/relay/frontend/mxnet.py | 14 +++++++++++ tests/python/frontend/mxnet/test_forward.py | 27 +++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b272ead9737d..0c9d2c4381ac 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -495,6 +495,19 @@ def _mx_layer_norm(inputs, attrs): return _op.nn.layer_norm(*inputs, **new_attrs) +def _mx_group_norm(inputs, attrs): + assert len(inputs) == 3 + if attrs.get_bool("output_mean_var", False): + raise tvm.error.OpAttributeUnimplemented( + 'Attribute "output_mean_var" is not supported for operator Group Norm.' + ) + new_attrs = {} + new_attrs["axis"] = 1 + new_attrs["num_groups"] = attrs.get_int("num_groups", 1) + new_attrs["epsilon"] = attrs.get_float("eps", 1e-5) + return _op.nn.group_norm(*inputs, **new_attrs) + + def _mx_slice(inputs, attrs): new_attrs = {} begin = list(attrs.get_int_tuple("begin", None)) @@ -2599,6 +2612,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_contrib_SyncBatchNorm": _mx_batch_norm, "InstanceNorm": _mx_instance_norm, "LayerNorm": _mx_layer_norm, + "GroupNorm": _mx_group_norm, "LRN": _mx_lrn, "L2Normalization": _mx_l2_normalize, "slice": _mx_slice, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 537349e073e1..38e820639b6a 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1263,6 +1263,33 @@ def verify(shape, axis=-1): verify((2, 5, 6)) +@tvm.testing.uses_gpu +def test_forward_group_norm(): + def verify(shape, num_groups=1): + x = np.random.uniform(size=shape).astype("float32") + gamma = np.random.uniform(size=(shape[1])).astype("float32") + beta = np.random.uniform(size=(shape[1])).astype("float32") + ref_res = mx.nd.GroupNorm(data=mx.nd.array(x), gamma=mx.nd.array(gamma), beta=mx.nd.array(beta), num_groups=num_groups) + mx_sym = mx.sym.GroupNorm( + mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), num_groups=num_groups + ) + shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x, gamma, beta) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5 + ) + + verify((1, 4, 2), num_groups=4) + # TODO(trevmorr): MXNet GroupNorm implementation is bugged for cases when num_groups != num_channels + # https://github.com/apache/incubator-mxnet/pull/18199 + # verify((1, 4, 2, 3), num_groups=2) + # verify((1, 4, 2, 3)) + + @tvm.testing.uses_gpu def test_forward_one_hot(): def verify(indices_shape, depth, on_value, off_value, dtype): From e32a61e56a8a7283450d2dbac66ade2412528947 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 8 Feb 2021 23:21:21 +0000 Subject: [PATCH 2/3] Fix python lint --- tests/python/frontend/mxnet/test_forward.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 38e820639b6a..59d7449ee40e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1269,7 +1269,12 @@ def verify(shape, num_groups=1): x = np.random.uniform(size=shape).astype("float32") gamma = np.random.uniform(size=(shape[1])).astype("float32") beta = np.random.uniform(size=(shape[1])).astype("float32") - ref_res = mx.nd.GroupNorm(data=mx.nd.array(x), gamma=mx.nd.array(gamma), beta=mx.nd.array(beta), num_groups=num_groups) + ref_res = mx.nd.GroupNorm( + data=mx.nd.array(x), + gamma=mx.nd.array(gamma), + beta=mx.nd.array(beta), + num_groups=num_groups + ) mx_sym = mx.sym.GroupNorm( mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), num_groups=num_groups ) From 2afbbe6f5d39861adcb6ebffb4c8cbb645341e41 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 10 Feb 2021 17:33:05 +0000 Subject: [PATCH 3/3] Fix lint --- tests/python/frontend/mxnet/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 59d7449ee40e..3e652cfc69e3 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1273,7 +1273,7 @@ def verify(shape, num_groups=1): data=mx.nd.array(x), gamma=mx.nd.array(gamma), beta=mx.nd.array(beta), - num_groups=num_groups + num_groups=num_groups, ) mx_sym = mx.sym.GroupNorm( mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), num_groups=num_groups