From c4b3e1641605ad73e53e408f1aa07d3d7aee0b30 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 7 Apr 2022 16:54:17 -0700 Subject: [PATCH 01/17] EmbedLayerNormalization, Attention --- python/tvm/relay/frontend/onnx.py | 142 ++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7dcb9952c7fb..d65a472b80b5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -836,6 +836,145 @@ def _impl_v1(cls, inputs, attr, params): return Gelu._impl_v1([inp], attr, params) +class EmbedLayerNormalization(OnnxOpConverter): + @classmethod + def _impl_v1(cls, inputs, attr, params): + input_ids = inputs[0] + segment_ids = inputs[1] + word_emb = inputs[2] + pos_emb = inputs[3] + segment_emb = inputs[4] + gamma = inputs[5] + beta = inputs[6] + + mask = inputs[7] + pos_ids = inputs[8] + + (batch_size, seq_len) = infer_shape(input_ids) + + if segment_ids: + assert segment_emb + + if pos_ids is None: + pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64") + + word_vec = _op.take(word_emb, input_ids, axis=0) + segment_vec = _op.take(segment_emb, segment_ids, axis=0) + pos_vec = _op.take(pos_emb, pos_ids, axis=0) + + vec_sum = _op.add(word_vec, pos_vec) + if segment_ids: + vec_sum = _op.add(vec_sum, segment_vec) + + eps_dtype = infer_type(word_emb).checked_type.dtype + + u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True) + ln = _op.divide(_op.subtract(vec_sum, u), _op.sqrt(_op.add(s, _op.const(attr["epsilon"], dtype=eps_dtype)))) + ln = _op.multiply(ln, gamma) + beta + + # TODO: actually calculate this + mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) + + return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3) + + +class SkipLayerNormalization(OnnxOpConverter): + @classmethod + def _impl_v1(cls, inputs, attr, params): + breakpoint() + + +class Attention(OnnxOpConverter): + @classmethod + def _impl_v1(cls, inputs, attr, params): + num_heads = attr["num_heads"] + assert "qkv_hidden_sizes" not in attr + assert "unidirectional" not in attr + + # (batch, seq, in_hidden) + input_emb = inputs[0] + + # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size + weight = inputs[1] + + # (3 * out_hidden,) + bias = inputs[2] + + # 1. ( batch, 1, max_seq, max_seq) + # 2. ( batch, past_seq + seq,) + # 3. ( batch, seq, past_seq + seq,) + # 4. ( batch,) + # 5. (2 * batch,) + mask_index = inputs[3] + + # (2, batch, num_heads, past_seq, head_size) + past = inputs[4] + + # (batch, num_heads, seq, seq) + extra_add = inputs[5] + + (batch_size, seq_len, in_hidden) = infer_shape(input_emb) + (out_hidden_x3,) = infer_shape(bias) + assert out_hidden_x3 % 3 == 0 + out_hidden = out_hidden_x3 // 3 + assert out_hidden % num_heads == 0 + head_size = out_hidden // num_heads + + mask_index_shape = infer_shape(mask_index) + assert len(mask_index_shape) == 2 + assert mask_index_shape[0] == batch_size + assert mask_index_shape[1] == seq_len + + assert past is None + assert extra_add is None + + # decompose weight into Q, K, V: (in_hidden, out_hidden) and do the matmuls + w_Q, w_K, w_V = _op.split(weight, 3, axis=1) + + Q = _op.nn.matmul(input_emb, w_Q) + K = _op.nn.matmul(input_emb, w_K) + V = _op.nn.matmul(input_emb, w_V) + + # massage tensors in preparation for batched matmul + def massage(tensor, is_V=False): + axes = [0, 2, 3, 1] if is_V else [0, 2, 1, 3] + tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) + tensor = _op.transpose(tensor, axes=axes) + return _op.reverse_reshape(tensor, (-1, 0, 0)) + + Q = massage(Q) + K = massage(K) + V = massage(V, is_V=True) + + K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size)) + V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size)) + present = _op.stack([K_present, V_present], axis=0) + + att_scores = _op.nn.batch_matmul(Q, K) + score_dtype = infer_type(att_scores).checked_type.dtype + att_scores = _op.divide(att_scores, _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype)) + att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) + + # build the attention mask + att_mask = _op.cast(mask_index, score_dtype) + att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) + att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) + att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) + + # apply the mask + att_scores = _op.add(att_scores, att_mask) + att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) + + att_probs = _op.nn.softmax(att_scores, axis=-1) + + C = _op.nn.batch_matmul(att_probs, V) + C = _op.reverse_reshape(C, (-1, num_heads, 0, 0)) + C = _op.transpose(C, axes=[0, 2, 1, 3]) + C = _op.reshape(C, (0, 0, out_hidden)) + + return _expr.TupleWrapper(_expr.Tuple([C, present]), 2) + + class Gemm(OnnxOpConverter): """Operator converter for Gemm.""" @@ -4737,6 +4876,9 @@ def _get_convert_map(opset): "Elu": Elu.get_converter(opset), "Gelu": Gelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), + "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset), + "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset), + "Attention": Attention.get_converter(opset), "Exp": Renamer("exp"), "Greater": Renamer("greater"), "GreaterOrEqual": Renamer("greater_equal"), From 4190bb2ebc5b84e5397386000788955e7fbda537 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 13:57:25 -0700 Subject: [PATCH 02/17] fix Attention --- python/tvm/relay/frontend/onnx.py | 72 ++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d65a472b80b5..911339c61f97 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -850,6 +850,8 @@ def _impl_v1(cls, inputs, attr, params): mask = inputs[7] pos_ids = inputs[8] + eps = attr["epsilon"] if "epsilon" in attr else 0.0 + (batch_size, seq_len) = infer_shape(input_ids) if segment_ids: @@ -869,7 +871,10 @@ def _impl_v1(cls, inputs, attr, params): eps_dtype = infer_type(word_emb).checked_type.dtype u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True) - ln = _op.divide(_op.subtract(vec_sum, u), _op.sqrt(_op.add(s, _op.const(attr["epsilon"], dtype=eps_dtype)))) + ln = _op.divide( + _op.subtract(vec_sum, u), + _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), + ) ln = _op.multiply(ln, gamma) + beta # TODO: actually calculate this @@ -888,8 +893,10 @@ class Attention(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): num_heads = attr["num_heads"] - assert "qkv_hidden_sizes" not in attr - assert "unidirectional" not in attr + assert ( + "qkv_hidden_sizes" not in attr + ), "different hidden sizes for Q, K, V are not currently supported" + assert "unidirectional" not in attr, "unidirectional attention not current supported" # (batch, seq, in_hidden) input_emb = inputs[0] @@ -905,6 +912,7 @@ def _impl_v1(cls, inputs, attr, params): # 3. ( batch, seq, past_seq + seq,) # 4. ( batch,) # 5. (2 * batch,) + # For now, we only support case 2. mask_index = inputs[3] # (2, batch, num_heads, past_seq, head_size) @@ -915,44 +923,56 @@ def _impl_v1(cls, inputs, attr, params): (batch_size, seq_len, in_hidden) = infer_shape(input_emb) (out_hidden_x3,) = infer_shape(bias) - assert out_hidden_x3 % 3 == 0 + assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3" out_hidden = out_hidden_x3 // 3 - assert out_hidden % num_heads == 0 + assert ( + out_hidden % num_heads == 0 + ), "output hidden size should be divisible by number of attention heads" head_size = out_hidden // num_heads mask_index_shape = infer_shape(mask_index) - assert len(mask_index_shape) == 2 - assert mask_index_shape[0] == batch_size - assert mask_index_shape[1] == seq_len + assert ( + len(mask_index_shape) == 2 + and mask_index_shape[0] == batch_size + and mask_index_shape[1] == seq_len + ), "currently only support (batch_size, sequence_length) mask index" - assert past is None - assert extra_add is None + assert past is None, "past K, V state is not currently supported" + assert extra_add is None, "extra add to QxK not currently supported" - # decompose weight into Q, K, V: (in_hidden, out_hidden) and do the matmuls + # split weight and biases and do the matmuls w_Q, w_K, w_V = _op.split(weight, 3, axis=1) - - Q = _op.nn.matmul(input_emb, w_Q) - K = _op.nn.matmul(input_emb, w_K) - V = _op.nn.matmul(input_emb, w_V) + b_Q, b_K, b_V = _op.split(bias, 3, axis=0) + # need to merge batch dimensions since TVM matmul is 2D + input_emb = _op.reverse_reshape(input_emb, (-1, 0)) + Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q) + K = _op.add(_op.nn.matmul(input_emb, w_K), b_K) + V = _op.add(_op.nn.matmul(input_emb, w_V), b_V) # massage tensors in preparation for batched matmul - def massage(tensor, is_V=False): - axes = [0, 2, 3, 1] if is_V else [0, 2, 1, 3] + def massage(tensor): tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) - tensor = _op.transpose(tensor, axes=axes) + + # (batch_size, num_heads, seq_len, head_size) + tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) + + # (batch_size * num_heads, seq_len, head_size) return _op.reverse_reshape(tensor, (-1, 0, 0)) Q = massage(Q) K = massage(K) - V = massage(V, is_V=True) + V = massage(V) K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size)) V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size)) present = _op.stack([K_present, V_present], axis=0) - att_scores = _op.nn.batch_matmul(Q, K) + att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) score_dtype = infer_type(att_scores).checked_type.dtype - att_scores = _op.divide(att_scores, _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype)) + att_scores = _op.divide( + att_scores, + _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype), + ) att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) # build the attention mask @@ -967,12 +987,12 @@ def massage(tensor, is_V=False): att_probs = _op.nn.softmax(att_scores, axis=-1) - C = _op.nn.batch_matmul(att_probs, V) - C = _op.reverse_reshape(C, (-1, num_heads, 0, 0)) - C = _op.transpose(C, axes=[0, 2, 1, 3]) - C = _op.reshape(C, (0, 0, out_hidden)) + output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) + output = _op.reverse_reshape(output, (-1, num_heads, 0, 0)) + output = _op.transpose(output, axes=[0, 2, 1, 3]) + output = _op.reshape(output, (0, 0, out_hidden)) - return _expr.TupleWrapper(_expr.Tuple([C, present]), 2) + return _expr.TupleWrapper(_expr.Tuple([output, present]), 2) class Gemm(OnnxOpConverter): From 1d3064eb257e87fdf0ab0f1671ab3782e96d70dd Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 14:28:03 -0700 Subject: [PATCH 03/17] SkipLayerNormalization --- python/tvm/relay/frontend/onnx.py | 33 +++- tests/python/frontend/onnx/test_forward.py | 199 +++++++++++++++++++++ 2 files changed, 229 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 911339c61f97..4b142f461a24 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -850,7 +850,7 @@ def _impl_v1(cls, inputs, attr, params): mask = inputs[7] pos_ids = inputs[8] - eps = attr["epsilon"] if "epsilon" in attr else 0.0 + eps = attr["epsilon"] if "epsilon" in attr else 1e-12 (batch_size, seq_len) = infer_shape(input_ids) @@ -877,8 +877,10 @@ def _impl_v1(cls, inputs, attr, params): ) ln = _op.multiply(ln, gamma) + beta - # TODO: actually calculate this mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) + if mask: + # calculate number of words per sentence + mask_index = _op.sum(mask, axis=1) return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3) @@ -886,7 +888,32 @@ def _impl_v1(cls, inputs, attr, params): class SkipLayerNormalization(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - breakpoint() + data = inputs[0] + skip = inputs[1] + gamma = inputs[2] + beta = inputs[3] + bias = inputs[4] + + eps = attr["epsilon"] if "epsilon" in attr else 1e-12 + + x = _op.add(data, skip) + if bias is not None: + x = _op.add(x, bias) + + eps_dtype = infer_type(x).checked_type.dtype + + u, s = _op.mean_variance(x, axis=-1, keepdims=True) + output = _op.divide( + _op.subtract(x, u), + _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), + ) + output = _op.multiply(output, gamma) + if beta: + output = _op.add(output, beta) + + placeholder = _op.const(0, dtype="float32") + + return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) class Attention(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 638b4b8f57eb..3493c87e9be5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5433,6 +5433,205 @@ def verify_biasgelu(x, bias): verify_biasgelu(x, bias) +@tvm.testing.parametrize_targets +def test_embedlayernormalization(target, dev): + def verify_embedlayernormalization( + input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta, + ): + node = onnx.helper.make_node( + "EmbedLayerNormalization", + inputs=[ + "input_ids", + "segment_ids", + "word_embedding", + "position_embedding", + "segment_embedding", + "gamma", + "beta", + ], + outputs=["output", "mask_index", "embedding_sum"], + domain="com.microsoft", + ) + + node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4)) + + graph = helper.make_graph( + [node], + "embedlayernormalization_test", + inputs=[ + helper.make_tensor_value_info( + "input_ids", TensorProto.INT32, list(input_ids.shape) + ), + helper.make_tensor_value_info( + "segment_ids", TensorProto.INT32, list(segment_ids.shape) + ), + helper.make_tensor_value_info( + "word_embedding", TensorProto.FLOAT, list(word_embedding.shape) + ), + helper.make_tensor_value_info( + "position_embedding", TensorProto.FLOAT, list(position_embedding.shape) + ), + helper.make_tensor_value_info( + "segment_embedding", TensorProto.FLOAT, list(segment_embedding.shape) + ), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)), + ], + outputs=[ + helper.make_tensor_value_info( + "output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size)) + ), + helper.make_tensor_value_info( + "mask_index", TensorProto.INT32, [batch_size] + ), + helper.make_tensor_value_info( + "embedding_sum", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size)) + ), + ], + ) + + model = helper.make_model(graph, producer_name="embedlayernormalization_test") + verify_with_ort_with_inputs( + model, + [ + input_ids, + segment_ids, + word_embedding, + position_embedding, + segment_embedding, + gamma, + beta, + ], + [(batch_size, sequence_length, hidden_size), batch_size, (batch_size, sequence_length, hidden_size)], + target=target, + dev=dev, + rtol=1e-4, + atol=1e-4, + ) + + hidden_size = 384 + batch_size = 4 + sequence_length = 4 + vocab_size = 5 + + input_ids = np.full((batch_size, sequence_length), 3).astype("int32") + segment_ids = np.zeros((batch_size, sequence_length)).astype("int32") + word_embedding = np.full((vocab_size, hidden_size), 1).astype("float32") + position_embedding = np.full((sequence_length, hidden_size), 2).astype("float32") + segment_embedding = np.full((vocab_size, hidden_size), 3).astype("float32") + + gamma = np.random.uniform(0.5, 0.7, hidden_size).astype("float32") + beta = np.random.randn(hidden_size).astype("float32") * 0.1 + + verify_embedlayernormalization( + input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta + ) + + +@tvm.testing.parametrize_targets +def test_attention(target, dev): + def verify_attention(input, weight, bias, mask_index, num_heads): + node = onnx.helper.make_node( + "Attention", + inputs=["input", "weight", "bias", "mask_index"], + outputs=["output", "present"], + domain="com.microsoft", + num_heads=num_heads, + ) + + present_output_shape = (2, batch_size, num_heads, sequence_length, head_size) + + graph = helper.make_graph( + [node], + "attention_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)), + helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(weight.shape)), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)), + helper.make_tensor_value_info( + "mask_index", TensorProto.INT32, list(mask_index.shape) + ), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)), + helper.make_tensor_value_info( + "present", TensorProto.FLOAT, list(present_output_shape) + ), + ], + ) + + model = helper.make_model(graph, producer_name="attention_test") + + # "present" output should be nullptr when the "past" input isn't included, + # but ort requires an output shape to be specified? + verify_with_ort_with_inputs( + model, + [input, weight, bias, mask_index], + [input.shape, present_output_shape], + target=target, + dev=dev, + rtol=1e-4, + atol=1e-4, + ) + + hidden_size = 384 + batch_size = 4 + sequence_length = 4 + num_heads = 12 + head_size = 32 + + dtype = "float32" + input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) + weight = np.random.normal(size=(hidden_size, 3 * hidden_size)).astype(dtype) * 0.1 + bias = np.random.randn(3 * hidden_size).astype(dtype) + mask_index = np.full((batch_size, sequence_length), 1).astype("int32") + + verify_attention(input, weight, bias, mask_index, num_heads) + + +def test_skiplayernormalization(target, dev): + def verify_skiplayernormalization(input, skip, gamma, beta, bias): + node = onnx.helper.make_node( + "SkipLayerNormalization", + inputs=["input", "skip", "gamma", "beta", "bias"], + outputs=["output"], + domain="com.microsoft", + ) + + node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4)) + + graph = helper.make_graph( + [node], + "skiplayernormalization_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)), + helper.make_tensor_value_info("skip", TensorProto.FLOAT, list(skip.shape)), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)), + ], + ) + + model = helper.make_model(graph, producer_name="skiplayernormalization_test") + verify_with_ort_with_inputs(model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev) + + hidden_size = 384 + batch_size = 4 + sequence_length = 4 + + dtype = "float32" + input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) + skip = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) + gamma = np.random.uniform(0.5, 0.7, hidden_size).astype(dtype) + beta = np.random.randn(hidden_size).astype(dtype) * 0.1 + bias = np.random.randn(hidden_size).astype(dtype) + + verify_skiplayernormalization(input, skip, gamma, beta, bias) + + @tvm.testing.known_failing_targets("cuda") @tvm.testing.parametrize_targets def test_qlinearconv(target, dev): From 1927414fa5177b1a630e3fd05d6d438f7ca8a9b7 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 14:46:55 -0700 Subject: [PATCH 04/17] fix dtype bug in Gelu Co-authored-by: An Wang --- python/tvm/relay/frontend/onnx.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4b142f461a24..a5c0f5b1b6e0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -807,9 +807,10 @@ def _impl_v1(cls, inputs, attr, params): x = inputs[0] # Declare consts - half = _expr.const(0.5) - one = _expr.const(1.0) - sqrt2 = _expr.const(math.sqrt(2)) + const_dtype = infer_type(x).checked_type.dtype + half = _expr.const(0.5, dtype=const_dtype) + one = _expr.const(1.0, dtype=const_dtype) + sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype) # Compute gelu term1 = _op.multiply(half, x) From b718d6a7174f33eb92bdab768ad20903d046c42a Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 14:59:05 -0700 Subject: [PATCH 05/17] missing parameterize_targets --- tests/python/frontend/onnx/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 3493c87e9be5..57c0c20ac4b7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5589,6 +5589,7 @@ def verify_attention(input, weight, bias, mask_index, num_heads): verify_attention(input, weight, bias, mask_index, num_heads) +@tvm.testing.parametrize_targets def test_skiplayernormalization(target, dev): def verify_skiplayernormalization(input, skip, gamma, beta, bias): node = onnx.helper.make_node( From 90bb12f1aee95fe4f62c901430fedd5d12cbc088 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 14:59:40 -0700 Subject: [PATCH 06/17] lint --- tests/python/frontend/onnx/test_forward.py | 26 ++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 57c0c20ac4b7..1229d5a859fc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5436,7 +5436,13 @@ def verify_biasgelu(x, bias): @tvm.testing.parametrize_targets def test_embedlayernormalization(target, dev): def verify_embedlayernormalization( - input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta, + input_ids, + segment_ids, + word_embedding, + position_embedding, + segment_embedding, + gamma, + beta, ): node = onnx.helper.make_node( "EmbedLayerNormalization", @@ -5481,11 +5487,11 @@ def verify_embedlayernormalization( helper.make_tensor_value_info( "output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size)) ), + helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size]), helper.make_tensor_value_info( - "mask_index", TensorProto.INT32, [batch_size] - ), - helper.make_tensor_value_info( - "embedding_sum", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size)) + "embedding_sum", + TensorProto.FLOAT, + list((batch_size, sequence_length, hidden_size)), ), ], ) @@ -5502,7 +5508,11 @@ def verify_embedlayernormalization( gamma, beta, ], - [(batch_size, sequence_length, hidden_size), batch_size, (batch_size, sequence_length, hidden_size)], + [ + (batch_size, sequence_length, hidden_size), + batch_size, + (batch_size, sequence_length, hidden_size), + ], target=target, dev=dev, rtol=1e-4, @@ -5617,7 +5627,9 @@ def verify_skiplayernormalization(input, skip, gamma, beta, bias): ) model = helper.make_model(graph, producer_name="skiplayernormalization_test") - verify_with_ort_with_inputs(model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev) + verify_with_ort_with_inputs( + model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev + ) hidden_size = 384 batch_size = 4 From 49984926a092160aa0e20c151ef57f719509ebe3 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 15:11:41 -0700 Subject: [PATCH 07/17] lint --- python/tvm/relay/frontend/onnx.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a5c0f5b1b6e0..b5573d022626 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -838,6 +838,11 @@ def _impl_v1(cls, inputs, attr, params): class EmbedLayerNormalization(OnnxOpConverter): + """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset. + + This layer embeds the input tokens, sums them, and applies layer normalization. + """ + @classmethod def _impl_v1(cls, inputs, attr, params): input_ids = inputs[0] @@ -887,6 +892,12 @@ def _impl_v1(cls, inputs, attr, params): class SkipLayerNormalization(OnnxOpConverter): + """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset. + + This layer sums the two input tensors (along with optional bias), and applies layer + normalization. + """ + @classmethod def _impl_v1(cls, inputs, attr, params): data = inputs[0] @@ -918,6 +929,11 @@ def _impl_v1(cls, inputs, attr, params): class Attention(OnnxOpConverter): + """Operator converter for Attention from Microsoft onnxruntime contrib opset. + + This is the self-attention mechanism used in transformer models. + """ + @classmethod def _impl_v1(cls, inputs, attr, params): num_heads = attr["num_heads"] @@ -949,7 +965,7 @@ def _impl_v1(cls, inputs, attr, params): # (batch, num_heads, seq, seq) extra_add = inputs[5] - (batch_size, seq_len, in_hidden) = infer_shape(input_emb) + (batch_size, seq_len, _) = infer_shape(input_emb) (out_hidden_x3,) = infer_shape(bias) assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3" out_hidden = out_hidden_x3 // 3 From 768e5353bc792b247c8369e0bf720cc59a9e7543 Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 8 Apr 2022 15:27:40 -0700 Subject: [PATCH 08/17] comments --- python/tvm/relay/frontend/onnx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b5573d022626..126ca7918d3e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -856,7 +856,7 @@ def _impl_v1(cls, inputs, attr, params): mask = inputs[7] pos_ids = inputs[8] - eps = attr["epsilon"] if "epsilon" in attr else 1e-12 + eps = attr.get('epsilon', 1e-12) (batch_size, seq_len) = infer_shape(input_ids) @@ -4940,6 +4940,9 @@ def _get_convert_map(opset): "Elu": Elu.get_converter(opset), "Gelu": Gelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), + # TODO: We need a better way to handle different domains, in case + # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention + # are in the `com.microsoft` domain. "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset), "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset), "Attention": Attention.get_converter(opset), From 29e0c68996410f8f26147a1d10cf77bf1586fe27 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 8 Apr 2022 16:16:32 -0700 Subject: [PATCH 09/17] fix small thing --- python/tvm/relay/frontend/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 126ca7918d3e..e564f6aba112 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -856,7 +856,7 @@ def _impl_v1(cls, inputs, attr, params): mask = inputs[7] pos_ids = inputs[8] - eps = attr.get('epsilon', 1e-12) + eps = attr.get("epsilon", 1e-12) (batch_size, seq_len) = infer_shape(input_ids) @@ -906,7 +906,7 @@ def _impl_v1(cls, inputs, attr, params): beta = inputs[3] bias = inputs[4] - eps = attr["epsilon"] if "epsilon" in attr else 1e-12 + eps = attr.get("epsilon", 1e-12) x = _op.add(data, skip) if bias is not None: From dbb7df3bbc92a29fdfbb5a89fa576d49bae8f82c Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Sat, 9 Apr 2022 17:21:12 -0700 Subject: [PATCH 10/17] factor out layer norm computation --- python/tvm/relay/frontend/onnx.py | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e564f6aba112..d6ab14ff4e69 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -874,14 +874,7 @@ def _impl_v1(cls, inputs, attr, params): if segment_ids: vec_sum = _op.add(vec_sum, segment_vec) - eps_dtype = infer_type(word_emb).checked_type.dtype - - u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True) - ln = _op.divide( - _op.subtract(vec_sum, u), - _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), - ) - ln = _op.multiply(ln, gamma) + beta + ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma, beta) mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) if mask: @@ -898,6 +891,21 @@ class SkipLayerNormalization(OnnxOpConverter): normalization. """ + @staticmethod + def _compute_layer_norm(x, eps, gamma, beta): + eps_dtype = infer_type(x).checked_type.dtype + + u, s = _op.mean_variance(x, axis=-1, keepdims=True) + output = _op.divide( + _op.subtract(x, u), + _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), + ) + output = _op.multiply(output, gamma) + if beta is not None: + output = _op.add(output, beta) + + return output + @classmethod def _impl_v1(cls, inputs, attr, params): data = inputs[0] @@ -912,17 +920,9 @@ def _impl_v1(cls, inputs, attr, params): if bias is not None: x = _op.add(x, bias) - eps_dtype = infer_type(x).checked_type.dtype - - u, s = _op.mean_variance(x, axis=-1, keepdims=True) - output = _op.divide( - _op.subtract(x, u), - _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), - ) - output = _op.multiply(output, gamma) - if beta: - output = _op.add(output, beta) + output = SkipLayerNormalization._compute_layer_norm(x, eps, gamma, beta) + # onnxruntime doesn't compute the other outputs, despite the documentation placeholder = _op.const(0, dtype="float32") return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) From 265e75312799f3fce4565407e315a9475d7d3b87 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 11 Apr 2022 10:32:38 -0700 Subject: [PATCH 11/17] layernorm func --- python/tvm/relay/frontend/onnx.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index d6ab14ff4e69..957d848c8591 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3): return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype) +def layer_norm(x, eps, gamma, beta): + """Common function to handle layer norm""" + eps_dtype = infer_type(x).checked_type.dtype + + u, s = _op.mean_variance(x, axis=-1, keepdims=True) + output = _op.divide( + _op.subtract(x, u), + _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), + ) + output = _op.multiply(output, gamma) + if beta is not None: + output = _op.add(output, beta) + + return output + + class OnnxOpConverter(object): """A helper class for holding onnx op converters.""" @@ -874,7 +890,7 @@ def _impl_v1(cls, inputs, attr, params): if segment_ids: vec_sum = _op.add(vec_sum, segment_vec) - ln = SkipLayerNormalization._compute_layer_norm(vec_sum, eps, gamma, beta) + ln = layer_norm(vec_sum, eps, gamma, beta) mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) if mask: @@ -920,7 +936,7 @@ def _impl_v1(cls, inputs, attr, params): if bias is not None: x = _op.add(x, bias) - output = SkipLayerNormalization._compute_layer_norm(x, eps, gamma, beta) + output = layer_norm(x, eps, gamma, beta) # onnxruntime doesn't compute the other outputs, despite the documentation placeholder = _op.const(0, dtype="float32") From 43296f9c768cc2044b8aba7fc4886dc5e3d3814e Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 11 Apr 2022 12:30:57 -0700 Subject: [PATCH 12/17] add optional args to test --- python/tvm/relay/frontend/onnx.py | 22 +++++++--------------- tests/python/frontend/onnx/test_forward.py | 22 ++++++++++++++++------ 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 957d848c8591..bb863b64f243 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -907,21 +907,6 @@ class SkipLayerNormalization(OnnxOpConverter): normalization. """ - @staticmethod - def _compute_layer_norm(x, eps, gamma, beta): - eps_dtype = infer_type(x).checked_type.dtype - - u, s = _op.mean_variance(x, axis=-1, keepdims=True) - output = _op.divide( - _op.subtract(x, u), - _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), - ) - output = _op.multiply(output, gamma) - if beta is not None: - output = _op.add(output, beta) - - return output - @classmethod def _impl_v1(cls, inputs, attr, params): data = inputs[0] @@ -930,6 +915,10 @@ def _impl_v1(cls, inputs, attr, params): beta = inputs[3] bias = inputs[4] + assert ( + beta is not None and bias is not None + ), "SkipLayerNormalization import currently only supports required beta and bias" + eps = attr.get("epsilon", 1e-12) x = _op.add(data, skip) @@ -990,6 +979,9 @@ def _impl_v1(cls, inputs, attr, params): ), "output hidden size should be divisible by number of attention heads" head_size = out_hidden // num_heads + assert ( + mask_index is not None + ), "Attention import currently only supports required mask_index" mask_index_shape = infer_shape(mask_index) assert ( len(mask_index_shape) == 2 diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1229d5a859fc..1bfc0a43a4d8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -39,6 +39,10 @@ def get_input_data_shape_dict(graph_def, input_data): shape_dict = {} for i, _ in enumerate(input_data): input_names[i] = graph_def.graph.input[i].name + if input_data[i] is None: + # Skip adding input shape data when the input data is None; + # This is to enable optional arguments for onnx operators. + continue shape_dict[input_names[i]] = input_data[i].shape else: input_names = graph_def.graph.input[0].name @@ -5448,10 +5452,10 @@ def verify_embedlayernormalization( "EmbedLayerNormalization", inputs=[ "input_ids", - "segment_ids", + "" if segment_ids is None else "segment_ids", "word_embedding", "position_embedding", - "segment_embedding", + "" if segment_embedding is None else "segment_embedding", "gamma", "beta", ], @@ -5461,6 +5465,9 @@ def verify_embedlayernormalization( node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4)) + segment_ids_shape = [] if segment_ids is None else segment_ids.shape + segment_embedding_shape = [] if segment_embedding is None else segment_embedding.shape + graph = helper.make_graph( [node], "embedlayernormalization_test", @@ -5468,9 +5475,7 @@ def verify_embedlayernormalization( helper.make_tensor_value_info( "input_ids", TensorProto.INT32, list(input_ids.shape) ), - helper.make_tensor_value_info( - "segment_ids", TensorProto.INT32, list(segment_ids.shape) - ), + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, segment_ids_shape), helper.make_tensor_value_info( "word_embedding", TensorProto.FLOAT, list(word_embedding.shape) ), @@ -5478,7 +5483,7 @@ def verify_embedlayernormalization( "position_embedding", TensorProto.FLOAT, list(position_embedding.shape) ), helper.make_tensor_value_info( - "segment_embedding", TensorProto.FLOAT, list(segment_embedding.shape) + "segment_embedding", TensorProto.FLOAT, segment_embedding_shape ), helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)), helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)), @@ -5537,6 +5542,11 @@ def verify_embedlayernormalization( input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta ) + # Test with undefined segment embedding + verify_embedlayernormalization( + input_ids, None, word_embedding, position_embedding, None, gamma, beta + ) + @tvm.testing.parametrize_targets def test_attention(target, dev): From de7d9406d2e63730b0f4525733beaba64afccb8d Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 11 Apr 2022 14:07:56 -0700 Subject: [PATCH 13/17] upgrade onnxrt version --- docker/install/ubuntu_install_onnx.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index f94df2d64a17..8b7d7bcf177f 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -28,7 +28,7 @@ set -o pipefail # to onnx>=1.9, onnxoptimizer should also be installed. pip3 install \ onnx==1.10.2 \ - onnxruntime==1.9.0 \ + onnxruntime==1.11.0 \ onnxoptimizer==0.2.6 # torch depends on a number of other packages, but unhelpfully, does From 93aceb21fd23f4ae5e8d907ebd3a3d99869a5726 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 11 Apr 2022 14:11:58 -0700 Subject: [PATCH 14/17] no upgrade onnx --- docker/install/ubuntu_install_onnx.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index 8b7d7bcf177f..f94df2d64a17 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -28,7 +28,7 @@ set -o pipefail # to onnx>=1.9, onnxoptimizer should also be installed. pip3 install \ onnx==1.10.2 \ - onnxruntime==1.11.0 \ + onnxruntime==1.9.0 \ onnxoptimizer==0.2.6 # torch depends on a number of other packages, but unhelpfully, does From dfba87e1efc120e13ac7271b33ea6dd7f803de1e Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 11 Apr 2022 14:39:18 -0700 Subject: [PATCH 15/17] fix tests --- python/tvm/relay/frontend/onnx.py | 3 ++- tests/python/frontend/onnx/test_forward.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index bb863b64f243..1129794ce83d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -897,7 +897,8 @@ def _impl_v1(cls, inputs, attr, params): # calculate number of words per sentence mask_index = _op.sum(mask, axis=1) - return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3) + # TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum + return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2) class SkipLayerNormalization(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1bfc0a43a4d8..7019088579a1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -39,7 +39,7 @@ def get_input_data_shape_dict(graph_def, input_data): shape_dict = {} for i, _ in enumerate(input_data): input_names[i] = graph_def.graph.input[i].name - if input_data[i] is None: + if input_data[i] is None or len(input_data[i]) == 0: # Skip adding input shape data when the input data is None; # This is to enable optional arguments for onnx operators. continue @@ -5459,7 +5459,7 @@ def verify_embedlayernormalization( "gamma", "beta", ], - outputs=["output", "mask_index", "embedding_sum"], + outputs=["output", "mask_index"], domain="com.microsoft", ) @@ -5493,30 +5493,27 @@ def verify_embedlayernormalization( "output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size)) ), helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size]), - helper.make_tensor_value_info( - "embedding_sum", - TensorProto.FLOAT, - list((batch_size, sequence_length, hidden_size)), - ), ], ) model = helper.make_model(graph, producer_name="embedlayernormalization_test") + + # TODO(@anwang2009): onnxruntime v1.9.0 requires empty list for optional argument, + # but v1.10.0+ requires None instead. verify_with_ort_with_inputs( model, [ input_ids, - segment_ids, + [] if segment_ids is None else segment_ids, word_embedding, position_embedding, - segment_embedding, + [] if segment_embedding is None else segment_embedding, gamma, beta, ], [ (batch_size, sequence_length, hidden_size), batch_size, - (batch_size, sequence_length, hidden_size), ], target=target, dev=dev, From 989b41249d3ba8585bfd93ed581d9092395e1500 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 11 Apr 2022 14:56:43 -0700 Subject: [PATCH 16/17] int32 --- python/tvm/relay/frontend/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1129794ce83d..3450d489af70 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -880,7 +880,7 @@ def _impl_v1(cls, inputs, attr, params): assert segment_emb if pos_ids is None: - pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64") + pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32") word_vec = _op.take(word_emb, input_ids, axis=0) segment_vec = _op.take(segment_emb, segment_ids, axis=0) @@ -892,7 +892,7 @@ def _impl_v1(cls, inputs, attr, params): ln = layer_norm(vec_sum, eps, gamma, beta) - mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) + mask_index = _op.const(np.zeros((batch_size,), dtype="int32")) if mask: # calculate number of words per sentence mask_index = _op.sum(mask, axis=1) From 16a4d096dcae82ad09484b577291cbe9a6601522 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 12 Apr 2022 13:32:32 -0700 Subject: [PATCH 17/17] fix tests --- tests/python/frontend/onnx/test_forward.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7019088579a1..9a8647af48c1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -39,7 +39,7 @@ def get_input_data_shape_dict(graph_def, input_data): shape_dict = {} for i, _ in enumerate(input_data): input_names[i] = graph_def.graph.input[i].name - if input_data[i] is None or len(input_data[i]) == 0: + if input_data[i] is None or input_data[i].shape == (): # Skip adding input shape data when the input data is None; # This is to enable optional arguments for onnx operators. continue @@ -5504,10 +5504,10 @@ def verify_embedlayernormalization( model, [ input_ids, - [] if segment_ids is None else segment_ids, + np.empty(0, dtype="int32") if segment_ids is None else segment_ids, word_embedding, position_embedding, - [] if segment_embedding is None else segment_embedding, + np.empty(0, dtype="float32") if segment_embedding is None else segment_embedding, gamma, beta, ],