From 98f470fa2cbce9cacbc39e5a3c5c2faee337e95f Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Mon, 7 Aug 2023 20:17:40 +0900 Subject: [PATCH 01/11] add handling logic for aten::copy_ --- python/tvm/relay/frontend/pytorch.py | 69 ++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index eadd0a3c464f..bea781f9c033 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3743,6 +3743,54 @@ def weight_norm(self, inputs, input_types): reci_order, ) return weight_g * (weight_v / norm_v) + + def inplace_copy(self, inputs, input_types): + source = inputs[0] + slice_and_select_calls = [] + while True: + if isinstance(source, _expr.Call) and source.op.name in [ + "strided_slice", + "take", + "expand_dims", + ]: + slice_and_select_calls.append(source) + source = source.args[0] + else: + break + slice_and_select_calls = slice_and_select_calls[::-1] + source_shape = _infer_shape(source) + indices = [[j for j in range(source_shape[i])] for i in range(len(source_shape))] + + for call in slice_and_select_calls: + if call.op.name == "strided_slice": + """ + spec. of strided_slice: + https://tvm.apache.org/docs/reference/api/python/relay/index.html + """ + axes = call.attrs.axes + if axes is None: + axes = [i for i in range(len(source_shape))] + begins = call.attrs.begin + ends = call.attrs.end + for axis, begin, end in zip(axes, begins, ends): + if begin < 0: + begin = source_shape[axis] + begin + if end < 0: + end = source_shape[axis] + end + indices[axis] = [v for v in indices[axis] if begin <= v < end] + elif call.op.name == "take": + axis = call.attrs.axis.value + idx = call.args[1] + assert isinstance(idx, _expr.Constant) + idx = idx.data.numpy().item() + indices[axis] = [idx] + else: + pass + indices = tuple([_expr.const(i) for i in indices]) + return self.index_put( + [source, indices, inputs[1], inputs[2]], + [input_types[0], "int64", input_types[1], input_types[2]], + ) # Operator mappings def create_convert_map(self): @@ -4015,6 +4063,7 @@ def create_convert_map(self): "aten::__rshift__": self.make_elemwise("right_shift"), "aten::multinomial": self.multinomial, "aten::_weight_norm": self.weight_norm, + "aten::copy_": self.inplace_copy, } def update_convert_map(self, custom_map): @@ -4459,6 +4508,25 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True): torch._C._jit_pass_lower_all_tuples(graph) +def _redirect_inplace_output(graph): + for node in graph.nodes(): + if node.kind() == "aten::copy_": + node_inputs = list(node.inputs()) + src_node = node_inputs[0].node() + slice_and_select_nodes = [] + while True: + if src_node.kind() in ["aten::slice", "aten::select", "aten::unsqueeze"]: + src_node = list(src_node.inputs())[0].node() + slice_and_select_nodes.append(src_node) + else: + break + if src_node.kind() == "prim::Param": + # First one is "self" + src_value = list(src_node.outputs())[1] + else: + src_value = src_node.output() + src_value.replaceAllUsesAfterNodeWith(node, node.output()) + def _get_tensor_and_var(torch_tensor, name): tensor = tvm.nd.array(torch_tensor.cpu().numpy()) var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype) @@ -4960,6 +5028,7 @@ def from_pytorch( break _run_jit_passes(graph, enable_lower_all_tuples) + _redirect_inplace_output(graph) if custom_convert_map: converter.update_convert_map(custom_convert_map) From 5c2263516d0cbbd3f604fed126b1eec0b3acffda Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Mon, 7 Aug 2023 20:35:55 +0900 Subject: [PATCH 02/11] lint --- python/tvm/relay/frontend/pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index bea781f9c033..63b2ee2d8248 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3743,7 +3743,7 @@ def weight_norm(self, inputs, input_types): reci_order, ) return weight_g * (weight_v / norm_v) - + def inplace_copy(self, inputs, input_types): source = inputs[0] slice_and_select_calls = [] @@ -4527,6 +4527,7 @@ def _redirect_inplace_output(graph): src_value = src_node.output() src_value.replaceAllUsesAfterNodeWith(node, node.output()) + def _get_tensor_and_var(torch_tensor, name): tensor = tvm.nd.array(torch_tensor.cpu().numpy()) var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype) From f37050e1605e130f835d3515447d59ff024698e4 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Mon, 7 Aug 2023 21:14:09 +0900 Subject: [PATCH 03/11] add test case --- tests/python/frontend/pytorch/test_forward.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index cb49e837fe6e..7c998c3eab54 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5330,6 +5330,19 @@ def forward(self, *args): assert "%aten::_convolution_0" in graph +def test_inplace_copy(): + class InplaceCopy(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + x[:5, 0] = x[:5, 0] + 1 + return x + + inputs = torch.randn(10, 10) + verify_model(InplaceCopy(), [inputs]) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged.""" From aed8863015a311085f2829c05dcf48cbd44d1e1e Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Tue, 8 Aug 2023 13:23:59 +0900 Subject: [PATCH 04/11] lint --- python/tvm/relay/frontend/pytorch.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 63b2ee2d8248..15a89f54760b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3759,17 +3759,13 @@ def inplace_copy(self, inputs, input_types): break slice_and_select_calls = slice_and_select_calls[::-1] source_shape = _infer_shape(source) - indices = [[j for j in range(source_shape[i])] for i in range(len(source_shape))] + indices = [list(range(source_shape[i])) for i in range(len(source_shape))] for call in slice_and_select_calls: if call.op.name == "strided_slice": - """ - spec. of strided_slice: - https://tvm.apache.org/docs/reference/api/python/relay/index.html - """ axes = call.attrs.axes if axes is None: - axes = [i for i in range(len(source_shape))] + axes = list(range(len(source_shape))) begins = call.attrs.begin ends = call.attrs.end for axis, begin, end in zip(axes, begins, ends): From 3210e94c8363d9714e7368f6fa6b495b393dfd7e Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Tue, 8 Aug 2023 14:19:16 +0900 Subject: [PATCH 05/11] remove __init__ --- tests/python/frontend/pytorch/test_forward.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 7c998c3eab54..f5d87d83cbd4 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5332,9 +5332,6 @@ def forward(self, *args): def test_inplace_copy(): class InplaceCopy(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - def forward(self, x): x[:5, 0] = x[:5, 0] + 1 return x From 4aac330466a9acd749ad9457282f50227f78d178 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Wed, 11 Oct 2023 18:06:06 +0900 Subject: [PATCH 06/11] fix logic --- python/tvm/relay/frontend/pytorch.py | 63 +++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 15a89f54760b..b4150209097d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3746,12 +3746,19 @@ def weight_norm(self, inputs, input_types): def inplace_copy(self, inputs, input_types): source = inputs[0] + values = inputs[1] + accumulate = inputs[2] + if not accumulate: + mode = "update" + else: + mode = "add" + + # Track slice and select calls slice_and_select_calls = [] while True: if isinstance(source, _expr.Call) and source.op.name in [ "strided_slice", "take", - "expand_dims", ]: slice_and_select_calls.append(source) source = source.args[0] @@ -3759,8 +3766,10 @@ def inplace_copy(self, inputs, input_types): break slice_and_select_calls = slice_and_select_calls[::-1] source_shape = _infer_shape(source) - indices = [list(range(source_shape[i])) for i in range(len(source_shape))] + # Create index map + index_map = {} + squeezed_axes = [] for call in slice_and_select_calls: if call.op.name == "strided_slice": axes = call.attrs.axes @@ -3769,24 +3778,58 @@ def inplace_copy(self, inputs, input_types): begins = call.attrs.begin ends = call.attrs.end for axis, begin, end in zip(axes, begins, ends): + num_squeezed_axis = len([v for v in squeezed_axes if v <= axis]) + axis += num_squeezed_axis + # Set range if begin < 0: begin = source_shape[axis] + begin if end < 0: end = source_shape[axis] + end - indices[axis] = [v for v in indices[axis] if begin <= v < end] + if begin == 0 and end == source_shape[axis]: + continue + index_map[axis] = (begin.value, end.value) elif call.op.name == "take": - axis = call.attrs.axis.value + num_squeezed_axis = len([v for v in squeezed_axes if v <= axis]) + axis = call.attrs.axis.value + num_squeezed_axis idx = call.args[1] assert isinstance(idx, _expr.Constant) idx = idx.data.numpy().item() - indices[axis] = [idx] + index_map[axis] = (idx, idx + 1) + values = _op.expand_dims(values, axis) + squeezed_axes.append(axis) else: pass - indices = tuple([_expr.const(i) for i in indices]) - return self.index_put( - [source, indices, inputs[1], inputs[2]], - [input_types[0], "int64", input_types[1], input_types[2]], - ) + last_index_dim = np.max([k for k in index_map]).item() + for axis in range(last_index_dim + 1): + if axis not in index_map: + index_map[axis] = (0, source_shape[axis]) + + # Create indices + nelem = 1 + for (begin, end) in index_map.values(): + nelem *= end - begin + chunk_sizes = [nelem] + for i in range(1, last_index_dim + 1): + begin, end = index_map[i - 1] + chunk_sizes.append(chunk_sizes[-1] // (end - begin)) + indices = [] + for axis in range(last_index_dim + 1): + chunk_size = chunk_sizes[axis] + repeat = nelem // chunk_size + begin, end = index_map[axis] + step_size = chunk_size // (end - begin) + chunk = np.repeat(np.arange(begin, end), step_size) + chunk = np.concatenate([chunk] * repeat) + indices.append(chunk) + indices = np.stack(indices, axis=0).astype(np.int64) + new_shape = [indices.shape[0]] + [ + index_map[i][1] - index_map[i][0] for i in range(last_index_dim + 1) + ] + indices = np.resize(indices, new_shape) + indices = _expr.const(indices) + + # Return + return _op.scatter_nd(source, indices, values, mode) # Operator mappings def create_convert_map(self): From ccec5c06f5f66744aca40352efffd72b26bfd425 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Wed, 11 Oct 2023 18:58:38 +0900 Subject: [PATCH 07/11] lint --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b4150209097d..e4310d53fd78 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3803,7 +3803,7 @@ def inplace_copy(self, inputs, input_types): for axis in range(last_index_dim + 1): if axis not in index_map: index_map[axis] = (0, source_shape[axis]) - + # Create indices nelem = 1 for (begin, end) in index_map.values(): From 3c661a9ef3a6535018092a3b2f0d5fb0f2a4af24 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Wed, 11 Oct 2023 20:20:35 +0900 Subject: [PATCH 08/11] lint --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e4310d53fd78..43e57edd4d85 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3802,7 +3802,7 @@ def inplace_copy(self, inputs, input_types): last_index_dim = np.max([k for k in index_map]).item() for axis in range(last_index_dim + 1): if axis not in index_map: - index_map[axis] = (0, source_shape[axis]) + index_map[axis] = 0, source_shape[axis] # Create indices nelem = 1 From b8b3aae07600036234837afa4e85d7fe7656e7c9 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Wed, 11 Oct 2023 20:47:48 +0900 Subject: [PATCH 09/11] lint --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 43e57edd4d85..8385da85b0f1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3799,7 +3799,7 @@ def inplace_copy(self, inputs, input_types): squeezed_axes.append(axis) else: pass - last_index_dim = np.max([k for k in index_map]).item() + last_index_dim = np.max(list(index_map)).item() for axis in range(last_index_dim + 1): if axis not in index_map: index_map[axis] = 0, source_shape[axis] From 8f1735ebc57e811b7034cc89b5a31b2cf0758957 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Tue, 17 Oct 2023 16:30:54 +0900 Subject: [PATCH 10/11] feedback --- python/tvm/relay/frontend/pytorch.py | 15 ++++++++++++ tests/python/frontend/pytorch/test_forward.py | 24 +++++++++++++++---- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 61c3607526f5..080107d8931f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3797,6 +3797,8 @@ def inplace_copy(self, inputs, input_types): idx = call.args[1] assert isinstance(idx, _expr.Constant) idx = idx.data.numpy().item() + if idx < 0: + idx = source_shape[axis] + idx index_map[axis] = (idx, idx + 1) values = _op.expand_dims(values, axis) squeezed_axes.append(axis) @@ -4559,6 +4561,19 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True): def _redirect_inplace_output(graph): + """ + This pass redirects the output node of the in-place op i.e. aten::copy_. + Before: + %1: ... + %2: ... + %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2) + return (%input) + After: + %1: ... + %2: ... + %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2) + return (%3) + """ for node in graph.nodes(): if node.kind() == "aten::copy_": node_inputs = list(node.inputs()) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 90d0485cd9cc..eb4f2ab41782 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5356,13 +5356,29 @@ def forward(self, *args): def test_inplace_copy(): - class InplaceCopy(torch.nn.Module): + class SimpleInplaceCopy(torch.nn.Module): def forward(self, x): - x[:5, 0] = x[:5, 0] + 1 + x[:5, 0, 5:] = x[:5, 0, 5:] + 1 + return x + + class NegativeSliceInplaceCopy(torch.nn.Module): + def forward(self, x): + x[5:-1, -1, :] = x[5:-1, -1, :] + 1 + return x + + class PartialDimensionInplaceCopy(torch.nn.Module): + def forward(self, x): + x[:5] = x[:5] + 1 + x[0:5, ...] = x[0:5, ...] + 1 + x[0:5, ..., -1] = x[0:5, ..., -1] + 1 return x - inputs = torch.randn(10, 10) - verify_model(InplaceCopy(), [inputs]) + inputs = torch.randn(10, 10, 10) + verify_model(SimpleInplaceCopy(), [inputs]) + inputs = torch.randn(10, 10, 10) + verify_model(NegativeSliceInplaceCopy(), [inputs]) + inputs = torch.randn(10, 10, 10) + verify_model(PartialDimensionInplaceCopy(), [inputs]) class TestSetSpan: From 7d7da85936d028a9d6c43ea176f06fd8275daa19 Mon Sep 17 00:00:00 2001 From: jhlee525 Date: Tue, 17 Oct 2023 19:06:18 +0900 Subject: [PATCH 11/11] lint --- python/tvm/relay/frontend/pytorch.py | 4 ++-- tests/python/frontend/pytorch/test_forward.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 080107d8931f..81392a08ecd1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4565,12 +4565,12 @@ def _redirect_inplace_output(graph): This pass redirects the output node of the in-place op i.e. aten::copy_. Before: %1: ... - %2: ... + %2: ... %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2) return (%input) After: %1: ... - %2: ... + %2: ... %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2) return (%3) """ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index eb4f2ab41782..abdbda8e4005 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5360,7 +5360,7 @@ class SimpleInplaceCopy(torch.nn.Module): def forward(self, x): x[:5, 0, 5:] = x[:5, 0, 5:] + 1 return x - + class NegativeSliceInplaceCopy(torch.nn.Module): def forward(self, x): x[5:-1, -1, :] = x[5:-1, -1, :] + 1