Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,54 @@ def _max_pool3d(self, node: fx.Node) -> relax.Var:
ceil_mode = args[5] if len(args) > 5 else False
return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

def _max_pool1d_with_indices(self, node: fx.Node) -> relax.Var:
# max_pool1d_with_indices returns (output, indices)
# We only compute the output and create a placeholder for indices
args = self.retrieve_args(node)
x = args[0]
kernel_size = args[1]
stride = args[2] if len(args) > 2 else None
padding = args[3] if len(args) > 3 else 0
dilation = args[4] if len(args) > 4 else 1
ceil_mode = args[5] if len(args) > 5 else False

output = self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)
# Create a placeholder for indices (empty tensor with same shape as output)
indices = relax.op.zeros_like(output)
return self.block_builder.emit(relax.Tuple([output, indices]))

def _max_pool2d_with_indices(self, node: fx.Node) -> relax.Var:
# max_pool2d_with_indices returns (output, indices)
# We only compute the output and create a placeholder for indices
args = self.retrieve_args(node)
x = args[0]
kernel_size = args[1]
stride = args[2] if len(args) > 2 else None
padding = args[3] if len(args) > 3 else 0
dilation = args[4] if len(args) > 4 else 1
ceil_mode = args[5] if len(args) > 5 else False

output = self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)
# Create a placeholder for indices (empty tensor with same shape as output)
indices = relax.op.zeros_like(output)
return self.block_builder.emit(relax.Tuple([output, indices]))

def _max_pool3d_with_indices(self, node: fx.Node) -> relax.Var:
# max_pool3d_with_indices returns (output, indices)
# We only compute the output and create a placeholder for indices
args = self.retrieve_args(node)
x = args[0]
kernel_size = args[1]
stride = args[2] if len(args) > 2 else None
padding = args[3] if len(args) > 3 else 0
dilation = args[4] if len(args) > 4 else 1
ceil_mode = args[5] if len(args) > 5 else False

output = self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)
# Create a placeholder for indices (empty tensor with same shape as output)
indices = relax.op.zeros_like(output)
return self.block_builder.emit(relax.Tuple([output, indices]))

def _pad(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
pad = node.args[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,9 @@ def create_convert_map(
"gru.input": self._gru,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool2d_with_indices.default": self._max_pool2d_with_indices,
"max_pool3d.default": self._max_pool3d,
"max_pool3d_with_indices.default": self._max_pool3d_with_indices,
"scaled_dot_product_attention.default": self._scaled_dot_product_attention,
"unbind.int": self._unbind,
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
Expand Down
152 changes: 105 additions & 47 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -3163,16 +3163,24 @@ def main(
input_1: R.Tensor((1, 3, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
with R.dataflow():
lv = R.nn.max_pool1d(
input_1,
pool_size=[2],
strides=[2],
dilation=[1],
padding=[0, 0],
layout="NCW",
out_layout="NCW",
lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2])
lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
lv,
pool_size=[1, 2],
strides=[1, 2],
dilation=[1, 1],
padding=[0, 0, 0, 0],
layout="NCHW",
out_layout="NCHW",
)
gv = (lv,)
lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1)
lv3: R.Tuple(
R.Tensor((1, 3, 1, 4), dtype="float32"),
R.Tensor((1, 3, 1, 4), dtype="float32"),
) = (lv1, lv2)
lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2])
gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
R.output(gv)
return gv

Expand All @@ -3183,16 +3191,24 @@ def main(
input_1: R.Tensor((1, 3, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
with R.dataflow():
lv = R.nn.max_pool1d(
input_1,
pool_size=[2],
strides=[2],
dilation=[1],
padding=[0, 0],
layout="NCW",
out_layout="NCW",
lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2])
lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
lv,
pool_size=[1, 2],
strides=[1, 2],
dilation=[1, 1],
padding=[0, 0, 0, 0],
layout="NCHW",
out_layout="NCHW",
)
gv = (lv,)
lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1)
lv3: R.Tuple(
R.Tensor((1, 3, 1, 4), dtype="float32"),
R.Tensor((1, 3, 1, 4), dtype="float32"),
) = (lv1, lv2)
lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2])
gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
R.output(gv)
return gv

Expand All @@ -3203,16 +3219,24 @@ def main(
input_1: R.Tensor((1, 3, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
with R.dataflow():
lv = R.nn.max_pool1d(
input_1,
pool_size=[3],
strides=[2],
dilation=[1],
padding=[0, 0],
layout="NCW",
out_layout="NCW",
lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2])
lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
lv,
pool_size=[1, 3],
strides=[1, 2],
dilation=[1, 1],
padding=[0, 0, 0, 0],
layout="NCHW",
out_layout="NCHW",
)
gv = (lv,)
lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1)
lv3: R.Tuple(
R.Tensor((1, 3, 1, 4), dtype="float32"),
R.Tensor((1, 3, 1, 4), dtype="float32"),
) = (lv1, lv2)
lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2])
gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
R.output(gv)
return gv

Expand All @@ -3222,9 +3246,9 @@ def main(
example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)

# Verify the models
verify_model(MaxPool1d(), example_args1, {}, expected1)
verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
verify_model(MaxPool1d2(), example_args3, {}, expected3)
verify_model(MaxPool1d(), example_args1, {}, expected1, run_ep_decomposition=True)
verify_model(MaxPool1d_functional(), example_args2, {}, expected2, run_ep_decomposition=True)
verify_model(MaxPool1d2(), example_args3, {}, expected3, run_ep_decomposition=True)


def test_maxpool2d():
Expand Down Expand Up @@ -3260,7 +3284,13 @@ def main(
layout="NCHW",
out_layout="NCHW",
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros_like(lv)
lv2: R.Tuple(
R.Tensor((1, 3, 10, 10), dtype="float32"),
R.Tensor((1, 3, 10, 10), dtype="float32"),
) = (lv, lv1)
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2[0]
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
R.output(gv)
return gv

Expand Down Expand Up @@ -3289,7 +3319,12 @@ def main(
layout="NCHW",
out_layout="NCHW",
)
gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,)
lv1: R.Tensor((1, 3, 4, 4), dtype="float32") = R.zeros_like(lv)
lv2: R.Tuple(
R.Tensor((1, 3, 4, 4), dtype="float32"), R.Tensor((1, 3, 4, 4), dtype="float32")
) = (lv, lv1)
lv3: R.Tensor((1, 3, 4, 4), dtype="float32") = lv2[0]
gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv3,)
R.output(gv)
return gv

Expand Down Expand Up @@ -3318,15 +3353,20 @@ def main(
layout="NCHW",
out_layout="NCHW",
)
gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,)
lv1: R.Tensor((1, 3, 6, 6), dtype="float32") = R.zeros_like(lv)
lv2: R.Tuple(
R.Tensor((1, 3, 6, 6), dtype="float32"), R.Tensor((1, 3, 6, 6), dtype="float32")
) = (lv, lv1)
lv3: R.Tensor((1, 3, 6, 6), dtype="float32") = lv2[0]
gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv3,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(MaxPool2d(), example_args, {}, expected1)
verify_model(MaxPool2d_functional(), example_args, {}, expected1)
verify_model(MaxPool2d2(), example_args, {}, expected2)
verify_model(MaxPool2d3(), example_args, {}, expected3)
verify_model(MaxPool2d(), example_args, {}, expected1, run_ep_decomposition=True)
verify_model(MaxPool2d_functional(), example_args, {}, expected1, run_ep_decomposition=True)
verify_model(MaxPool2d2(), example_args, {}, expected2, run_ep_decomposition=True)
verify_model(MaxPool2d3(), example_args, {}, expected3, run_ep_decomposition=True)


def test_maxpool3d():
Expand All @@ -3352,7 +3392,7 @@ def main(
input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
with R.dataflow():
lv = R.nn.max_pool3d(
lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.max_pool3d(
input_1,
pool_size=[1, 1, 1],
strides=[1, 1, 1],
Expand All @@ -3361,7 +3401,13 @@ def main(
layout="NCDHW",
out_layout="NCDHW",
)
gv = (lv,)
lv1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.zeros_like(lv)
lv2: R.Tuple(
R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
) = (lv, lv1)
lv3: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = lv2[0]
gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv3,)
R.output(gv)
return gv

Expand All @@ -3380,7 +3426,7 @@ def main(
input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")):
with R.dataflow():
lv = R.nn.max_pool3d(
lv: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.nn.max_pool3d(
input_1,
pool_size=[2, 2, 2],
strides=[2, 2, 2],
Expand All @@ -3389,7 +3435,13 @@ def main(
layout="NCDHW",
out_layout="NCDHW",
)
gv = (lv,)
lv1: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.zeros_like(lv)
lv2: R.Tuple(
R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
) = (lv, lv1)
lv3: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = lv2[0]
gv: R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")) = (lv3,)
R.output(gv)
return gv

Expand All @@ -3408,7 +3460,7 @@ def main(
input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")):
with R.dataflow():
lv = R.nn.max_pool3d(
lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.nn.max_pool3d(
input_1,
pool_size=[3, 3, 3],
strides=[2, 2, 2],
Expand All @@ -3417,7 +3469,13 @@ def main(
layout="NCDHW",
out_layout="NCDHW",
)
gv = (lv,)
lv1: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.zeros_like(lv)
lv2: R.Tuple(
R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
) = (lv, lv1)
lv3: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv2[0]
gv: R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")) = (lv3,)
R.output(gv)
return gv

Expand All @@ -3427,10 +3485,10 @@ def main(
example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)

# Verify the models with expected IR modules
verify_model(MaxPool3d(), example_args1, {}, expected1)
verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
verify_model(MaxPool3d2(), example_args2, {}, expected2)
verify_model(MaxPool3d3(), example_args3, {}, expected3)
verify_model(MaxPool3d(), example_args1, {}, expected1, run_ep_decomposition=True)
verify_model(MaxPool3d_functional(), example_args1, {}, expected1, run_ep_decomposition=True)
verify_model(MaxPool3d2(), example_args2, {}, expected2, run_ep_decomposition=True)
verify_model(MaxPool3d3(), example_args3, {}, expected3, run_ep_decomposition=True)


def test_scaled_dot_product_attention():
Expand Down
Loading