From 98139b37090b30b5e512b17e38ec6d3aa4c43dd8 Mon Sep 17 00:00:00 2001 From: princek Date: Tue, 22 Oct 2024 08:11:02 +0000 Subject: [PATCH] [Marvell BYOC]: global_max_pool2d and squeeze op support --- python/tvm/relay/op/contrib/mrvl.py | 54 +++++++++- src/relay/backend/contrib/mrvl/codegen.cc | 102 ++++++++++++++++++ tests/python/contrib/test_mrvl/test_mrvl.py | 108 ++++++++++++++++++++ 3 files changed, 263 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/mrvl.py b/python/tvm/relay/op/contrib/mrvl.py index 75041fbc8c44..b13cf3d9533d 100644 --- a/python/tvm/relay/op/contrib/mrvl.py +++ b/python/tvm/relay/op/contrib/mrvl.py @@ -535,7 +535,6 @@ def avgpool2d_base_pattern(pattern): def globalavgpool2d_pattern(): """Create a globalavgpool2d pattern. - review tvm/tests/python/relay/test_dataflow_pattern.py for examples Returns ------- pattern : dataflow_pattern.AltPattern @@ -544,6 +543,17 @@ def globalavgpool2d_pattern(): pattern = is_op("nn.global_avg_pool2d")(wildcard()) return pattern + def globalmaxpool2d_pattern(): + """Create a globalmaxpool2d pattern. + review tvm/tests/python/relay/test_dataflow_pattern.py for examples + Returns + ------- + pattern : dataflow_pattern.AltPattern + Denotes the globalmaxpool2d pattern. + """ + pattern = is_op("nn.global_max_pool2d")(wildcard()) + return pattern + def reshape_pattern(): pattern = is_op("reshape")(wildcard()) return pattern @@ -552,6 +562,10 @@ def batch_flatten_pattern(): pattern = is_op("nn.batch_flatten")(wildcard()) return pattern + def squeeze_pattern(): + pattern = is_op("squeeze")(wildcard()) + return pattern + def layout_transform_nchw2nhwc_pattern(): pattern = is_op("layout_transform")(is_var(), wildcard(), wildcard()).has_attr( {"src_layout": "NCHW", "dst_layout": "NHWC"} @@ -596,6 +610,13 @@ def check_globalavgpool2d(extract): call = call.args[0] return globalavgpool2d_nhwc2nhwc(call) + def check_globalmaxpool2d(extract): + """Check globalmaxpool2d pattern is supported by Mrvl.""" + call = extract + while call.op.name != "nn.global_max_pool2d": + call = call.args[0] + return globalmaxpool2d_nhwc2nhwc(call) + def check_reshape(extract): call = extract while call.op.name != "reshape": @@ -608,6 +629,12 @@ def check_batch_flatten(extract): call = call.args[0] return batch_flatten_mrvl(call) + def check_squeeze(extract): + call = extract + while call.op.name != "squeeze": + call = call.args[0] + return squeeze_mrvl(call) + def check_layout_transform_nchw2nhwc(extract): call = extract while call.op.name != "layout_transform": @@ -634,6 +661,7 @@ def check_concat(extract): ("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d), ("mrvl.avgpool2d_nhwc2nhwc", avgpool2d_pattern(), check_avgpool2d), ("mrvl.globalavgpool2d_nhwc2nhwc", globalavgpool2d_pattern(), check_globalavgpool2d), + ("mrvl.globalmaxpool2d_nhwc2nhwc", globalmaxpool2d_pattern(), check_globalmaxpool2d), ("mrvl.sum", sum_pattern(), check_sum), ("mrvl.concat", concat_pattern(), check_concat), ( @@ -643,6 +671,7 @@ def check_concat(extract): ), ("mrvl.reshape", reshape_pattern(), check_reshape), ("mrvl.batch_flatten", batch_flatten_pattern(), check_batch_flatten), + ("mrvl.squeeze", squeeze_pattern(), check_squeeze), ] @@ -813,6 +842,21 @@ def globalavgpool2d_nhwc2nhwc(expr): return True +# register a helper function to indicate that the given operator can be supported by Mrvl. +@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.mrvl") +def globalmaxpool2d_nhwc2nhwc(expr): + """Check if the external Mrvl codegen for globalmaxpool2d_nhwc2nhwc should be used.""" + attrs, args = expr.attrs, expr.args + if attrs.layout != "NHWC": + return False + data_type = args[0].checked_type + if not (len(data_type.shape) == 4 or len(data_type.shape) == 2): + return False + if (len(data_type.shape) != 4) or (data_type.dtype not in ["float32"]): + return False + return True + + @tvm.ir.register_op_attr("reshape", "target.mrvl") def reshape_mrvl(expr): """Check if the external Mrvl codegen for reshape should be used.""" @@ -846,6 +890,14 @@ def batch_flatten_mrvl(expr): return True +@tvm.ir.register_op_attr("squeeze", "target.mrvl") +def squeeze_mrvl(expr): + """Check if the external Mrvl codegen for squeeze should be used.""" + if expr.op.name != "squeeze": + return False + return True + + # register a helper function to indicate that the given operator can be supported by Mrvl. @tvm.ir.register_op_attr("layout_transform", "target.mrvl") def layout_transform_nchw2nhwc(expr): diff --git a/src/relay/backend/contrib/mrvl/codegen.cc b/src/relay/backend/contrib/mrvl/codegen.cc index 6d7e593b9b04..96121e4b4b69 100644 --- a/src/relay/backend/contrib/mrvl/codegen.cc +++ b/src/relay/backend/contrib/mrvl/codegen.cc @@ -225,6 +225,13 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { const CallNode* batch_flatten = nullptr; }; + /*! + * \brief A series of operators that form a Squeeze node. + */ + struct CompositeSqueezeNode { + const CallNode* squeeze = nullptr; + }; + /*! * \brief A series of operators that form a composite * fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no. @@ -278,6 +285,8 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { json_kernel_node = CreateCompositeMrvlAvgpool2DLayer(cn); } else if (name == "mrvl.globalavgpool2d_nhwc2nhwc") { json_kernel_node = CreateCompositeMrvlGlobalAvgpool2DLayer(cn); + } else if (name == "mrvl.globalmaxpool2d_nhwc2nhwc") { + json_kernel_node = CreateCompositeMrvlGlobalMaxpool2DLayer(cn); } else if (name == "mrvl.sum") { json_kernel_node = CreateCompositeMrvlSumLayer(cn); } else if (name == "mrvl.concat") { @@ -286,6 +295,8 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { json_kernel_node = CreateMrvlReshapeLayer(cn); } else if (name == "mrvl.batch_flatten") { json_kernel_node = CreateMrvlBatchFlattenLayer(cn); + } else if (name == "mrvl.squeeze") { + json_kernel_node = CreateMrvlSqueezeLayer(cn); } else { LOG(FATAL) << "Unrecognized Mrvl pattern: " << name; } @@ -511,6 +522,22 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { return nodes; } + /*! + * \brief Extract squeeze nodes from a composite function. + * \param call The call node of the composite function. + * \return Extracted composite squeeze nodes. + */ + CompositeSqueezeNode UnpackCompositeSqueeze(const CallNode* call) { + CompositeSqueezeNode nodes{}; + const auto* fn = call->op.as(); + ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode failed."; + const auto* current_call = fn->body.as(); + ICHECK(backend::IsOp(current_call, "squeeze")) + << "Marvell-Compiler-ERROR-Internal::squeeze missing."; + nodes.squeeze = current_call; + return nodes; + } + /*! * \brief Extract maxpool nodes from a composite function. * @@ -533,6 +560,11 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { << "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing."; ICHECK(backend::IsOp(current_call, "nn.avg_pool2d")) << "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing."; + } else if (mrvlLayerName == "GlobalMaxpool2D") { + ICHECK(mrvlLayerName == "GlobalMaxpool2D") + << "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op missing."; + ICHECK(backend::IsOp(current_call, "nn.global_max_pool2d")) + << "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op missing."; } else { ICHECK(mrvlLayerName == "GlobalAvgpool2D") << "Marvell-Compiler-ERROR-Internal::nn.global_avg_pool2d Op missing."; @@ -1115,6 +1147,34 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { return json_node; } + /*! + * \brief Create a JSON representation of a composite Squeeze. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateMrvlSqueezeLayer(const CallNode* cn) { + CompositeSqueezeNode nodes = UnpackCompositeSqueeze(cn); + std::vector inputs; + std::string name = "squeeze"; + inputs.push_back(VisitExpr(cn->args[0])[0]); + std::vector layout_vec; + GetInputTensorShapeViaArgN(nodes.squeeze, &layout_vec); + std::string data_layout; + if (layout_vec.size() == 4) { + data_layout = "NHWC"; + } else { + data_layout = "NC"; + } + layout_vec.clear(); + std::string out_layout = "NC"; + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout, + "" /* no kernel_layout */, out_layout); + SetMrvlQuantAttrs(json_node, nodes.instrument_1, "1"); + return json_node; + } + /*! * \brief Create a JSON representation of a composite concat. * @@ -1304,6 +1364,48 @@ class MrvlJSONSerializer : public backend::contrib::JSONSerializer { return json_node; } + /*! + * \brief Create a JSON representation of a composite globalmaxpooling operator. + * + * A composite function is only created when using the uint8 datatype for these operators. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeMrvlGlobalMaxpool2DLayer(const CallNode* cn) { + std::string mrvlLayerName = "GlobalMaxpool2D"; + std::string name = "nn.globalmaxpool2d_nhwc2nhwc"; + CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName); + + const auto* globalmaxpool_attr = nodes.pool->attrs.as(); + ICHECK(globalmaxpool_attr) + << "Marvell-Compiler-ERROR-Internal::Downcast to GlobalPool2DAttrs failed."; + ICHECK(globalmaxpool_attr->layout == "NHWC") + << "Marvell-Compiler-ERROR-Internal::" + << "Layout must be NHWC, has the module been pre-processed correctly?"; + + std::string data_layout = globalmaxpool_attr->layout; + std::string out_layout = globalmaxpool_attr->layout; + std::vector inputs; + std::vector kernel_layout_vec; + std::vector data_layout_vec; + GetInputTensorShapeViaArgN(cn, &data_layout_vec); + ICHECK(data_layout_vec.size() == 4); + kernel_layout_vec.push_back(data_layout_vec[1]); + kernel_layout_vec.push_back(data_layout_vec[2]); + inputs.push_back(VisitExpr(cn->args[0])[0]); + + // op_type_ is "kernel" + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, nodes.pool); + JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec); + if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad); + + SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, data_layout, "HW", + out_layout); + return json_node; + } + /*! * \brief Create a JSON representation of an OpNode layer. * diff --git a/tests/python/contrib/test_mrvl/test_mrvl.py b/tests/python/contrib/test_mrvl/test_mrvl.py index 26956c97c5c1..cd3f343c2d03 100644 --- a/tests/python/contrib/test_mrvl/test_mrvl.py +++ b/tests/python/contrib/test_mrvl/test_mrvl.py @@ -181,7 +181,115 @@ def get_graph(): run_and_verify_func(get_graph()) +@requires_mrvl +def test_maxpool2d(): + """Test maxpool2d operator for "mrvl" targets""" + + def get_graph(): + x = relay.var("x", shape=(1, 3, 224, 224)) + arr = np.random.rand(16, 3, 3, 3).astype("float32") + w = relay.const(arr) + y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) + y = relay.nn.max_pool2d(y) + func = relay.Function([x], y) + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params={}, tvm_ops=1, contains="mrvl.maxpool2d_nhwc2nhwc") + return func, {"x": (1, 3, 224, 224)}, [], option_dict + + run_and_verify_func(get_graph()) + + +@requires_mrvl +def test_avgpool2d(): + """Test avgpool2d operator for "mrvl" targets""" + + def get_graph(): + x = relay.var("x", shape=(1, 3, 224, 224)) + arr = np.random.rand(16, 3, 3, 3).astype("float32") + w = relay.const(arr) + y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) + y = relay.nn.avg_pool2d(y) + func = relay.Function([x], y) + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params={}, tvm_ops=1, contains="mrvl.avgpool2d_nhwc2nhwc") + return func, {"x": (1, 3, 224, 224)}, [], option_dict + + run_and_verify_func(get_graph()) + + +@requires_mrvl +def test_globalavgpool2d(): + """Test globalavgpool2d operator for "mrvl" targets""" + + def get_graph(): + x = relay.var("x", shape=(1, 3, 224, 224)) + arr = np.random.rand(16, 3, 3, 3).astype("float32") + w = relay.const(arr) + y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) + y = relay.nn.global_avg_pool2d(y) + func = relay.Function([x], y) + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params={}, tvm_ops=1, contains="mrvl.globalavgpool2d_nhwc2nhwc") + return func, {"x": (1, 3, 224, 224)}, [], option_dict + + run_and_verify_func(get_graph()) + + +@requires_mrvl +def test_globalmaxpool2d(): + """Test globalmaxpool2d operator for "mrvl" targets""" + + def get_graph(): + x = relay.var("x", shape=(1, 3, 224, 224)) + arr = np.random.rand(16, 3, 3, 3).astype("float32") + w = relay.const(arr) + y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) + y = relay.nn.global_max_pool2d(y) + func = relay.Function([x], y) + params = {} + params["w"] = arr + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params=params, tvm_ops=2, contains="mrvl.globalmaxpool2d_nhwc2nhwc") + return func, {"x": (1, 3, 224, 224), "w": (16, 3, 3, 3)}, ["w"], option_dict + + run_and_verify_func(get_graph()) + + +@requires_mrvl +def test_squeeze(): + """Test squeeze operator for "mrvl" targets""" + + def get_graph(): + x = relay.var("x", shape=(1, 3, 224, 224)) + arr = np.random.rand(16, 3, 3, 3).astype("float32") + w = relay.const(arr) + y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], kernel_size=[3, 3]) + y = relay.reshape(y, newshape=(1, 1, 16, 112, 112)) + y = relay.squeeze(y, axis=[0, 1]) + func = relay.Function([x], y) + mod = tvm.IRModule() + mod["main"] = func + option_dict = {"num_tiles": 1} + verify_codegen(mod, params={}, tvm_ops=3, contains="mrvl.squeeze") + return func, {"x": (1, 3, 224, 224)}, [], option_dict + + run_and_verify_func(get_graph()) + + if __name__ == "__main__": test_mrvl_fuse() test_conv2d() test_dense() + test_maxpool2d() + test_avgpool2d() + test_globalavgpool2d() + test_globalmaxpool2d() + test_squeeze()