diff --git a/cmake/modules/contrib/CLML.cmake b/cmake/modules/contrib/CLML.cmake index 30e60423b03b..2fde0de65b4b 100644 --- a/cmake/modules/contrib/CLML.cmake +++ b/cmake/modules/contrib/CLML.cmake @@ -22,7 +22,21 @@ if(USE_CLML) if(NOT USE_CLML_GRAPH_EXECUTOR) list(APPEND COMPILER_SRCS ${CLML_RUNTIME_MODULE}) endif() - message(STATUS "Build with CLML support...") + message(STATUS "Build with CLML support : " ${USE_CLML}) + if (NOT USE_CLML STREQUAL "ON") + set(CLML_VERSION_HEADER "${USE_CLML}/CL/cl_qcom_ml_ops.h") + if(EXISTS ${CLML_VERSION_HEADER}) + file(READ ${CLML_VERSION_HEADER} ver) + string(REGEX MATCH "CL_QCOM_ML_OPS_H_MAJOR_VERSION ([0-9]*)" _ ${ver}) + set(CLML_VERSION_MAJOR ${CMAKE_MATCH_1}) + else() + set(CLML_VERSION_MAJOR "2") + endif() + else() + set(CLML_VERSION_MAJOR "2") + endif() + add_definitions(-DTVM_CLML_VERSION=${CLML_VERSION_MAJOR}) + message(STATUS "CLML SDK Version :" ${CLML_VERSION_MAJOR}) endif() if(USE_CLML_GRAPH_EXECUTOR) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 6453b8a06c9f..02e4f62bed24 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -28,6 +28,12 @@ from ..strategy.generic import is_depthwise_conv2d +def clml_sdk_version(): + """Utility function to get clml version version""" + + return tvm.support.libinfo().get("TVM_CLML_VERSION", 2) + + def is_clml_runtime_enabled(): """Check if the CLML graph runtime is present. @@ -92,38 +98,35 @@ def preprocess_module(mod): preprocessed_mod : The processed module. """ - def convert_layout_conv2d(conv2d_function): - def convert_conv(attrs, inputs, tinfos, desired_layouts): - new_attrs = dict(attrs) - data_info = tinfos[0] - weight_info = tinfos[1] - desired_data_layout, desired_kernel_layout = map(str, desired_layouts) - new_attrs["data_layout"] = desired_data_layout - new_attrs["kernel_layout"] = desired_kernel_layout - - if is_depthwise_conv2d( - data_info.shape, - attrs["data_layout"], - weight_info.shape, - attrs["kernel_layout"], - attrs["groups"], - ): - dkl = desired_kernel_layout - new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3] - return conv2d_function(*inputs, **new_attrs) - - return convert_conv - - with OpAttrContext( - "nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d) - ): + def alter_conv(attrs, inputs, tinfos, out_type): + new_attrs = dict(attrs) + data_info = tinfos[0] + weight_info = tinfos[1] + (desired_data_layout, desired_kernel_layout) = ("NCHW", "OIHW") + new_attrs["data_layout"] = desired_data_layout + new_attrs["kernel_layout"] = desired_kernel_layout + + if is_depthwise_conv2d( + data_info.shape, + attrs["data_layout"], + weight_info.shape, + attrs["kernel_layout"], + attrs["groups"], + ): + dkl = desired_kernel_layout + new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3] + return relay.nn.conv2d(*inputs, **new_attrs) + + with OpAttrContext("nn.conv2d", "FTVMAlterOpLayout", alter_conv): seq = tvm.transform.Sequential( [ transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}), + transform.AlterOpLayout(), transform.FoldConstant(), ] ) - preprocessed_mod = seq(mod) + with tvm.transform.PassContext(opt_level=3): + preprocessed_mod = seq(mod) return preprocessed_mod @@ -275,6 +278,9 @@ def check_default_op(extract): ("clml.add", is_op("add")(wildcard(), wildcard()), check_binary_op), ("clml.subtract", is_op("subtract")(wildcard(), wildcard()), check_binary_op), ("clml.multiply", is_op("multiply")(wildcard(), wildcard()), check_binary_op), + ("clml.divide", is_op("divide")(wildcard(), wildcard()), check_binary_op), + ("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op), + ("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op), ("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op), ("clml.reshape", is_op("reshape")(wildcard()), check_default_op), ("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op), diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc index 167c48e1baf5..d8ca791ad8c4 100644 --- a/src/relay/backend/contrib/clml/codegen.cc +++ b/src/relay/backend/contrib/clml/codegen.cc @@ -328,7 +328,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer { const auto* dense = fn->body.as(); const CallNode* bias = nullptr; - if (backend::IsOp(dense, "add")) { + if (backend::IsOp(dense, "add") || backend::IsOp(dense, "nn.bias_add")) { bias = dense; dense = dense->args[0].as(); } diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index a667caaafcd8..6396fce4858b 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -153,13 +153,25 @@ class CLMLRuntime : public JSONRuntimeBase { ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; for (cl_uint i = 0; i < numVersions; ++i) { +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2 if (majorVersions[i] == 2) { - LOG(WARNING) << "CLML Version Selected:" << majorVersions[i] << " : " << majorVersions[i]; h_ClmlIntf = clGetMLInterfaceV2QCOM(0); - ICHECK(h_ClmlIntf != NULL) << "clGetMLInterfaceV2QCOM:" << result; + LOG(WARNING) << "CLML Target version:" << majorVersions[i]; break; } +#endif +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3 + if (majorVersions[i] == 3) { + h_ClmlIntf = clGetMLInterfaceV3QCOM(0); + LOG(WARNING) << "CLML Target version:" << majorVersions[i]; + break; + } +#endif } + ICHECK(h_ClmlIntf != NULL) + << "clGetMLInterfaceVxQCOM:" << result + << " Perhaps there is mispatch between CLML SDK version to target supported version:" + << majorVersions[numVersions - 1]; char* tune_flag; if ((tune_flag = getenv("CLML_IS_TUNNING_RUN"))) this->is_tuning_run = std::stoi(tune_flag); @@ -400,7 +412,7 @@ class CLMLRuntime : public JSONRuntimeBase { this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); } else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name || - "minimum" == op_name || "maximum" == op_name) { + "minimum" == op_name || "maximum" == op_name || "divide" == op_name) { auto out = CreateBinaryLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); this->layer_.func_outs.push_back(out); @@ -523,7 +535,7 @@ class CLMLRuntime : public JSONRuntimeBase { } cl_ml_tensor_qcom DeviceMakeCLMLTensor( - void* pClmlIntf, cl_context context, tensor_dims_t dims, + cl_context context, tensor_dims_t dims, cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_channel_type dtype = CL_FLOAT) { cl_ml_tensor_qcom tensor; @@ -531,8 +543,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = { dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }}; - CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast(pClmlIntf); - result = clmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor); + result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor); ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; (void)result; return tensor; @@ -544,9 +555,8 @@ class CLMLRuntime : public JSONRuntimeBase { cl_int result = CL_OUT_OF_HOST_MEMORY; cl_mem buffer = NULL; - CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast(pClmlIntf); result = - clmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size); + h_ClmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size); ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result; buffer = clCreateBuffer(workspace->context, CL_MEM_READ_WRITE, size, NULL, &result); @@ -612,8 +622,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); auto tensor_dsc = std::make_shared(); - tensor_dsc->tensor = - DeviceMakeCLMLTensor(h_ClmlIntf, workspace->context, dims, layout, cl_dtype); + tensor_dsc->tensor = DeviceMakeCLMLTensor(workspace->context, dims, layout, cl_dtype); return tensor_dsc; } @@ -901,7 +910,6 @@ class CLMLRuntime : public JSONRuntimeBase { auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]); std::vector windows = node.GetAttr>("pool_size"); std::vector strides = node.GetAttr>("strides"); @@ -1103,7 +1111,6 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); int inputSize = input_.size(); - int axis = std::stoi(node.GetAttr>("axis")[0]); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize]; for (int i = 0; i < inputSize; i++) { @@ -1236,6 +1243,8 @@ class CLMLRuntime : public JSONRuntimeBase { binary_op = CL_TENSOR_OP_SUB_QCOM; else if (op_name == "multiply") binary_op = CL_TENSOR_OP_MUL_QCOM; + else if (op_name == "divide") + binary_op = CL_TENSOR_OP_DIV_QCOM; else if (op_name == "minimum") binary_op = CL_TENSOR_OP_MIN_QCOM; else if (op_name == "maximum") @@ -1260,7 +1269,12 @@ class CLMLRuntime : public JSONRuntimeBase { CachedLayer layer_; // CLML Context +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2 CLMLInterfaceV2QCOM* h_ClmlIntf = NULL; +#endif +#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3 + CLMLInterfaceV3QCOM* h_ClmlIntf = NULL; +#endif cl::OpenCLWorkspace* workspace = NULL; cl::OpenCLThreadEntry* tentry = NULL; cl_ml_tuningcache_qcom tuning_cache = NULL; diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py index 89c22255d77d..be2bbc7f8a71 100644 --- a/tests/python/contrib/test_clml/infrastructure.py +++ b/tests/python/contrib/test_clml/infrastructure.py @@ -39,9 +39,9 @@ class Device: Configuration for CLML tests. Check tests/python/contrib/clml/ for the presence of an test_config.json file. - This file can be used to override the default configuration here which will attempt to run the Arm - Compute Library runtime tests locally if the runtime is available. Changing the configuration - will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example. + This file can be used to override the default configuration here which will attempt to run the + Open CLML runtime tests locally if the runtime is available. Changing the configuration + will allow these runtime tests to be offloaded to a remote Snapdragon device via a tracker for example. Notes ----- @@ -101,6 +101,25 @@ def _get_remote(cls): return device +def get_cpu_op_count(mod): + """Traverse graph counting ops offloaded to TVM.""" + + class Counter(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + c = Counter() + c.visit(mod["main"]) + return c.count + + def skip_codegen_test(): """Skip test if it requires the CLML codegen and it's not present.""" if not tvm.get_global_func("relay.ext.clml", True): @@ -130,7 +149,6 @@ def build_and_run( try: libm = build_module(mod, device.target, device.target_host, params, enable_clml, tune_log) - clml_modules = extract_clml_modules(libm) for mod in clml_modules: source = mod.get_source("json") @@ -155,9 +173,9 @@ def build_and_run( for _ in range(no_runs): gen_module.run() out.append([gen_module.get_output(i) for i in range(outputs)]) - time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1) - cost = time_f().mean - print("%g secs/iteration\n" % cost) + # time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1) + # cost = time_f().mean + # print("%g secs/iteration\n" % cost) return out @@ -181,16 +199,34 @@ def extract_clml_modules(module): def verify_codegen( - module, + mod, known_good_codegen, + device, + params, num_clml_modules=1, tvm_ops=0, - target="llvm -mtriple=aarch64-linux-gnu", ): """Check clml codegen against a known good output.""" - module = build_module(module, target, tvm_ops=tvm_ops, clml_partitions=num_clml_modules) - clml_modules = extract_clml_modules(module) + if isinstance(mod, tvm.relay.expr.Call): + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.partition_for_clml(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format( + tvm_op_count, tvm_ops + ) + partition_count = 0 + for global_var in mod.get_global_vars(): + if "clml" in global_var.name_hint: + partition_count += 1 + + assert ( + num_clml_modules == partition_count + ), "Got {} Open CLML partitions, expected {}".format(partition_count, num_clml_modules) + relay.backend.te_compiler.get().clear() + module = relay.build(mod, target=device.target, target_host=device.target_host, params=params) + clml_modules = extract_clml_modules(module) assert len(clml_modules) == num_clml_modules, ( f"The number of CLML modules produced ({len(clml_modules)}) does not " f"match the expected value ({num_clml_modules})." diff --git a/tests/python/contrib/test_clml/test_network.py b/tests/python/contrib/test_clml/test_network.py index 8d740d6dce4d..177359d9b18a 100644 --- a/tests/python/contrib/test_clml/test_network.py +++ b/tests/python/contrib/test_clml/test_network.py @@ -91,13 +91,8 @@ def get_model(): mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5 ) - # test - print("OpenCL:", outputs[0].asnumpy().shape) - print("CLML:", outputs[1].asnumpy().shape) - opencl_sort = np.argsort(outputs[1].asnumpy()).flatten() clml_sort = np.argsort(outputs[0].asnumpy()).flatten() - tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5) @@ -134,7 +129,6 @@ def get_model(): opencl_sort = np.argsort(outputs[1].asnumpy()).flatten() clml_sort = np.argsort(outputs[0].asnumpy()).flatten() - tvm.testing.assert_allclose(opencl_sort[:5], clml_sort[:5], rtol=1e-5, atol=1e-5) @@ -176,11 +170,10 @@ def get_model(): mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5 ) - # test - print("OpenCL:", outputs[0].asnumpy().shape) - print("CLML:", outputs[1].asnumpy().shape) - opencl_sort = np.argsort(outputs[1].asnumpy()).flatten() clml_sort = np.argsort(outputs[0].asnumpy()).flatten() - tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index da09715fbe4c..c4ec2603249b 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -14,15 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""CLML integration conv2d tests.""" +"""CLML integration operator tests.""" import tvm import numpy as np from tvm import relay +from tvm.relay.op.contrib import clml from tvm.relay import testing from tvm.ir import IRModule from tvm.contrib import utils -from test_clml.infrastructure import build_and_run, Device, skip_codegen_test +from test_clml.infrastructure import ( + build_and_run, + Device, + skip_codegen_test, + verify_codegen, + build_module, + get_cpu_op_count, +) import pytest @@ -54,11 +62,8 @@ def _get_conv_model( shape = (shape[0], shape[1], shape[2] + padding[0] * 2, shape[3] + padding[1] * 2) is_depthwise = shape[1] == channels == groups - weight_format = "OIHW" if is_depthwise else "OIHW" - if weight_format == "IOHW": - weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w) - else: - weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + weight_format = "OIHW" + weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) w = tvm.nd.array(np.random.uniform(-1, 1, weight_shape).astype(dtype)) weights = relay.const(w, dtype) @@ -77,7 +82,7 @@ def _get_conv_model( ) params = {"w": w} if has_bias: - bias_shape = weight_shape[2] if is_depthwise else weight_shape[0] + bias_shape = (weight_shape[0],) b = tvm.nd.array(np.random.uniform(-1, 1, bias_shape).astype(dtype)) biasc = relay.const(b, dtype) out = relay.nn.bias_add(out, biasc, axis=1) @@ -86,31 +91,121 @@ def _get_conv_model( if has_activation: out = relay.nn.relu(out) - print("Out:", out) - return out, params +def _get_conv_expected_codegen( + shape, + kernel_h, + kernel_w, + padding, + strides, + dilation, + groups, + dtype, + channels, + has_bias=False, + has_activation=False, +): + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + output_height = ((shape[2] - kernel_h + padding[0] + padding[2]) / strides[0]) + 1 + output_width = ((shape[3] - kernel_w + padding[1] + padding[3]) / strides[1]) + 1 + output_shape = (1, channels, int(output_height), int(output_width)) + out_dtype = dtype + is_depthwise = shape[1] == channels == groups + + weight_format = "IOHW" if is_depthwise else "OIHW" + if weight_format == "OIHW": + weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w) + else: + weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w) + + if is_depthwise: + name = "nn.depthwise_conv2d" + else: + name = "nn.conv2d" + + node = { + "op": "kernel", + "name": name, + "inputs": [], + "attrs": { + "groups": [[str(groups)]], + "num_outputs": "1", + "data_layout": [["NCHW"]], + "kernel_layout": [[weight_format]], + "channels": [[str(channels)]], + "dilation": [[str(dilation[0]), str(dilation[1])]], + "out_layout": [[""]], + "out_dtype": [[out_dtype]], + "kernel_size": [[str(kernel_h), str(kernel_w)]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in strides]], + }, + } + + if has_activation: + node["attrs"]["activation_type"] = [["relu"]] + + inputs = [ + {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}}, + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]}, + }, + ] + + if has_bias: + bias_dtype = dtype + inputs.append( + { + "op": "const", + "name": "", + "attrs": { + "shape": [[[1, weight_shape[1] if is_depthwise else weight_shape[0], 1, 1]]], + "dtype": [[bias_dtype]], + }, + } + ) + + input_idx = 0 + for _ in range(len(inputs)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(inputs)) + inputs.append(node) + return inputs + + @pytest.mark.parametrize("dtype", ["float32"]) @tvm.testing.requires_openclml def test_conv2d(device, dtype): trials = [ # Normal convolution - [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)], - [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True)], - [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)], - [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True)], - # Normal convolution - [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)], - [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True)], - [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)], - [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, False)], - [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False)], - [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True)], - [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False)], - [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False)], - [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False)], - [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True)], + [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True), False], + [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), False], + [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False], + [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False], + [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True), False], + # Depth-wise convolution + [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True], + [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True], + [3, 3, (2, 2), (2, 2), (1, 1), 14, (14, 10, 10), (False, False, False), True], + [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True], + [3, 3, (1, 1), (2, 2), (1, 1), 14, (14, 10, 10), (False, True, True), True], ] for ( @@ -122,9 +217,13 @@ def test_conv2d(device, dtype): out_channels, shape, composite, + is_depthwise, ) in trials: shape = (1, *shape) - groups = 1 + if is_depthwise: + groups = shape[1] + else: + groups = 1 outputs = [] inputs = { "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype(dtype)), @@ -151,11 +250,19 @@ def test_conv2d(device, dtype): tvm.testing.assert_allclose( clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-5, atol=1e-5 ) + args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels) + exp_codegen = _get_conv_expected_codegen( + *args, has_bias=composite[1], has_activation=composite[2] + ) + verify_codegen(func, exp_codegen, device, params) @pytest.mark.parametrize("dtype", ["float16"]) @tvm.testing.requires_openclml -def _test_batchnorm(device, dtype): +def test_batchnorm(device, dtype): + if tvm.support.libinfo().get("TVM_CLML_VERSION", 2) < 3: + print("Skip due to unsupported CLML version") + return in_shape = (1, 8, 64, 64) channels = 8 @@ -211,11 +318,80 @@ def test_concat(device, dtype): tvm.testing.assert_allclose( clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 ) + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(in_shape_1)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(in_shape_2)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "axis": [["1"]], + "dtype": [[dtype]], + "num_inputs": "2", + "num_outputs": "1", + "shape": [[list(clml_out[0].shape)]], + }, + "inputs": [[0, 0, 0], [1, 0, 0]], + "name": "concatenate", + "op": "kernel", + }, + ] + verify_codegen(func, exp_codegen, device, params) + + +def _get_pool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype): + import math + + pool_height = math.floor(((input_shape[2] + padding[2] - pool_size[0]) / stride[0]) + 1) + pool_width = math.floor(((input_shape[3] + padding[3] - pool_size[1]) / stride[1]) + 1) + output_shape = [input_shape[0], input_shape[1], pool_height, pool_width] + attrs = { + "ceil_mode": [["0"]], + "dilation": [["1", "1"]], + "layout": [["NCHW"]], + "num_inputs": "1", + "num_outputs": "1", + "out_layout": [[""]], + "padding": [[str(p) for p in padding]], + "pool_size": [[str(p) for p in pool_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in stride]], + } + if sum(padding): + attrs["count_include_pad"] = [["0"]] + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(input_shape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "nn.avg_pool2d" if pool_type == "avg" else "nn.max_pool2d", + "inputs": [[0, 0, 0]], + "attrs": attrs, + }, + ] + return exp_codegen @pytest.mark.parametrize("dtype", ["float16"]) @tvm.testing.requires_openclml -def test_avgpool(device, dtype): +def test_pool(device, dtype): trials = [ # input size pool_size stride paading [(1, 64, 147, 147), (3, 3), (2, 2), (0, 0, 0, 0), "max"], @@ -251,7 +427,152 @@ def test_avgpool(device, dtype): opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + args = (input_shape, pool_size, stride, padding, pooling_type, dtype) + exp_codegen = _get_pool_expected_codegen(*args) + verify_codegen(func, exp_codegen, device, params) + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_dense(device, dtype): + def _get_model(x_shape, k_shape, has_bias=False): + x = relay.var("x", shape=(x_shape), dtype=dtype) + kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + out = relay.nn.dense(x, kernel, units=k_shape[0]) + params = {"kernel": tvm.nd.array(np.random.uniform(-1, 1, k_shape).astype(dtype))} + inputs = {"x": tvm.nd.array(np.random.uniform(-1, 1, x_shape).astype(dtype))} + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(x_shape)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(k_shape)]], + }, + "name": "", + "op": "const", + }, + ] + if has_bias: + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(out, bias) + bias_node = { + "attrs": { + "dtype": [[dtype]], + "shape": [[list((1, k_shape[0]))]], + }, + "name": "", + "op": "const", + } + exp_codegen.append(bias_node) + params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, (k_shape[0],)).astype(dtype)) + + dense_node = { + "attrs": { + "num_inputs": "3" if has_bias else "2", + "num_outputs": "1", + "dtype": [[dtype]], + "out_dtype": [[""]], + "shape": [[[x_shape[0], k_shape[0]]]], + "units": [[str(k_shape[0])]], + }, + "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] if has_bias else [[0, 0, 0], [1, 0, 0]], + "name": "nn.dense", + "op": "kernel", + } + exp_codegen.append(dense_node) + return out, params, inputs, exp_codegen + + def _verify(out, params, inputs, exp_codegen): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] tvm.testing.assert_allclose( clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 ) + verify_codegen(out, exp_codegen, device, params) + + _verify(*(_get_model((1, 16), (32, 16)))) + _verify(*(_get_model((1, 16), (32, 16), True))) + + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_binary_ops(device, dtype): + def _get_model(a_shape, b_shape, op): + a = relay.var("a", shape=(a_shape), dtype=dtype) + b = relay.var("b", shape=(b_shape), dtype=dtype) + out = op(a, b) + inputs = { + "a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)), + "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype)), + } + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.partition_for_clml(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, expected 0".format( + tvm_op_count + ) + + _verify(*(_get_model((1, 16), (1, 16), relay.add))) + _verify(*(_get_model((1, 16), (1, 16), relay.subtract))) + _verify(*(_get_model((1, 16), (1, 16), relay.multiply))) + _verify(*(_get_model((1, 16), (1, 16), relay.divide))) + _verify(*(_get_model((1, 16), (1, 16), relay.minimum))) + _verify(*(_get_model((1, 16), (1, 16), relay.maximum))) + + +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_unary_ops(device, dtype): + def _get_model(a_shape, op): + a = relay.var("a", shape=(a_shape), dtype=dtype) + out = op(a) + inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))} + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + mod = clml.partition_for_clml(mod, params) + tvm_op_count = get_cpu_op_count(mod) + assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, expected 0".format( + tvm_op_count + ) + + _verify(*(_get_model((1, 16), relay.nn.softmax))) + _verify(*(_get_model((1, 16), relay.nn.relu))) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index 6b43d7cbc421..187ca7f815df 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -29,8 +29,12 @@ cd ${output_directory} cp ../cmake/config.cmake . echo set\(USE_MICRO OFF\) >> config.cmake -echo set\(USE_CLML ON\) >> config.cmake +if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then +echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake +else +echo set\(USE_OPENCL ON\) >> config.cmake +fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index d45c5e8b7dcf..d378b5f842b5 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -24,7 +24,9 @@ cd "$BUILD_DIR" cp ../cmake/config.cmake . echo set\(USE_OPENCL ON\) >> config.cmake -echo set\(USE_CLML ON\) >> config.cmake +if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then +echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake +fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake