diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 82ca544a9f66..5e93ea1ff0aa 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -199,6 +199,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { return *channels; } +/*! + * \brief Is single value tensor (scalar). + * \param expr The expr. + * \return True if single value tensor. + */ +inline bool IsScalar(const Expr& expr) { + if (auto tensor_type = expr->checked_type().as()) { + for (auto dim_index_expr : tensor_type->shape) { + if (auto dim_index = dim_index_expr.as()) { + if (dim_index->value != 1) { + return false; + } + } else { + return false; + } + } + } else { + return false; + } + return true; +} + /*! * \brief Create a Constant with a scalar * diff --git a/src/relay/pass/transform_layout.h b/src/relay/pass/transform_layout.h index 21a82a603c20..f6c5e9af6d62 100644 --- a/src/relay/pass/transform_layout.h +++ b/src/relay/pass/transform_layout.h @@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef { Expr input_expr = raw; Layout new_src_layout = src_layout; if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { + // If scalar, then no need of layout transformation as scalar can be broadcasted easily even + // if the other operand has a transformed layout. + if (IsScalar(input_expr)) { + return raw; + } int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); new_src_layout = src_layout.ExpandPrimal(dst_layout); input_expr = MakeExpandDims(input_expr, 0, num_new_axis); diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 9ab582d5b3e2..3f02e1db625e 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -318,6 +318,70 @@ def expected(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + +def test_alter_layout_broadcast_scalar_op(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 500, 500, 64)) + kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32') + bias = relay.var("bias", shape=(64,)) + multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32') + multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32') + + y = relay.nn.conv2d(x, kernel, + data_layout='NHWC', + kernel_layout="HWIO", + kernel_size=(3, 3)) + y = relay.add(bias, y) + y = relay.nn.relu(y) + + y = relay.multiply(multiplier1, y) + y = relay.multiply(y, multiplier2) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 500, 500, 64)) + kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32') + bias = relay.var("bias", shape=(64,)) + multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32') + multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32') + + b = relay.expand_dims(bias, axis=0, num_newaxis=3) + b = relay.layout_transform(b, "NHWC", "NCHW16c") + + y = relay.layout_transform(x, "NHWC", "NCHW16c") + y = relay.nn.conv2d(y, kernel, + data_layout='NCHW16c', + kernel_layout="HWIO", + kernel_size=(3, 3)) + + y = relay.add(b, y) + y = relay.nn.relu(y) + + y = relay.multiply(multiplier1, y) + y = relay.multiply(y, multiplier2) + y = relay.layout_transform(y, "NCHW16c", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + def test_alter_layout_scalar(): """Test alternating the layout of a conv2d. The layout of broadcast operators and the weight should be changed accordingly. @@ -980,6 +1044,7 @@ def expected(): test_alter_layout_dual_path() test_alter_layout_resnet() test_alter_layout_broadcast_op() + test_alter_layout_broadcast_scalar_op() test_alter_layout_scalar() test_alter_layout_concatenate() test_alter_layout_nchw_upsamping_op()