From 8d1dbea8fc81b1e5a310fd553aec611c00873d68 Mon Sep 17 00:00:00 2001 From: xutianming Date: Mon, 22 Feb 2021 14:53:09 +0800 Subject: [PATCH 1/4] Make onnx gemm tensor C optional --- python/tvm/relay/frontend/onnx.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 109e80c99783..800196eb3785 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -513,7 +513,8 @@ class Gemm(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs)) + assert len(inputs) == 3 or len(inputs) == 2, \ + "Gemm op take 2 or 3 inputs, {} given".format(len(inputs)) # Y = alpha * A * B + beta * C alpha = float(attr.get("alpha", 1.0)) beta = float(attr.get("beta", 1.0)) @@ -531,9 +532,12 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] *= _expr.const(alpha) out = _op.nn.dense(inputs[0], inputs[1], units=channels) - # skip (beta * C) if zero - C_array = params[inputs[2].name_hint].asnumpy() - if (beta == 0.0) or np.array_equal(C_array, np.array([0])): + if len(inputs) == 3: + # skip (beta * C) if zero + C_array = params[inputs[2].name_hint].asnumpy() + if (beta == 0.0) or np.array_equal(C_array, np.array([0])): + return out + else: return out return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) From 2e87198c046b1065eae7f0574806e2d6488d3b5a Mon Sep 17 00:00:00 2001 From: xutianming Date: Mon, 22 Feb 2021 14:59:10 +0800 Subject: [PATCH 2/4] fix codestyle --- python/tvm/relay/frontend/onnx.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 800196eb3785..510c7eebaf46 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -513,8 +513,9 @@ class Gemm(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - assert len(inputs) == 3 or len(inputs) == 2, \ - "Gemm op take 2 or 3 inputs, {} given".format(len(inputs)) + assert len(inputs) == 3 or len(inputs) == 2, "Gemm op take 2 or 3 inputs, {} given".format( + len(inputs) + ) # Y = alpha * A * B + beta * C alpha = float(attr.get("alpha", 1.0)) beta = float(attr.get("beta", 1.0)) From a80d060aa021b14d69ef80f94c21b5c13bc0bd19 Mon Sep 17 00:00:00 2001 From: xutianming Date: Tue, 23 Feb 2021 14:03:30 +0800 Subject: [PATCH 3/4] add tests --- tests/python/frontend/onnx/test_forward.py | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 59ecffe829df..9430861d1059 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1007,6 +1007,29 @@ def test_onehot(): ) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu +def test_gemm(): + a_shape = (4, 3) + b_shape = (3, 4) + out_shape = [a_shape[0], b_shape[1]] + + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + + gemm_node = helper.make_node("Gemm", ["a", "b"], ["out"]) + + graph = helper.make_graph( + [gemm_node], + "gemm_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="gemm_test") + verify_with_ort_with_inputs(model, [a_array, b_array]) @tvm.testing.uses_gpu def test_matmul(): @@ -4065,6 +4088,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): test_clip() test_clip_min_max_as_inputs() test_onehot() + test_gemm() test_matmul() test_gather() test_gatherelements() From cce01089db23b0d4da92d8bd214120a19e56832f Mon Sep 17 00:00:00 2001 From: xutianming Date: Tue, 23 Feb 2021 15:28:31 +0800 Subject: [PATCH 4/4] fix codestyle --- tests/python/frontend/onnx/test_forward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 9430861d1059..80e32b9a893f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1007,6 +1007,7 @@ def test_onehot(): ) tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) + @tvm.testing.uses_gpu def test_gemm(): a_shape = (4, 3) @@ -1031,6 +1032,7 @@ def test_gemm(): model = helper.make_model(graph, producer_name="gemm_test") verify_with_ort_with_inputs(model, [a_array, b_array]) + @tvm.testing.uses_gpu def test_matmul(): a_shape = (4, 3)