diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 109e80c99783..510c7eebaf46 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -513,7 +513,9 @@ class Gemm(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs)) + assert len(inputs) == 3 or len(inputs) == 2, "Gemm op take 2 or 3 inputs, {} given".format( + len(inputs) + ) # Y = alpha * A * B + beta * C alpha = float(attr.get("alpha", 1.0)) beta = float(attr.get("beta", 1.0)) @@ -531,9 +533,12 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] *= _expr.const(alpha) out = _op.nn.dense(inputs[0], inputs[1], units=channels) - # skip (beta * C) if zero - C_array = params[inputs[2].name_hint].asnumpy() - if (beta == 0.0) or np.array_equal(C_array, np.array([0])): + if len(inputs) == 3: + # skip (beta * C) if zero + C_array = params[inputs[2].name_hint].asnumpy() + if (beta == 0.0) or np.array_equal(C_array, np.array([0])): + return out + else: return out return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 59ecffe829df..80e32b9a893f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1008,6 +1008,31 @@ def test_onehot(): tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu +def test_gemm(): + a_shape = (4, 3) + b_shape = (3, 4) + out_shape = [a_shape[0], b_shape[1]] + + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + + gemm_node = helper.make_node("Gemm", ["a", "b"], ["out"]) + + graph = helper.make_graph( + [gemm_node], + "gemm_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="gemm_test") + verify_with_ort_with_inputs(model, [a_array, b_array]) + + @tvm.testing.uses_gpu def test_matmul(): a_shape = (4, 3) @@ -4065,6 +4090,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): test_clip() test_clip_min_max_as_inputs() test_onehot() + test_gemm() test_matmul() test_gather() test_gatherelements()