From a9ebcc68764bc8c850c895a941165277eea8ba87 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 19 Mar 2021 16:56:43 -0700 Subject: [PATCH] [torch] Use try_infer_value for clamp min/max --- python/tvm/relay/frontend/pytorch.py | 16 ++++++++++++++-- tests/python/frontend/pytorch/test_forward.py | 7 +++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fd0a07e35c15..8ae1e862ffd5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1688,8 +1688,20 @@ def pad(inputs, input_types): def clamp(self, inputs, input_types): data = inputs[0] - amin = inputs[1] if inputs[1] else np.finfo(np.float32).min - amax = inputs[2] if inputs[2] else np.finfo(np.float32).max + + def get_v(v, default_v): + if isinstance(v, _expr.Constant): + return float(v.data.asnumpy()) + if isinstance(v, _expr.Expr): + infer_v, success = try_infer_value(v, lambda ret: float(ret)) + if success: + return infer_v + if v is not None: + return v + return default_v + + amin = get_v(inputs[1], np.finfo(np.float32).min) + amax = get_v(inputs[2], np.finfo(np.float32).max) return _op.clip(data, amin, amax) def to(self, inputs, input_types): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 83c1698799c7..d0edfd9c8036 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2622,10 +2622,17 @@ class Clamp3(Module): def forward(self, *args): return torch.clamp(args[0], max=1.0) + class Clamp_MinExpr_MaxConstant(Module): + def forward(self, *args): + h, w = args[0].shape[2:] + amin = h / 100.0 + return torch.clamp(args[0], min=amin, max=w) + input_data = torch.rand(input_shape).float() verify_model(Clamp1().float().eval(), input_data=input_data) verify_model(Clamp2().float().eval(), input_data=input_data) verify_model(Clamp3().float().eval(), input_data=input_data) + verify_model(Clamp_MinExpr_MaxConstant().float().eval(), input_data=input_data) @tvm.testing.uses_gpu