Skip to content

Commit b24fbe7

Browse files
author
liyuchao
authored
Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#7845)
* [AutoScheduler] Fix incorrectly array context device and hide info at the beginning * Lint fix * Lint fix * update repo * Fix Pytorch matmul conversion when given (2-dim, N-dim) input pair * update measure.py * Lint fix * fix bug && add ut for pytorch matmul * update ut * Lint fix * update commit * Lint fix
1 parent 36b7dd9 commit b24fbe7

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,7 @@ def matmul(self, inputs, input_types):
15801580
b_shape = self.infer_shape_with_prelude(inputs_1)
15811581

15821582
# When performing a batch matmul, we need to properly handle N-dim shapes.
1583-
if len(a_shape) > 2 or len(b_shape) > 2:
1583+
if len(a_shape) > 2 and len(b_shape) > 2:
15841584
# Convert a into a 3 dimensional tensors.
15851585
need_reshape_output = False
15861586
if len(a_shape) != 3:
@@ -1606,18 +1606,32 @@ def matmul(self, inputs, input_types):
16061606
if need_reshape_output:
16071607
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
16081608
return output
1609+
elif len(a_shape) > 2:
1610+
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
16091611

1610-
# Otherwise a simple dense op will get the job done.
1611-
if len(b_shape) == 1:
1612-
input_1 = _op.expand_dims(inputs_1, 0, 1)
1613-
else:
1612+
if len(b_shape) > 2:
1613+
trans_axes = list(range(len(b_shape)))
1614+
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
1615+
input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
1616+
elif len(b_shape) == 2:
16141617
input_1 = _op.transpose(inputs_1, axes=(1, 0))
1618+
elif len(b_shape) == 1:
1619+
input_1 = _op.expand_dims(inputs_1, 0, 1)
16151620

16161621
out = _op.nn.dense(inputs_0, input_1)
16171622

16181623
if len(b_shape) == 1:
16191624
out = _op.squeeze(out, axis=[-1])
16201625

1626+
# Reshape output into a N dimensional tensor when a or b dim > 2
1627+
if len(a_shape) > 2:
1628+
out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
1629+
elif len(b_shape) > 2:
1630+
out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
1631+
out = _op.reshape(
1632+
_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
1633+
)
1634+
16211635
return out
16221636

16231637
def expand(self, inputs, input_types):

tests/python/frontend/pytorch/test_forward.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
162162
return est
163163

164164

165-
def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
165+
def verify_model(
166+
model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]
167+
):
166168
"""Assert that the output of a compiled model matches with that of its
167169
baseline."""
168170
if isinstance(model_name, str):
@@ -219,6 +221,20 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
219221

220222
assert_shapes_match(baseline_output, compiled_output)
221223
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
224+
225+
if expected_ops:
226+
227+
def visit(op):
228+
if isinstance(op, tvm.ir.op.Op):
229+
if op.name in expected_ops:
230+
expected_ops.remove(op.name)
231+
232+
tvm.relay.analysis.post_order_visit(mod["main"].body, visit)
233+
234+
if expected_ops:
235+
msg = "TVM Relay do not contain expected ops {}"
236+
raise AssertionError(msg.format(expected_ops))
237+
222238
del model_name
223239
del baseline_model
224240
torch.cuda.empty_cache()
@@ -3304,17 +3320,24 @@ def forward(self, *args):
33043320
# matrix x matrix
33053321
tensor1 = torch.randn(10, 4)
33063322
tensor2 = torch.randn(4, 10)
3307-
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
3323+
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
33083324

33093325
# batched matrix x batched matrix
33103326
tensor1 = torch.randn(10, 3, 4)
33113327
tensor2 = torch.randn(10, 4, 5)
3312-
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
3328+
verify_model(
3329+
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
3330+
)
33133331

33143332
# batched matrix x broadcasted matrix
33153333
tensor1 = torch.randn(10, 3, 4)
33163334
tensor2 = torch.randn(4, 5)
3317-
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
3335+
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
3336+
3337+
# broadcasted matrix x batched matrix
3338+
tensor1 = torch.randn(10, 4)
3339+
tensor2 = torch.randn(3, 4, 5)
3340+
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
33183341

33193342
# batched matrix x batched matrix
33203343
tensor1 = torch.randn(1, 12, 14, 64)

0 commit comments

Comments
 (0)