Skip to content
124 changes: 124 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3747,6 +3747,95 @@ def weight_norm(self, inputs, input_types):
)
return weight_g * (weight_v / norm_v)

def inplace_copy(self, inputs, input_types):
source = inputs[0]
values = inputs[1]
accumulate = inputs[2]
if not accumulate:
mode = "update"
else:
mode = "add"

# Track slice and select calls
slice_and_select_calls = []
while True:
if isinstance(source, _expr.Call) and source.op.name in [
"strided_slice",
"take",
]:
slice_and_select_calls.append(source)
source = source.args[0]
else:
break
slice_and_select_calls = slice_and_select_calls[::-1]
source_shape = _infer_shape(source)

# Create index map
index_map = {}
squeezed_axes = []
for call in slice_and_select_calls:
if call.op.name == "strided_slice":
axes = call.attrs.axes
if axes is None:
axes = list(range(len(source_shape)))
begins = call.attrs.begin
ends = call.attrs.end
for axis, begin, end in zip(axes, begins, ends):
num_squeezed_axis = len([v for v in squeezed_axes if v <= axis])
axis += num_squeezed_axis
# Set range
if begin < 0:
begin = source_shape[axis] + begin
if end < 0:
end = source_shape[axis] + end
if begin == 0 and end == source_shape[axis]:
continue
index_map[axis] = (begin.value, end.value)
elif call.op.name == "take":
num_squeezed_axis = len([v for v in squeezed_axes if v <= axis])
axis = call.attrs.axis.value + num_squeezed_axis
idx = call.args[1]
assert isinstance(idx, _expr.Constant)
idx = idx.data.numpy().item()
if idx < 0:
idx = source_shape[axis] + idx
index_map[axis] = (idx, idx + 1)
values = _op.expand_dims(values, axis)
squeezed_axes.append(axis)
else:
pass
last_index_dim = np.max(list(index_map)).item()
for axis in range(last_index_dim + 1):
if axis not in index_map:
index_map[axis] = 0, source_shape[axis]

# Create indices
nelem = 1
for (begin, end) in index_map.values():
nelem *= end - begin
chunk_sizes = [nelem]
for i in range(1, last_index_dim + 1):
begin, end = index_map[i - 1]
chunk_sizes.append(chunk_sizes[-1] // (end - begin))
indices = []
for axis in range(last_index_dim + 1):
chunk_size = chunk_sizes[axis]
repeat = nelem // chunk_size
begin, end = index_map[axis]
step_size = chunk_size // (end - begin)
chunk = np.repeat(np.arange(begin, end), step_size)
chunk = np.concatenate([chunk] * repeat)
indices.append(chunk)
indices = np.stack(indices, axis=0).astype(np.int64)
new_shape = [indices.shape[0]] + [
index_map[i][1] - index_map[i][0] for i in range(last_index_dim + 1)
]
indices = np.resize(indices, new_shape)
indices = _expr.const(indices)

# Return
return _op.scatter_nd(source, indices, values, mode)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -4018,6 +4107,7 @@ def create_convert_map(self):
"aten::__rshift__": self.make_elemwise("right_shift"),
"aten::multinomial": self.multinomial,
"aten::_weight_norm": self.weight_norm,
"aten::copy_": self.inplace_copy,
}

def update_convert_map(self, custom_map):
Expand Down Expand Up @@ -4470,6 +4560,39 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True):
torch._C._jit_pass_lower_all_tuples(graph)


def _redirect_inplace_output(graph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give an example of what this pass does, by documenting IR before / after this pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok An example added

"""
This pass redirects the output node of the in-place op i.e. aten::copy_.
Before:
%1: ...
%2: ...
%3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2)
return (%input)
After:
%1: ...
%2: ...
%3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2)
return (%3)
"""
for node in graph.nodes():
if node.kind() == "aten::copy_":
node_inputs = list(node.inputs())
src_node = node_inputs[0].node()
slice_and_select_nodes = []
while True:
if src_node.kind() in ["aten::slice", "aten::select", "aten::unsqueeze"]:
src_node = list(src_node.inputs())[0].node()
slice_and_select_nodes.append(src_node)
else:
break
if src_node.kind() == "prim::Param":
# First one is "self"
src_value = list(src_node.outputs())[1]
else:
src_value = src_node.output()
src_value.replaceAllUsesAfterNodeWith(node, node.output())


def _get_tensor_and_var(torch_tensor, name):
tensor = tvm.nd.array(torch_tensor.cpu().numpy())
var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype)
Expand Down Expand Up @@ -4971,6 +5094,7 @@ def from_pytorch(
break

_run_jit_passes(graph, enable_lower_all_tuples)
_redirect_inplace_output(graph)

if custom_convert_map:
converter.update_convert_map(custom_convert_map)
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5355,6 +5355,32 @@ def forward(self, *args):
assert "%aten::_convolution_0" in graph


def test_inplace_copy():
class SimpleInplaceCopy(torch.nn.Module):
def forward(self, x):
x[:5, 0, 5:] = x[:5, 0, 5:] + 1
return x

class NegativeSliceInplaceCopy(torch.nn.Module):
def forward(self, x):
x[5:-1, -1, :] = x[5:-1, -1, :] + 1
return x

class PartialDimensionInplaceCopy(torch.nn.Module):
def forward(self, x):
x[:5] = x[:5] + 1
x[0:5, ...] = x[0:5, ...] + 1
x[0:5, ..., -1] = x[0:5, ..., -1] + 1
return x

inputs = torch.randn(10, 10, 10)
verify_model(SimpleInplaceCopy(), [inputs])
inputs = torch.randn(10, 10, 10)
verify_model(NegativeSliceInplaceCopy(), [inputs])
inputs = torch.randn(10, 10, 10)
verify_model(PartialDimensionInplaceCopy(), [inputs])


class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with span tagged."""

Expand Down