diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index 68bf767557d5..1afc6e28775e 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -83,14 +83,13 @@ def convert_flatten(self, op): def convert_eltwise(self, op): """Convert Eltwise layer""" inputs = op.bottom - assert len(inputs) == 2, "input tensors length should be 2" + assert len(inputs) >= 2, "input tensors length should be larger than 2" + # gethering initial 2 input expressions lhs_expr = self.exp_tab.get_expr(inputs[0]) rhs_expr = self.exp_tab.get_expr(inputs[1]) - lhs_shape = _infer_shape(lhs_expr) rhs_shape = _infer_shape(rhs_expr) - assert lhs_shape == rhs_shape, "input tensors shape should be equal" eltwise_params = op.eltwise_param @@ -100,6 +99,11 @@ def convert_eltwise(self, op): if eltwise_type_dict[eltwise_type] == "PROD": out = _op.multiply(lhs_expr, rhs_expr) + # for rest inputs + for i in range(len(inputs) - 2): + extra_expr = self.exp_tab.get_expr(inputs[i + 2]) + assert _infer_shape(out) == _infer_shape(extra_expr) + out = _op.multiply(out, extra_expr) elif eltwise_type_dict[eltwise_type] == "SUM": if coeff: left_coeff_expr = self.exp_tab.new_const(np.asarray(coeff[0], np.float32)) @@ -109,8 +113,23 @@ def convert_eltwise(self, op): out = _op.add(lhs_expr_scale, rhs_expr_scale) else: out = _op.add(lhs_expr, rhs_expr) + # for rest inputs + for i in range(len(inputs) - 2): + extra_expr = self.exp_tab.get_expr(inputs[i + 2]) + assert _infer_shape(out) == _infer_shape(extra_expr) + if coeff: + coeff_expr = self.exp_tab.new_const(np.asarray(coeff[i + 2], np.float32)) + extra_expr_scale = _op.multiply(extra_expr, coeff_expr) + out = _op.add(out, extra_expr_scale) + else: + out = _op.add(out, extra_expr) elif eltwise_type_dict[eltwise_type] == "MAX": out = _op.maximum(lhs_expr, rhs_expr) + # for rest inputs + for i in range(len(inputs) - 2): + extra_expr = self.exp_tab.get_expr(inputs[i + 2]) + assert _infer_shape(out) == _infer_shape(extra_expr) + out = _op.maximum(out, extra_expr) else: raise tvm.error.OpNotImplemented( "eltwise_type {} is not supported for frontend Caffe.".format(eltwise_type) diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index 0027a6b41736..004306867196 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -511,6 +511,45 @@ def test_forward_Eltwise(): operation=1, coeff=[0.5, 1], ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=0, + ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=1, + ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=2, + ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=1, + coeff=[0.5, 1, 0.2, 1.8, 3.1, 0.1], + ) #######################################################################