Skip to content

Commit 81f7da8

Browse files
authored
[Relax][PyTorch] Support prod, std and var ops for ExportedProgram importer (#17785)
* Update exported_program_translator.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py * Update test_frontend_from_exported_program.py
1 parent 95cbdaa commit 81f7da8

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ def create_convert_map(
353353
"upsample_nearest2d.vec": self._upsample_nearest2d,
354354
# statistical
355355
"mean.dim": self._mean,
356+
"prod.default": self._prod,
357+
"std.correction": self._std,
356358
"sum.dim_IntList": self._sum,
359+
"var.correction": self._var,
357360
# search
358361
"argmax.default": self._argmax_argmin(relax.op.argmax),
359362
"argmin.default": self._argmax_argmin(relax.op.argmin),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3704,5 +3704,68 @@ def main(
37043704
verify_model(Take(), example_args, {}, Expected)
37053705

37063706

3707+
def test_std():
3708+
class Std(Module):
3709+
def forward(self, x):
3710+
return torch.std(x)
3711+
3712+
@tvm.script.ir_module
3713+
class Expected:
3714+
@R.function
3715+
def main(
3716+
inp_0: R.Tensor((5, 3), dtype="float32"),
3717+
) -> R.Tuple(R.Tensor((), dtype="float32")):
3718+
with R.dataflow():
3719+
lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, keepdims=False)
3720+
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
3721+
R.output(gv)
3722+
return gv
3723+
3724+
example_args = (torch.randn(5, 3, dtype=torch.float32),)
3725+
verify_model(Std(), example_args, {}, Expected)
3726+
3727+
3728+
def test_var():
3729+
class Var(Module):
3730+
def forward(self, x):
3731+
return torch.var(x)
3732+
3733+
@tvm.script.ir_module
3734+
class Expected:
3735+
@R.function
3736+
def main(
3737+
inp_0: R.Tensor((5, 3), dtype="float32"),
3738+
) -> R.Tuple(R.Tensor((), dtype="float32")):
3739+
with R.dataflow():
3740+
lv: R.Tensor((), dtype="float32") = R.variance(inp_0, axis=None, keepdims=False)
3741+
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
3742+
R.output(gv)
3743+
return gv
3744+
3745+
example_args = (torch.randn(5, 3, dtype=torch.float32),)
3746+
verify_model(Var(), example_args, {}, Expected)
3747+
3748+
3749+
def test_prod():
3750+
class Prod(Module):
3751+
def forward(self, x):
3752+
return torch.prod(x)
3753+
3754+
@tvm.script.ir_module
3755+
class Expected:
3756+
@R.function
3757+
def main(
3758+
inp_0: R.Tensor((5, 3), dtype="float32"),
3759+
) -> R.Tuple(R.Tensor((), dtype="float32")):
3760+
with R.dataflow():
3761+
lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, keepdims=False)
3762+
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
3763+
R.output(gv)
3764+
return gv
3765+
3766+
example_args = (torch.randn(5, 3, dtype=torch.float32),)
3767+
verify_model(Prod(), example_args, {}, Expected)
3768+
3769+
37073770
if __name__ == "__main__":
37083771
tvm.testing.main()

0 commit comments

Comments
 (0)