From 86e59e059f0d693257bc7a96a370f80658310617 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 25 Jan 2023 07:50:24 -0800 Subject: [PATCH] [CLML][RUNTIME] Enable more ops in CLML runtime Enable the DepthToSpace and Resize bilinear operator in CLML runtime and bug fix in concat layer --- python/tvm/relay/op/contrib/clml.py | 16 +++- src/runtime/contrib/clml/clml_runtime.cc | 67 +++++++++++++- tests/python/contrib/test_clml/test_ops.py | 102 +++++++++++++++++++++ 3 files changed, 183 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index e6e535edc068..2ec782098553 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -316,6 +316,18 @@ def check_softmax_op(extract): return False return True + def check_upsampling_op(extract): + call = extract + if call.attrs["method"] != "bilinear": + return False + return True + + def check_concat_op(extract): + call = extract + if call.attrs["axis"] != 1: + return False + return True + def check_default_op(extract): return True @@ -324,7 +336,7 @@ def check_default_op(extract): ("clml.conv2d", conv_pattern(), check_conv), ("clml.dense", dense_pattern(), check_default_op), ("clml.pad", pad_pattern(), check_pad_op), - ("clml.concat", concat_pattern(), check_default_op), + ("clml.concat", concat_pattern(), check_concat_op), ("clml.batch_norm", batch_norm_pattern(), check_default_op), ("clml.add", is_op("add")(wildcard(), wildcard()), check_binary_op), ("clml.subtract", is_op("subtract")(wildcard(), wildcard()), check_binary_op), @@ -341,6 +353,8 @@ def check_default_op(extract): ("clml.relu", is_op("nn.relu")(wildcard()), check_default_op), ("clml.clip", is_op("clip")(wildcard()), check_default_op), ("clml.batch_flatten", is_op("nn.batch_flatten")(wildcard()), check_default_op), + ("clml.depth_to_space", is_op("nn.depth_to_space")(wildcard()), check_default_op), + ("clml.upsampling", is_op("nn.upsampling")(wildcard()), check_upsampling_op), ] diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 1fb694a91201..0987eefdc9c0 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -430,6 +430,14 @@ class CLMLRuntime : public JSONRuntimeBase { auto out = CreateBinaryLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); + } else if ("nn.depth_to_space" == op_name) { + auto out = CreateDepthToSpaceLayer(&layer_, node); + this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); + this->layer_.func_outs.push_back(out); + } else if ("nn.upsampling" == op_name) { + auto out = CreateResizeLayer(&layer_, node); + this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); + this->layer_.func_outs.push_back(out); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -1151,13 +1159,14 @@ class CLMLRuntime : public JSONRuntimeBase { cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); int inputSize = input_.size(); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); + cl_uint axis = std::stoi(node.GetAttr>("axis")[0]); cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize]; for (int i = 0; i < inputSize; i++) { auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[i], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); concatInputs[i] = input->tensor; } - cl_ml_op_concat_desc_qcom concatDesc = {1, (cl_uint)inputSize, cl_arithmetic_mode}; + cl_ml_op_concat_desc_qcom concatDesc = {axis, (cl_uint)inputSize, cl_arithmetic_mode}; result = h_ClmlIntf->clCreateMLOpConcatQCOM(workspace->context, 0, &concatDesc, concatInputs, output->tensor, &op, tuning_cache); @@ -1301,6 +1310,62 @@ class CLMLRuntime : public JSONRuntimeBase { return output; } + /*! + * \brief Create a DepthToSpace(X) layer. + * + * \param layer The CLML layer to build. Containing inputs, outputs and the CLML output. + * \param node The JSON representation of the operator. + */ + std::shared_ptr CreateDepthToSpaceLayer( + CachedLayer* layer, const JSONGraphNode& node) { + cl_int result = 0; + cl_ml_op_qcom op = NULL; + DLDataType tvm_dtype = node.GetOpDataType()[0]; + cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); + cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); + auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, + cl_dtype); + auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); + cl_uint block_size = std::stoi(node.GetAttr>("block_size")[0]); + + cl_ml_op_depthtospace_desc_qcom dtos_desc = {block_size, cl_arithmetic_mode}; + result = h_ClmlIntf->clCreateMLOpDepthToSpaceQCOM( + workspace->context, 0, &dtos_desc, input->tensor, output->tensor, &op, tuning_cache); + ICHECK(op && result == CL_SUCCESS) << "DepthToSpace Layer Error:" << result; + + layer_.func_ins.push_back(input); + layer->function.push_back(op); + return output; + } + + /*! + * \brief Create a Resize(X) layer. + * + * \param layer The CLML layer to build. Containing inputs, outputs and the CLML output. + * \param node The JSON representation of the operator. + */ + std::shared_ptr CreateResizeLayer(CachedLayer* layer, + const JSONGraphNode& node) { + cl_int result = 0; + cl_ml_op_qcom op = NULL; + DLDataType tvm_dtype = node.GetOpDataType()[0]; + cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); + cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); + auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, + cl_dtype); + auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); + cl_bool align_corners = std::stoi(node.GetAttr>("align_corners")[0]); + + cl_ml_op_resize_bilinear_desc_qcom resize_desc = {align_corners, false, cl_arithmetic_mode}; + result = h_ClmlIntf->clCreateMLOpResizeBilinearQCOM( + workspace->context, 0, &resize_desc, input->tensor, output->tensor, &op, tuning_cache); + ICHECK(op && result == CL_SUCCESS) << "Resize Layer Error:" << result; + + layer_.func_ins.push_back(input); + layer->function.push_back(op); + return output; + } + /*! * \brief The network layers represented by acl functions. * \note Currently only supports a single layer. diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index c4ec2603249b..b8177435a0dc 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -574,5 +574,107 @@ def _verify(out, params, inputs): _verify(*(_get_model((1, 16), relay.nn.relu))) +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +@tvm.testing.requires_openclml +def test_depth_to_space(device, dtype): + def _get_model(a_shape, block_size): + a = relay.var("a", shape=(a_shape), dtype=dtype) + out = relay.nn.depth_to_space(a, block_size) + inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))} + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(inputs["a"].shape)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "block_size": [[str(int(out.attrs.block_size))]], + "layout": [["NCHW"]], + "mode": [["DCR"]], + "dtype": [[dtype]], + "num_inputs": "1", + "num_outputs": "1", + "shape": [[list(clml_out[0].shape)]], + }, + "inputs": [[0, 0, 0]], + "name": "nn.depth_to_space", + "op": "kernel", + }, + ] + verify_codegen(out, exp_codegen, device, params) + + _verify(*(_get_model((1, 64, 8, 8), 4))) + _verify(*(_get_model((1, 64, 8, 8), 8))) + + +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +@tvm.testing.requires_openclml +def test_resize_bilinear(device, dtype): + def _get_model(a_shape, scale, align_corners): + a = relay.var("a", shape=(a_shape), dtype=dtype) + out = relay.nn.upsampling( + a, scale_h=scale[0], scale_w=scale[1], method="bilinear", align_corners=align_corners + ) + inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))} + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(inputs["a"].shape)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "scale_h": [[str(int(out.attrs.scale_h))]], + "scale_w": [[str(int(out.attrs.scale_w))]], + "layout": [["NCHW"]], + "method": [[out.attrs.method]], + "align_corners": [[str(out.attrs.align_corners)]], + "dtype": [[dtype]], + "num_inputs": "1", + "num_outputs": "1", + "shape": [[list(clml_out[0].shape)]], + }, + "inputs": [[0, 0, 0]], + "name": "nn.upsampling", + "op": "kernel", + }, + ] + verify_codegen(out, exp_codegen, device, params) + + _verify(*(_get_model((1, 16, 8, 8), (2, 2), False))) + _verify(*(_get_model((1, 16, 7, 7), (2, 2), True))) + + if __name__ == "__main__": tvm.testing.main()