Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,23 @@ def _pad(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value))

def _constant_pad_nd(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
pad = node.args[1]
value = node.args[2] if len(node.args) > 2 else node.kwargs.get("value", 0.0)
value = 0.0 if value is None else value

# Calculate symmetric padding width for each dimension
# and applying them in reverse order to match the input dimensions.
input_ndim = x.struct_info.ndim
pad_width = [0] * (input_ndim * 2)
pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)]
reversed_pairs = list(reversed(pad_pairs))
flattened = [v for pair in reversed_pairs for v in pair]
pad_width[-len(flattened) :] = flattened

return self.block_builder.emit(relax.op.nn.pad(x, pad_width, "constant", value))

def _pixel_shuffle(self, node: fx.Node) -> relax.Var:
data = self.env[node.args[0]]
upscale_factor = node.args[1]
Expand Down Expand Up @@ -1665,8 +1682,37 @@ def _index_put(self, node: fx.Node) -> relax.Var:

def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
data = args[0]
indices = args[1]
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))

# In PyTorch's aten.index.Tensor, None means "select all elements" for that dimension
non_none_indices = [(i, idx) for i, idx in enumerate(indices) if idx is not None]

# Special case: if there's only one non-None index, use take operation
if len(non_none_indices) == 1:
axis, index_tensor = non_none_indices[0]
return self.block_builder.emit(relax.op.take(data, index_tensor, axis=axis))

# General case: multiple non-None indices require advanced indexing
processed_indices = []
data_shape = self.shape_of(data)

for i, idx in enumerate(indices):
if idx is None:
dim_size = data_shape[i]
arange_idx = self.block_builder.emit(
relax.op.arange(
start=relax.PrimValue(0),
end=dim_size,
step=relax.PrimValue(1),
dtype="int64",
)
)
processed_indices.append(arange_idx)
else:
processed_indices.append(idx)

return self.block_builder.emit(relax.op.index_tensor(data, processed_indices))

def _meshgrid(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,8 @@ def create_convert_map(
"_log_softmax.default": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"pad.default": self._pad,
"constant_pad_nd.default": self._constant_pad_nd,
"copy.default": self._copy_,
"pixel_shuffle.default": self._pixel_shuffle,
"prelu.default": self._prelu,
"reciprocal.default": self._reciprocal,
Expand Down
201 changes: 179 additions & 22 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,13 +2715,25 @@ def main(
x: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
x,
pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
pad_mode="reflect",
pad_value=0.0,
lv: R.Tensor((14,), dtype="int64") = R.arange(
R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64"
)
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
lv1: R.Tensor((14,), dtype="int64") = R.abs(lv)
lv2: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv1)
lv3: R.Tensor((14,), dtype="int64") = R.abs(lv2)
lv4: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv3)
lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv4, axis=2, mode="fast")
lv6: R.Tensor((12,), dtype="int64") = R.arange(
R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64"
)
lv7: R.Tensor((12,), dtype="int64") = R.abs(lv6)
lv8: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv7)
lv9: R.Tensor((12,), dtype="int64") = R.abs(lv8)
lv10: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv9)
lv11: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take(
lv5, lv10, axis=3, mode="fast"
)
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv11,)
R.output(gv)
return gv

Expand All @@ -2732,13 +2744,19 @@ def main(
x: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
x,
pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
pad_mode="replicate",
pad_value=0.0,
lv: R.Tensor((14,), dtype="int64") = R.arange(
R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64"
)
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
lv1: R.Tensor((14,), dtype="int64") = R.clip(lv, R.prim_value(0), R.prim_value(9))
lv2: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv1, axis=2, mode="fast")
lv3: R.Tensor((12,), dtype="int64") = R.arange(
R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64"
)
lv4: R.Tensor((12,), dtype="int64") = R.clip(lv3, R.prim_value(0), R.prim_value(9))
lv5: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take(
lv2, lv4, axis=3, mode="fast"
)
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv5,)
R.output(gv)
return gv

Expand All @@ -2749,21 +2767,160 @@ def main(
x: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros(
R.shape([1, 3, 14, 12]), dtype="float32"
)
lv1: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice(
lv,
(R.prim_value(3),),
(R.prim_value(1),),
(R.prim_value(11),),
(R.prim_value(1),),
assume_inbound=False,
)
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
x,
pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
pad_mode="circular",
pad_value=0.0,
(R.prim_value(3),),
(R.prim_value(0),),
(R.prim_value(10),),
(R.prim_value(1),),
assume_inbound=False,
)
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
lv1,
(R.prim_value(2),),
(R.prim_value(2),),
(R.prim_value(12),),
(R.prim_value(1),),
assume_inbound=False,
)
lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
lv2,
(R.prim_value(2),),
(R.prim_value(0),),
(R.prim_value(10),),
(R.prim_value(1),),
assume_inbound=False,
)
lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice(
lv,
(R.prim_value(3),),
(R.prim_value(1),),
(R.prim_value(11),),
(R.prim_value(1),),
assume_inbound=False,
)
lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = R.slice_scatter(
lv5, lv4, R.prim_value(2), R.prim_value(12), R.prim_value(1), axis=2
)
lv7: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
lv, lv6, R.prim_value(1), R.prim_value(11), R.prim_value(1), axis=3
)
lv8: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
lv7,
(R.prim_value(3),),
(R.prim_value(0),),
(R.prim_value(1),),
(R.prim_value(1),),
assume_inbound=False,
)
lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
lv7,
(R.prim_value(3),),
(R.prim_value(10),),
(R.prim_value(11),),
(R.prim_value(1),),
assume_inbound=False,
)
lv10: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
lv7, lv9, R.prim_value(0), R.prim_value(1), R.prim_value(1), axis=3
)
lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
lv10,
(R.prim_value(3),),
(R.prim_value(11),),
(R.prim_value(12),),
(R.prim_value(1),),
assume_inbound=False,
)
lv12: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
lv10,
(R.prim_value(3),),
(R.prim_value(1),),
(R.prim_value(2),),
(R.prim_value(1),),
assume_inbound=False,
)
lv13: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
lv10, lv12, R.prim_value(11), R.prim_value(12), R.prim_value(1), axis=3
)
lv14: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
lv13,
(R.prim_value(2),),
(R.prim_value(0),),
(R.prim_value(2),),
(R.prim_value(1),),
assume_inbound=False,
)
lv15: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
lv13,
(R.prim_value(2),),
(R.prim_value(10),),
(R.prim_value(12),),
(R.prim_value(1),),
assume_inbound=False,
)
lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
lv13, lv15, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=2
)
lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
lv16,
(R.prim_value(2),),
(R.prim_value(12),),
(R.prim_value(14),),
(R.prim_value(1),),
assume_inbound=False,
)
lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
lv16,
(R.prim_value(2),),
(R.prim_value(2),),
(R.prim_value(4),),
(R.prim_value(1),),
assume_inbound=False,
)
lv19: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
lv16, lv18, R.prim_value(12), R.prim_value(14), R.prim_value(1), axis=2
)
gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv19,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant)
verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect)
verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate)
verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular)
verify_model(
PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant, run_ep_decomposition=True
)
verify_model(
PadModel(pad=[1, 1, 2, 2], mode="reflect"),
example_args,
{},
expected_reflect,
run_ep_decomposition=True,
)
verify_model(
PadModel(pad=[1, 1, 2, 2], mode="replicate"),
example_args,
{},
expected_replicate,
run_ep_decomposition=True,
)
verify_model(
PadModel(pad=[1, 1, 2, 2], mode="circular"),
example_args,
{},
expected_circular,
run_ep_decomposition=True,
)


def test_pixel_shuffle():
Expand Down Expand Up @@ -5949,7 +6106,7 @@ def main(
) -> R.Tuple(R.Tensor((3,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((5,), dtype="float32") = R.reshape(data, R.shape([5]))
lv1: R.Tensor((3,), dtype="float32") = R.index_tensor(lv, (indices,))
lv1: R.Tensor((3,), dtype="float32") = R.take(lv, indices, axis=0, mode="fast")
gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
R.output(gv)
return gv
Expand Down
Loading