From 574277601f0e706395a47a0b49b693a5f9eb53ab Mon Sep 17 00:00:00 2001 From: cchung100m Date: Wed, 16 Oct 2019 00:03:53 +0800 Subject: [PATCH 01/62] [Relay][Frontend][ONNX] operator support: ConstantOfShape --- python/tvm/relay/frontend/onnx.py | 23 ++++++++++++++- tests/python/frontend/onnx/test_forward.py | 33 ++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a7f787484b2c..c68c56194ac1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -915,6 +915,7 @@ def _impl_v1(cls, inputs, attr, params): reps = attr.pop('repeats') # The number of times repeating the tensor data. return _op.tile(inputs[0], reps) + class Erf(OnnxOpConverter): """Operator converter for Erf """ @@ -923,6 +924,18 @@ def _impl_v1(cls, inputs, attr, params): return _op.erf(inputs[0]) +class ConstantOfShape(Elemwise): + """Operator converter for ConstantOfShape + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if not isinstance(inputs, list) or len(inputs) < 2: + raise ValueError("Expect minimum 2 inputs") + # reps: The number of times repeating the tensor data. + shape = tuple(params[inputs[0].name_hint].asnumpy().astype('int').tolist()) + return _op.tile(inputs[1], reps=shape) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1042,7 +1055,8 @@ def _get_convert_map(opset): 'Not': Not.get_converter(opset), 'And': And.get_converter(opset), 'Tile': Tile.get_converter(opset), - 'Erf': Erf.get_converter(opset) + 'Erf': Erf.get_converter(opset), + 'ConstantOfShape': ConstantOfShape.get_converter(opset) } @@ -1162,6 +1176,13 @@ def from_onnx(self, graph, opset): self._params[i_name] = fill_value self._nodes[i_name] = new_var(node.output[0], shape=(), dtype=dtype) inputs.append(self._nodes[i_name]) + if op_name == "ConstantOfShape": + t_proto = self._parse_attr(node.attribute)["value"] + i_name = node.output[0] + self._params[i_name] = self._parse_array(t_proto) + self._nodes[i_name] = new_var(node.input[0], + shape=self._params[node.input[0]].shape, + dtype=self._params[node.input[0]].dtype) i_name = self._parse_value_proto(node) attr['tvm_custom'] = {} diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 16e717401174..b75b6f70fcab 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1300,6 +1300,38 @@ def test_erf(): verify_erf(x, z) +def verify_constantofshape(indata, outdata, out_value): + in_tensor = onnx.helper.make_tensor(name='in', + data_type=TensorProto.FLOAT, + dims=indata.shape, + vals=indata) + + value_tensor = onnx.helper.make_tensor(name='value', + data_type=TensorProto.FLOAT, + dims=[1], + vals=[out_value]) + + node = helper.make_node('ConstantOfShape', inputs=['in'], outputs=['out'], value=value_tensor,) + graph = helper.make_graph([node], + 'ConstantOfShape_test', + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + initializer=[in_tensor]) + model = helper.make_model(graph, producer_name='ConstantOfShape_test') + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape) + tvm.testing.assert_allclose(outdata, tvm_out) + + +def test_constantofshape(): + # test cases: https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConstantOfShape + # float ones + x = np.array([4, 3, 2]).astype(np.int64) + y = np.ones(x, dtype=np.float32) + value = 1 + verify_constantofshape(x, y, out_value=value) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1347,3 +1379,4 @@ def test_erf(): test_and() test_tile() test_erf() + test_constantofshape() From 28aefac2fffd90e3e5ac63f28f0867229d811b57 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Wed, 16 Oct 2019 00:21:15 +0800 Subject: [PATCH 02/62] [Relay][Frontend][ONNX] operator support: ConstantOfShape --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c68c56194ac1..451101eeb6e0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1183,7 +1183,7 @@ def from_onnx(self, graph, opset): self._nodes[i_name] = new_var(node.input[0], shape=self._params[node.input[0]].shape, dtype=self._params[node.input[0]].dtype) - + inputs.append(self._nodes[i_name]) i_name = self._parse_value_proto(node) attr['tvm_custom'] = {} attr['tvm_custom']['name'] = i_name From 1c0e7435cca6e1e128b4b9edb3017dc42a54ffba Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 15 Oct 2019 22:17:47 -0700 Subject: [PATCH 03/62] [QNN] Change default rouning to UPWARD. (#4131) --- include/tvm/relay/qnn/attrs.h | 2 +- python/tvm/relay/qnn/op/qnn.py | 2 +- src/relay/qnn/util.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 83b55b04222a..e5f4ba94e12e 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -49,7 +49,7 @@ struct RequantizeAttrs : public tvm::AttrsNode { .describe("The scale of the output tensor."); TVM_ATTR_FIELD(output_zero_point) .describe("The zero point of the output tensor."); - TVM_ATTR_FIELD(rounding).set_default("TONEAREST") + TVM_ATTR_FIELD(rounding).set_default("UPWARD") .describe("Defines the rounding direction when the value is midway between" "two representable values. There are two supported modes - UPWARD" "or TONEAREST. Both modes behave exactly same except at the" diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index ed443abb5293..c8ebfc00a21b 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -27,7 +27,7 @@ def requantize(data, input_zero_point, output_scale, output_zero_point, - rounding="TONEAREST", + rounding="UPWARD", out_dtype="int8"): r"""Requantized operator. diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index c26183705b89..f94860d28cf9 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -76,7 +76,7 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, static inline Expr Requantize(const Expr& data, const Array& input_shape, double input_scale, int32_t input_zero_point, double output_scale, int32_t output_zero_point, const DataType& out_dtype, - const std::string& rounding = "TONEAREST") { + const std::string& rounding = "UPWARD") { auto attrs = make_node(); attrs->input_scale = std::move(input_scale); attrs->input_zero_point = std::move(input_zero_point); From 46fa6eebcd397ab5e8524ed5da5820b482886ba9 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 16 Oct 2019 07:32:29 -0700 Subject: [PATCH 04/62] [Relay][Training] Add and fix gradients (#4126) * add and fix gradients * fix linter issues --- python/tvm/relay/op/_tensor_grad.py | 82 ++++++++++++++++++++--- tests/python/relay/test_op_grad_level2.py | 29 +++++++- tests/python/relay/test_op_grad_level4.py | 19 +++--- 3 files changed, 111 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 3a82e46e6a7d..1c94162d87d9 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -48,6 +48,9 @@ tile, transpose, where, + repeat, + expand_dims, + full_like ) @@ -198,6 +201,7 @@ def clip_grad(orig, grad): @register_gradient("nn.max_pool2d") def max_pool2d_grad(orig, grad): + """Returns the gradient of max_pool2d.""" attrs = orig.attrs pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, @@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad): @register_gradient("nn.avg_pool2d") def avg_pool2d_grad(orig, grad): + """Returns the gradient of avg_pool2d.""" attrs = orig.attrs pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, @@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad): return [pool_grad] +@register_gradient("nn.global_avg_pool2d") +def global_avg_pool2d_grad(orig, grad): + """Returns the gradient of global_avg_pool2d.""" + data = orig.args[0] + shape = data.checked_type.shape + layout = orig.attrs.layout + + # we assume NCHW or NHWC layout for now, but easy to add more + assert layout in ["NCHW", "NHWC"] + if layout == "NCHW": + pool_size = shape[2], shape[3] + elif layout == "NHWC": + pool_size = shape[1], shape[2] + + pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size, + strides=(1, 1), padding=(0, 0), + layout=layout) + return [pool_grad] + + # not implemented, this is only for testing. @register_gradient("concatenate") def concatenate_grad(orig, grad): @@ -287,16 +312,53 @@ def conv2d_grad(orig, grad): return [backward_data, backward_weight] +def _get_reduce_axis(call): + """Helper function that returns the reduce axis of the call as plain python ints.""" + x, axis = call.args[0], call.attrs.axis + shape = x.checked_type.concrete_shape + + # should never exclude when axis is None + assert not (axis is None and call.attrs.exclude) + + if axis is None: + return None + + # convert to nonnegative integers and sort + axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)]) + if call.attrs.exclude: + axis = [ax for ax in range(len(shape)) if ax not in axis] + return axis + + +def _unreduce_expand(x, axis): + """Helper function that returns x expanded on the reduced dimensions in axis.""" + # assume axis is sorted nonnegative ints + for ax in axis: + x = expand_dims(x, ax) + return x + + @register_gradient("max") def max_grad(orig, grad): """Returns the gradient of max""" - # Only support axis=0, since broadcasting orig to x behaves incorrectly - x, axis = orig.args[0], orig.attrs.axis - assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0) - orig = broadcast_to_like(orig, x) - grad = broadcast_to_like(grad, x) - indicators = cast_like(equal(orig, x), grad) - return [indicators * grad] + x, axis = orig.args[0], _get_reduce_axis(orig) + shape = x.checked_type.concrete_shape + + repeated = orig + if axis is None: + repeated = full_like(x, repeated) + else: + # expand dims (if necessary) and repeat along each axis + if not orig.attrs.keepdims: + repeated = _unreduce_expand(repeated, axis) + grad = _unreduce_expand(grad, axis) + for ax in axis: + repeated = repeat(repeated, shape[ax], ax) + + indicators = cast_like(equal(repeated, x), grad) + num_selected = _sum(indicators, axis, keepdims=True) + # spread error across all max weights + return [indicators * grad / num_selected] @register_gradient("nn.softmax") @@ -372,7 +434,11 @@ def negative_grad(orig, grad): @register_gradient("sum") def sum_grad(orig, grad): """Returns grad broadcasted to data dims""" - data = orig.args[0] + data, axis = orig.args[0], _get_reduce_axis(orig) + if not orig.attrs.keepdims: + if axis is None: + axis = list(range(len(data.checked_type.concrete_shape))) + grad = _unreduce_expand(grad, axis) return [broadcast_to_like(grad, data)] diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 8e809250d1de..57b1e2c676ac 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): def test_max_pool2d_grad(): - verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), - ceil_mode=False) + verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False) verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False) @@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) - def test_avg_pool2d_grad(): verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False, count_include_pad=True) @@ -83,6 +81,30 @@ def test_avg_pool2d_grad(): ceil_mode=False, count_include_pad=False) +def verify_global_avg_pool2d_grad(x_shape): + x = relay.var("x", relay.TensorType(x_shape, "float32")) + y = tvm.relay.nn.global_avg_pool2d(x) + + fwd_func = relay.Function([x], y) + fwd_func = run_infer_type(fwd_func) + bwd_func = run_infer_type(gradient(fwd_func)) + + data = np.random.rand(*x_shape).astype("float32") + y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) + out_grad = np.ones(shape=y_shape) + ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]), + strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg', + ceil_mode=False) + + for target, ctx in ctx_list(): + intrp = relay.create_executor(ctx=ctx, target=target) + op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) + np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) + +def test_global_avg_pool2d_grad(): + verify_global_avg_pool2d_grad((1, 4, 16, 16)) + verify_global_avg_pool2d_grad((1, 8, 8, 24)) + def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'): try: import torch @@ -155,6 +177,7 @@ def test_batch_flatten_grad(): if __name__ == "__main__": test_max_pool2d_grad() test_avg_pool2d_grad() + test_global_avg_pool2d_grad() test_conv2d_grad() test_dense_grad() test_batch_flatten_grad() diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index f8d6c3a56c93..f690a186ea41 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -29,18 +29,21 @@ def test_sum_grad(): verify_sum_grad((4, 2)) verify_sum_grad((4, 2), axis=-1, keepdims=True) verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) + verify_sum_grad((4, 2, 1), axis=1) -def test_max_grad(): - s = (10, 10) - t = relay.TensorType(s) - x = relay.var("x", t) - axis = 0 - z = relay.max(x, axis) - - fwd_func = relay.Function([x], z) +def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude)) check_grad(fwd_func, scale=1e-3) +def test_max_grad(): + verify_max_grad((10, 10), axis=None) + verify_max_grad((10, 10), axis=-1) + verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True) + verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True) + + if __name__ == "__main__": pytest.main() From c1069108e38fef7a979a0315eed06cbbe93be9dd Mon Sep 17 00:00:00 2001 From: shoubhik Date: Wed, 16 Oct 2019 09:44:59 -0700 Subject: [PATCH 05/62] Adding support for dequantizing from int32 to float32. (#4130) --- src/relay/qnn/op/dequantize.cc | 5 +++-- tests/python/relay/test_op_qnn_dequantize.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index ff37e2dd09e3..784572fcab69 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -43,8 +43,9 @@ bool DequantizeRel(const Array& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); const auto input_dtype = data->dtype; - CHECK(input_dtype == Int(8) || input_dtype == UInt(8)) - << "Input type should be one of the quantized types [unit8, int8] but was " << input_dtype; + CHECK(input_dtype == Int(8) || input_dtype == UInt(8) || input_dtype == Int(32)) + << "Input type should be one of the quantized types [unit8, int8, int32] but was " + << input_dtype; const Array oshape = data->shape; // assign output type, output will always be float 32. reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32))); diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index 51258651ab36..a99e78d3a1db 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -44,10 +44,10 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): def test_uint8_to_float32(): data = np.array([0, 1, 2, 3, 4, 251, 252, 253, 254, 255]) \ .astype('uint8') \ - .reshape((2,5)) + .reshape((2, 5)) output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ .astype('float32') \ - .reshape((2,5)) + .reshape((2, 5)) quant_args = {"in_zero_point":127, "in_scale":0.5} quantize_test_driver(in_dtype='uint8', quant_args=quant_args, in_data=data, verify_output_data=output) @@ -55,16 +55,24 @@ def test_uint8_to_float32(): def test_int8_to_float32(): data = np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127]) \ .astype('int8') \ - .reshape((2,5)) + .reshape((2, 5)) output = np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64]) \ .astype('float32') \ - .reshape((2,5)) - quant_args = {"in_zero_point":-1, "in_scale":0.5} + .reshape((2, 5)) + quant_args = {"in_zero_point": -1, "in_scale": 0.5} quantize_test_driver(in_dtype='int8', quant_args=quant_args, in_data=data, verify_output_data=output) + def test_int32_to_float32(): + data = np.array([113, 29, -1052]).astype('int32') + output = np.array([0.6550452, 0.16810896, -6.098297]).astype('float32') + quant_args = {"in_zero_point": 0, "in_scale": 0.0057968604} + quantize_test_driver(in_dtype='int32', quant_args=quant_args, in_data=data, + verify_output_data=output) + test_uint8_to_float32() test_int8_to_float32() + test_int32_to_float32() if __name__ == "__main__": From e3fbdc8b257eddf8fcbff289172727595598229a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 16 Oct 2019 10:44:17 -0700 Subject: [PATCH 06/62] Update PULL_REQUEST_TEMPLATE.md --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 849e4606834e..59825d69d0d4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1 +1 @@ -Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/dmlc/tvm/blob/master/CONTRIBUTORS.md#reviewers). +Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/dmlc/tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread. From 02c1e11716a5afdfc0159a1b21e45a64f7875473 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 16 Oct 2019 15:24:23 -0700 Subject: [PATCH 07/62] [RUNTIME] Refactor object python FFI to new protocol. (#4128) * [RUNTIME] Refactor object python FFI to new protocol. This is a pre-req to bring the Node system under object protocol. Most of the code reflects the current code in the Node system. - Use new instead of init so subclass can define their own constructors - Allow register via name, besides type idnex - Introduce necessary runtime C API functions - Refactored Tensor and Datatype to directly use constructor. * address review comments --- include/tvm/runtime/c_runtime_api.h | 26 +++- include/tvm/runtime/object.h | 1 + include/tvm/runtime/packed_func.h | 12 +- python/tvm/_ffi/_ctypes/function.py | 6 +- python/tvm/_ffi/_ctypes/object.py | 85 +++++++++++ python/tvm/_ffi/_ctypes/vmobj.py | 52 ------- python/tvm/_ffi/_cython/base.pxi | 6 +- python/tvm/_ffi/_cython/core.pyx | 3 +- python/tvm/_ffi/_cython/function.pxi | 12 +- .../_ffi/_cython/{vmobj.pxi => object.pxi} | 55 +++++-- python/tvm/_ffi/function.py | 1 - python/tvm/_ffi/object.py | 130 ++++++++++++++++ python/tvm/_ffi/runtime_ctypes.py | 2 +- python/tvm/_ffi/vmobj.py | 61 -------- python/tvm/api.py | 1 + python/tvm/relay/backend/vm.py | 5 +- python/tvm/relay/backend/vmobj.py | 141 ++++++------------ src/runtime/c_dsl_api.cc | 4 +- src/runtime/c_runtime_api.cc | 4 +- src/runtime/object.cc | 31 ++++ .../frontend/tensorflow/test_forward.py | 4 +- tests/python/relay/test_vm.py | 4 +- tests/python/relay/test_vm_object.py | 46 ++++++ 23 files changed, 440 insertions(+), 252 deletions(-) create mode 100644 python/tvm/_ffi/_ctypes/object.py delete mode 100644 python/tvm/_ffi/_ctypes/vmobj.py rename python/tvm/_ffi/_cython/{vmobj.pxi => object.pxi} (53%) create mode 100644 python/tvm/_ffi/object.py delete mode 100644 python/tvm/_ffi/vmobj.py create mode 100644 tests/python/relay/test_vm_object.py diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 54e6f98e8ee5..b058fd63a2f5 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -104,7 +104,7 @@ typedef enum { kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kObjectCell = 14U, + kObjectHandle = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. @@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type, TVMStreamHandle dst); /*! - * \brief Get the tag from an object. + * \brief Get the type_index from an object. * * \param obj The object handle. - * \param tag The tag of object. + * \param out_tindex the output type index. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag); +TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + +/*! + * \brief Free the object. + * + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectFree(TVMObjectHandle obj); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 7b0653ae5485..0693b1f47b3c 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -253,6 +253,7 @@ class Object { template friend class ObjectPtr; friend class TVMRetValue; + friend class TVMObjectCAPI; }; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5b71bbc66142..2bfa3323e4f1 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -491,7 +491,7 @@ class TVMPODValue_ { } operator ObjectRef() const { if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kObjectCell); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { @@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(ObjectRef other) { this->Clear(); - type_code_ = kObjectCell; + type_code_ = kObjectHandle; // move the handle out value_.v_handle = other.data_.data_; other.data_.data_ = nullptr; @@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ { kNodeHandle, *other.template ptr >()); break; } - case kObjectCell: { + case kObjectHandle: { *this = other.operator ObjectRef(); break; } @@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ { static_cast(value_.v_handle)->DecRef(); break; } - case kObjectCell: { + case kObjectHandle: { static_cast(value_.v_handle)->DecRef(); break; } @@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) { case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; case kNDArrayContainer: return "NDArrayContainer"; - case kObjectCell: return "ObjectCell"; + case kObjectHandle: return "ObjectCell"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } @@ -1164,7 +1164,7 @@ class TVMArgsSetter { } void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectCell; + type_codes_[i] = kObjectHandle; } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 895c72d28d01..22fb6c335dcc 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -33,6 +33,7 @@ from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 from .node import NodeBase +from . import object as _object from . import node as _node FunctionHandle = ctypes.c_void_p @@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args): temp_args.append(arg) elif isinstance(arg, _CLASS_OBJECT): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_CELL + type_codes[i] = TypeCode.OBJECT_HANDLE else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == TypeCode.NODE_HANDLE + assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE) handle = ret_val.v_handle return handle @@ -247,6 +248,7 @@ def _handle_return_func(x): # setup return handle for function type _node.__init_by_constructor__ = __init_handle_by_constructor__ +_object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py new file mode 100644 index 000000000000..5ddceb166677 --- /dev/null +++ b/python/tvm/_ffi/_ctypes/object.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Runtime Object api""" +from __future__ import absolute_import + +import ctypes +from ..base import _LIB, check_call +from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func + + +ObjectHandle = ctypes.c_void_p +__init_by_constructor__ = None + +"""Maps object type to its constructor""" +OBJECT_TYPE = {} + +def _register_object(index, cls): + """register object class""" + OBJECT_TYPE[index] = cls + + +def _return_object(x): + handle = x.v_handle + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + tindex = ctypes.c_uint() + check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) + cls = OBJECT_TYPE.get(tindex.value, ObjectBase) + # Avoid calling __init__ of cls, instead directly call __new__ + # This allows child class to implement their own __init__ + obj = cls.__new__(cls) + obj.handle = handle + return obj + +RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object +C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( + _return_object, TypeCode.OBJECT_HANDLE) + + +class ObjectBase(object): + """Base object for all object types""" + __slots__ = ["handle"] + + def __del__(self): + if _LIB is not None: + check_call(_LIB.TVMObjectFree(self.handle)) + + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + # assign handle first to avoid error raising + self.handle = None + handle = __init_by_constructor__(fconstructor, args) + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + self.handle = handle diff --git a/python/tvm/_ffi/_ctypes/vmobj.py b/python/tvm/_ffi/_ctypes/vmobj.py deleted file mode 100644 index 59930e55c382..000000000000 --- a/python/tvm/_ffi/_ctypes/vmobj.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Runtime Object api""" -from __future__ import absolute_import - -import ctypes -from ..base import _LIB, check_call -from .types import TypeCode, RETURN_SWITCH - -ObjectHandle = ctypes.c_void_p - -"""Maps object type to its constructor""" -OBJECT_TYPE = {} - -def _register_object(index, cls): - """register object class""" - OBJECT_TYPE[index] = cls - - -def _return_object(x): - handle = x.v_handle - if not isinstance(handle, ObjectHandle): - handle = ObjectHandle(handle) - tag = ctypes.c_int() - check_call(_LIB.TVMGetObjectTag(handle, ctypes.byref(tag))) - cls = OBJECT_TYPE.get(tag.value, ObjectBase) - obj = cls(handle) - return obj - -RETURN_SWITCH[TypeCode.OBJECT_CELL] = _return_object - - -class ObjectBase(object): - __slots__ = ["handle"] - - def __init__(self, handle): - self.handle = handle diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 63130ef67d38..76fa96376b47 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -37,7 +37,7 @@ cdef enum TVMTypeCode: kStr = 11 kBytes = 12 kNDArrayContainer = 13 - kObjectCell = 14 + kObjectHandle = 14 kExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMArrayToDLPack(DLTensorHandle arr_from, DLManagedTensor** out) void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) - int TVMGetObjectTag(ObjectHandle obj, int* tag) + int TVMObjectFree(ObjectHandle obj) + int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) + cdef extern from "tvm/c_dsl_api.h": int TVMNodeFree(NodeHandle handle) diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index 4b8536c726aa..a9349338fc6a 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -16,7 +16,8 @@ # under the License. include "./base.pxi" +include "./object.pxi" include "./node.pxi" include "./function.pxi" include "./ndarray.pxi" -include "./vmobj.pxi" + diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index cf1884c32486..ceacf7407170 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args, if (tcode == kNodeHandle or tcode == kFuncHandle or tcode == kModuleHandle or - tcode == kObjectCell or + tcode == kObjectHandle or tcode > kExtBegin): CALL(TVMCbArgToReturn(&value, tcode)) @@ -155,12 +155,12 @@ cdef inline int make_arg(object arg, value[0].v_handle = (arg).chandle tcode[0] = kNodeHandle temp_args.append(arg) + elif isinstance(arg, _CLASS_OBJECT): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kModuleHandle - elif isinstance(arg, _CLASS_OBJECT): - value[0].v_handle = c_handle(arg.handle) - tcode[0] = kObjectCell elif isinstance(arg, FunctionBase): value[0].v_handle = (arg).chandle tcode[0] = kFuncHandle @@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" if tcode == kNodeHandle: return make_ret_node(value.v_handle) + elif tcode == kObjectHandle: + return make_ret_object(value.v_handle) elif tcode == kNull: return None elif tcode == kInt: @@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode): fobj = _CLASS_FUNCTION(None, False) (fobj).chandle = value.v_handle return fobj - elif tcode == kObjectCell: - return make_ret_object(value.v_handle) elif tcode in _TVM_EXT_RET: return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) diff --git a/python/tvm/_ffi/_cython/vmobj.pxi b/python/tvm/_ffi/_cython/object.pxi similarity index 53% rename from python/tvm/_ffi/_cython/vmobj.pxi rename to python/tvm/_ffi/_cython/object.pxi index 9b487566a6a6..90be6a9c5b74 100644 --- a/python/tvm/_ffi/_cython/vmobj.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -19,7 +19,7 @@ OBJECT_TYPE = [] def _register_object(int index, object cls): - """register node class""" + """register object class""" while len(OBJECT_TYPE) <= index: OBJECT_TYPE.append(None) OBJECT_TYPE[index] = cls @@ -27,41 +27,70 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE - cdef int tag + cdef unsigned tindex cdef list object_type cdef object cls cdef object handle object_type = OBJECT_TYPE handle = ctypes_handle(chandle) - CALL(TVMGetObjectTag(chandle, &tag)) - if tag < len(object_type): - cls = object_type[tag] + CALL(TVMObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(object_type): + cls = object_type[tindex] if cls is not None: - obj = cls(handle) + obj = cls.__new__(cls) else: - obj = ObjectBase(handle) + obj = ObjectBase.__new__(ObjectBase) else: - obj = ObjectBase(handle) + obj = ObjectBase.__new__(ObjectBase) + (obj).chandle = chandle return obj cdef class ObjectBase: - cdef ObjectHandle chandle + cdef void* chandle cdef inline _set_handle(self, handle): + cdef unsigned long long ptr if handle is None: self.chandle = NULL else: - self.chandle = c_handle(handle) + ptr = handle.value + self.chandle = (ptr) property handle: def __get__(self): if self.chandle == NULL: return None else: - return ctypes.cast(self.chandle, ctypes.c_void_p) + return ctypes_handle(self.chandle) + def __set__(self, value): self._set_handle(value) - def __init__(self, handle): - self._set_handle(handle) + def __dealloc__(self): + CALL(TVMObjectFree(self.chandle)) + + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + # avoid error raised during construction. + self.chandle = NULL + cdef void* chandle + ConstructorCall( + (fconstructor).chandle, + kObjectHandle, args, &chandle) + self.chandle = chandle diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 4bb31820548f..60e7aeb9aec5 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -22,7 +22,6 @@ import sys import ctypes from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE -from . import vmobj as _vmobj IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py new file mode 100644 index 000000000000..be8b086a50f9 --- /dev/null +++ b/python/tvm/_ffi/object.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Runtime Object API""" +from __future__ import absolute_import + +import sys +import ctypes +from .base import _FFI_MODE, check_call, _LIB, c_str + +IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError + +try: + # pylint: disable=wrong-import-position + if _FFI_MODE == "ctypes": + raise ImportError() + if sys.version_info >= (3, 0): + from ._cy3.core import _set_class_object + from ._cy3.core import ObjectBase as _ObjectBase + from ._cy3.core import _register_object + else: + from ._cy2.core import _set_class_object + from ._cy2.core import ObjectBase as _ObjectBase + from ._cy2.core import _register_object +except IMPORT_EXCEPT: + # pylint: disable=wrong-import-position + from ._ctypes.function import _set_class_object + from ._ctypes.object import ObjectBase as _ObjectBase + from ._ctypes.object import _register_object + + +class Object(_ObjectBase): + """Base class for all tvm's runtime objects.""" + pass + + +def register_object(type_key=None): + """register object type. + + Parameters + ---------- + type_key : str or cls + The type key of the node + + Examples + -------- + The following code registers MyObject + using type key "test.MyObject" + + .. code-block:: python + + @tvm.register_object("test.MyObject") + class MyObject(Object): + pass + """ + object_name = type_key if isinstance(type_key, str) else type_key.__name__ + + def register(cls): + """internal register function""" + if hasattr(cls, "_type_index"): + tindex = cls._type_index + else: + tidx = ctypes.c_uint() + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + tindex = tidx.value + _register_object(tindex, cls) + return cls + + if isinstance(type_key, str): + return register + + return register(type_key) + + +def getitem_helper(obj, elem_getter, length, idx): + """Helper function to implement a pythonic getitem function. + + Parameters + ---------- + obj: object + The original object + + elem_getter : function + A simple function that takes index and return a single element. + + length : int + The size of the array + + idx : int or slice + The argument passed to getitem + + Returns + ------- + result : object + The result of getitem + """ + if isinstance(idx, slice): + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else length + step = idx.step if idx.step is not None else 1 + if start < 0: + start += length + if stop < 0: + stop += length + return [elem_getter(obj, i) for i in range(start, stop, step)] + + if idx < -length or idx >= length: + raise IndexError("Index out of range. size: {}, got index {}" + .format(length, idx)) + if idx < 0: + idx += length + return elem_getter(obj, idx) + + +_set_class_object(Object) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 0d28abd46cb2..00e19459df76 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -42,7 +42,7 @@ class TypeCode(object): STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 - OBJECT_CELL = 14 + OBJECT_HANDLE = 14 EXT_BEGIN = 15 diff --git a/python/tvm/_ffi/vmobj.py b/python/tvm/_ffi/vmobj.py deleted file mode 100644 index ea3431aa973c..000000000000 --- a/python/tvm/_ffi/vmobj.py +++ /dev/null @@ -1,61 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -"""Runtime Object api""" -from __future__ import absolute_import - -import sys -from .base import _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError - -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object - from ._cy3.core import ObjectBase as _ObjectBase - from ._cy3.core import _register_object - else: - from ._cy2.core import _set_class_object - from ._cy2.core import ObjectBase as _ObjectBase - from ._cy2.core import _register_object -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_object - from ._ctypes.vmobj import ObjectBase as _ObjectBase - from ._ctypes.vmobj import _register_object - - -class ObjectTag(object): - """Type code used in API calls""" - TENSOR = 1 - CLOSURE = 2 - DATATYPE = 3 - - -class Object(_ObjectBase): - """The VM Object used in Relay virtual machine.""" - - -def register_object(cls): - _register_object(cls.tag, cls) - return cls - - -_set_class_object(Object) diff --git a/python/tvm/api.py b/python/tvm/api.py index e7523bd733f9..f0261be37e41 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -21,6 +21,7 @@ from numbers import Integral as _Integral from ._ffi.base import string_types +from ._ffi.object import register_object, Object from ._ffi.node import register_node, NodeBase from ._ffi.node import convert_to_node as _convert_to_node from ._ffi.node_generic import _scalar_type_inference diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index e54629dd1344..c24b16ca6437 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -30,9 +30,12 @@ from . import vmobj as _obj from .interpreter import Executor +Tensor = _obj.Tensor +Datatype = _obj.Datatype + def _convert(arg, cargs): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): - cargs.append(_obj.tensor_object(arg)) + cargs.append(_obj.Tensor(arg)) elif isinstance(arg, (tuple, list)): field_args = [] for field in arg: diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index 4c92e9bf38a6..939b122bf510 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -18,32 +18,37 @@ from __future__ import absolute_import as _abs import numpy as _np -from tvm._ffi.vmobj import Object, ObjectTag, register_object +from tvm._ffi.object import Object, register_object, getitem_helper from tvm import ndarray as _nd from . import _vmobj -# TODO(@icemelon9): Add ClosureObject -@register_object -class TensorObject(Object): - """Tensor object.""" - tag = ObjectTag.TENSOR +@register_object("vm.Tensor") +class Tensor(Object): + """Tensor object. - def __init__(self, handle): - """Constructs a Tensor object - - Parameters - ---------- - handle : object - Object handle + Parameters + ---------- + arr : numpy.ndarray or tvm.nd.NDArray + The source array. - Returns - ------- - obj : TensorObject - A tensor object. - """ - super(TensorObject, self).__init__(handle) - self.data = _vmobj.GetTensorData(self) + ctx : TVMContext, optional + The device context to create the array + """ + def __init__(self, arr, ctx=None): + if isinstance(arr, _np.ndarray): + ctx = ctx if ctx else _nd.cpu(0) + self.__init_handle_by_constructor__( + _vmobj.Tensor, _nd.array(arr, ctx=ctx)) + elif isinstance(arr, _nd.NDArray): + self.__init_handle_by_constructor__( + _vmobj.Tensor, arr) + else: + raise RuntimeError("Unsupported type for tensor object.") + + @property + def data(self): + return _vmobj.GetTensorData(self) def asnumpy(self): """Convert data to numpy array @@ -56,65 +61,34 @@ def asnumpy(self): return self.data.asnumpy() -@register_object -class DatatypeObject(Object): - """Datatype object.""" - tag = ObjectTag.DATATYPE +@register_object("vm.Datatype") +class Datatype(Object): + """Datatype object. - def __init__(self, handle): - """Constructs a Datatype object + Parameters + ---------- + tag : int + The tag of datatype. - Parameters - ---------- - handle : object - Object handle + fields : list[Object] or tuple[Object] + The source tuple. + """ + def __init__(self, tag, fields): + for f in fields: + assert isinstance(f, Object) + self.__init_handle_by_constructor__( + _vmobj.Datatype, tag, *fields) - Returns - ------- - obj : DatatypeObject - A Datatype object. - """ - super(DatatypeObject, self).__init__(handle) - self.tag = _vmobj.GetDatatypeTag(self) - num_fields = _vmobj.GetDatatypeNumberOfFields(self) - self.fields = [] - for i in range(num_fields): - self.fields.append(_vmobj.GetDatatypeFields(self, i)) + @property + def tag(self): + return _vmobj.GetDatatypeTag(self) def __getitem__(self, idx): - return self.fields[idx] + return getitem_helper( + self, _vmobj.GetDatatypeFields, len(self), idx) def __len__(self): - return len(self.fields) - - def __iter__(self): - return iter(self.fields) - -# TODO(icemelon9): Add closure object - -def tensor_object(arr, ctx=_nd.cpu(0)): - """Create a tensor object from source arr. - - Parameters - ---------- - arr : numpy.ndarray or tvm.nd.NDArray - The source array. - - ctx : TVMContext, optional - The device context to create the array - - Returns - ------- - ret : TensorObject - The created object. - """ - if isinstance(arr, _np.ndarray): - tensor = _vmobj.Tensor(_nd.array(arr, ctx)) - elif isinstance(arr, _nd.NDArray): - tensor = _vmobj.Tensor(arr) - else: - raise RuntimeError("Unsupported type for tensor object.") - return tensor + return _vmobj.GetDatatypeNumberOfFields(self) def tuple_object(fields): @@ -127,30 +101,9 @@ def tuple_object(fields): Returns ------- - ret : DatatypeObject + ret : Datatype The created object. """ for f in fields: assert isinstance(f, Object) return _vmobj.Tuple(*fields) - - -def datatype_object(tag, fields): - """Create a datatype object from tag and source fields. - - Parameters - ---------- - tag : int - The tag of datatype. - - fields : list[Object] or tuple[Object] - The source tuple. - - Returns - ------- - ret : DatatypeObject - The created object. - """ - for f in fields: - assert isinstance(f, Object) - return _vmobj.Datatype(tag, *fields) diff --git a/src/runtime/c_dsl_api.cc b/src/runtime/c_dsl_api.cc index e45c89a0e9b3..bf9092637420 100644 --- a/src/runtime/c_dsl_api.cc +++ b/src/runtime/c_dsl_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 20793b4618b3..74f0f3e82f27 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 5248da00245a..a52a9b3b4457 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -26,6 +26,7 @@ #include #include #include +#include "runtime_base.h" namespace tvm { namespace runtime { @@ -184,5 +185,35 @@ std::string Object::TypeIndex2Key(uint32_t tindex) { uint32_t Object::TypeKey2Index(const char* key) { return TypeContext::Global()->TypeKey2Index(key); } + +class TVMObjectCAPI { + public: + static void Free(TVMObjectHandle obj) { + static_cast(obj)->DecRef(); + } + + static uint32_t TypeKey2Index(const char* type_key) { + return Object::TypeKey2Index(type_key); + } +}; } // namespace runtime } // namespace tvm + +int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { + API_BEGIN(); + out_tindex[0] = static_cast(obj)->type_index(); + API_END(); +} + +int TVMObjectFree(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::TVMObjectCAPI::Free(obj); + API_END(); +} + +int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { + API_BEGIN(); + out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index( + type_key); + API_END(); +} diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 6432bbde98c6..c2cbbff24173 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -47,9 +47,9 @@ def convert_to_list(x): return x def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.TensorObject): + if isinstance(o, tvm.relay.backend.vmobj.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): + elif isinstance(o, tvm.relay.backend.vmobj.Datatype): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 5289fe9f5411..cedbc4f71859 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -59,9 +59,9 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): return ret def vmobj_to_list(o): - if isinstance(o, tvm.relay.backend.vmobj.TensorObject): + if isinstance(o, tvm.relay.backend.vm.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject): + elif isinstance(o, tvm.relay.backend.vm.Datatype): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py new file mode 100644 index 000000000000..ad21fff8e185 --- /dev/null +++ b/tests/python/relay/test_vm_object.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +from tvm.relay import vm + +def test_tensor(): + arr = tvm.nd.array([1,2,3]) + x = vm.Tensor(arr) + assert isinstance(x, vm.Tensor) + assert x.asnumpy()[0] == 1 + assert x.asnumpy()[-1] == 3 + assert isinstance(x.data, tvm.nd.NDArray) + + +def test_datatype(): + arr = tvm.nd.array([1,2,3]) + x = vm.Tensor(arr) + y = vm.Datatype(0, [x, x]) + + assert len(y) == 2 + assert isinstance(y, vm.Datatype) + y[0:1][-1].data == x.data + assert y.tag == 0 + assert isinstance(x.data, tvm.nd.NDArray) + + + +if __name__ == "__main__": + test_tensor() + test_datatype() From 3185e4ad63b0d7a5ee75ca08f1ee18a44d4f9818 Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Thu, 17 Oct 2019 09:27:57 -0700 Subject: [PATCH 08/62] [Relay] Improve build error when no lowered funcs are produced (#4132) * Improve build error when no lowered funcs * Switch from fatal to warning --- src/relay/backend/build_module.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 4cf13a3d21a2..dfe85fc10908 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -460,7 +460,9 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.params = graph_codegen_->GetParams(); auto lowered_funcs = graph_codegen_->GetLoweredFunc(); - if (lowered_funcs.size() != 0) { + if (lowered_funcs.size() == 0) { + LOG(WARNING) << "no lowered funcs exist in the compiled module"; + } else { ret_.mod = tvm::build( lowered_funcs, target_host_, From 972f019c73cb0118e6d5b9ecf1f8209e5660e20d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 17 Oct 2019 09:31:58 -0700 Subject: [PATCH 09/62] [TOPI][x86] Cascade lake support. (#4123) * [TOPI][x86] Cascade lake support. * Jenkins test debug 1. * Testing cascade lake alone. --- python/tvm/relay/qnn/op/legalizations.py | 2 +- python/tvm/target.py | 10 ++ tests/python/contrib/test_gemm_acc16.py | 4 +- tests/python/contrib/test_gemm_acc32_vnni.py | 6 +- tests/python/relay/test_op_level2.py | 110 +++++++++++-------- topi/python/topi/x86/conv2d_avx_1x1.py | 6 +- topi/python/topi/x86/conv2d_avx_common.py | 4 +- topi/python/topi/x86/conv2d_int8.py | 12 +- topi/python/topi/x86/tensor_intrin.py | 30 +++-- topi/python/topi/x86/util.py | 8 +- 10 files changed, 112 insertions(+), 80 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 0fdc0f3a3231..6b2e073822f1 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -100,7 +100,7 @@ def _is_int8_hw_support(target): Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake and above. """ - supported_arches = {'-mcpu=skylake-avx512',} + supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'} return supported_arches.intersection(set(target.options)) # Collect the dtypes. diff --git a/python/tvm/target.py b/python/tvm/target.py index 4548ffac4c88..42045c0fb733 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -128,6 +128,16 @@ def model(self): return opt.value[7:] return 'unknown' + @property + def mcpu(self): + """Returns the mcpu from the target if it exists.""" + mcpu = '' + if self.options is not None: + for opt in self.options: + if 'mcpu' in opt: + mcpu = opt.split('=')[1] + return mcpu + def __enter__(self): _api_internal._EnterTargetScope(self) return self diff --git a/tests/python/contrib/test_gemm_acc16.py b/tests/python/contrib/test_gemm_acc16.py index 555187838723..17f920efeb8a 100644 --- a/tests/python/contrib/test_gemm_acc16.py +++ b/tests/python/contrib/test_gemm_acc16.py @@ -17,7 +17,7 @@ # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition import tvm import numpy as np -from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int16 +from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int16 def benchmark_fc_int8_acc16(): @@ -40,7 +40,7 @@ def verify(target="llvm -mcpu=skylake-avx512"): ctx = tvm.context(target, 0) X = tvm.placeholder((m, k), name='X', dtype="uint8") W = tvm.placeholder((n, k), name='W', dtype="int8") - pc = dot_16x1x16_int8_int8_int16() + pc = dot_16x1x16_uint8_int8_int16() ak = tvm.reduce_axis((0, k), name='k') packedW = tvm.placeholder((n//128, 128*(k//2), 2), name='packedW', dtype="int8") diff --git a/tests/python/contrib/test_gemm_acc32_vnni.py b/tests/python/contrib/test_gemm_acc32_vnni.py index 34518f4ed9d6..4f535918ba15 100644 --- a/tests/python/contrib/test_gemm_acc32_vnni.py +++ b/tests/python/contrib/test_gemm_acc32_vnni.py @@ -18,8 +18,8 @@ import tvm import numpy as np -from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32_vnni -from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32 +from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32_cascadelake +from topi.x86.tensor_intrin import dot_16x1x16_uint8_int8_int32 import pytest @@ -46,7 +46,7 @@ def verify(target="llvm -mcpu=cascadelake"): return ctx = tvm.context(target, 0) - pc = dot_16x1x16_int8_int8_int32_vnni() + pc = dot_16x1x16_uint8_int8_int32_cascadelake() ak = tvm.reduce_axis((0, k), name='k') packedW = tvm.placeholder( (n // 16, 16 * (k // 4), 4), name='packedW', dtype="int8") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 015582468289..e097980b060c 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -576,57 +576,71 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): assembly = lib.get_source("asm") return assembly - # compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions - target = "llvm -mcpu=skylake-avx512" - name = "llvm.x86.avx512.pmaddubs.w.512" - llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(name) - if llvm_id != 0: - fast_int8_dtypes = ('uint8', 'int8', 'int32') - # Sweep the input channels to check int8 robustness - for ic in range(1, 24): - asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", kernel_layout='OIHW', - dtypes=fast_int8_dtypes) - assert "pmaddubs" in asm - - for ic in range(1, 24): - asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=fast_int8_dtypes) - assert "pmaddubs" in asm - - - # Sweep the output channels to check int8 robustness - for oc in range(2, 24): - asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", kernel_layout='OIHW', + def _has_fast_int8_instructions(asm, target): + if 'skylake-avx512' in target: + return "pmaddubs" in asm + elif 'cascadelake' in target: + return "vpdpbusd" in asm + else: + assert False, "Target should be Skylake or Cascadelake" + + # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions + targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] + llvm_version = tvm.codegen.llvm_version_major() + for target in targets: + if llvm_version >= 8: + fast_int8_dtypes = ('uint8', 'int8', 'int32') + # Sweep the input channels to check int8 robustness + # Input channels should be a multiple of 4 internally. + for ic in [1, 4, 6]: + asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW", + kernel_layout='OIHW', + dtypes=fast_int8_dtypes) + assert _has_fast_int8_instructions(asm, target) + + for ic in [1, 4, 6]: + asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC", + kernel_layout='HWIO', + dtypes=fast_int8_dtypes) + assert _has_fast_int8_instructions(asm, target) + + + # Sweep the output channels to check int8 robustness + # Output channels should be a multiple of 16 internally. + for oc in [4, 16, 20]: + asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW", + kernel_layout='OIHW', + dtypes=fast_int8_dtypes) + assert _has_fast_int8_instructions(asm, target) + + for oc in [4, 16, 20]: + asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", + kernel_layout='HWIO', + dtypes=fast_int8_dtypes) + assert _has_fast_int8_instructions(asm, target) + + # Check that both non-divisible oc and ic work + asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW', dtypes=fast_int8_dtypes) - assert "pmaddubs" in asm + assert _has_fast_int8_instructions(asm, target) - for oc in range(2, 24): - asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC", kernel_layout='HWIO', + asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', dtypes=fast_int8_dtypes) - assert "pmaddubs" in asm - - # Check that both non-divisible oc and ic work - asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW', - dtypes=fast_int8_dtypes) - assert "pmaddubs" in asm - - asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=fast_int8_dtypes) - assert "pmaddubs" in asm - - # Ensure that code is generated when datatypes are not HW supported. - dtypes = ('int8', 'int8', 'int32') - asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=dtypes) - # Check that intrinisic is not present in the assembly. - assert "pmaddubs" not in asm - - # Ensure that code is generated when datatypes are not HW supported. - dtypes = ('uint8', 'uint8', 'int32') - asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', - dtypes=dtypes) - # Check that intrinisic is not present in the assembly. - assert "pmaddubs" not in asm + assert _has_fast_int8_instructions(asm, target) + + # Ensure that code is generated when datatypes are not HW supported. + dtypes = ('int8', 'int8', 'int32') + asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=dtypes) + # Check that intrinisic is not present in the assembly. + assert not _has_fast_int8_instructions(asm, target) + + # Ensure that code is generated when datatypes are not HW supported. + dtypes = ('uint8', 'uint8', 'int32') + asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', + dtypes=dtypes) + # Check that intrinisic is not present in the assembly. + assert not _has_fast_int8_instructions(asm, target) # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 96b6e47789f7..2a81dcc495d3 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -24,7 +24,7 @@ from ..nn.util import infer_pad, get_pad_tuple from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple, simplify -from .tensor_intrin import dot_16x1x16_int8_int8_int32 +from .tensor_intrin import dot_16x1x16_uint8_int8_int32 from .util import get_fp32_len def _fallback_schedule(cfg, wkl): @@ -183,7 +183,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last, int32_lanes=16, - intrin=dot_16x1x16_int8_int8_int32()) + intrin=dot_16x1x16_uint8_int8_int32()) def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): @@ -282,7 +282,7 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor) s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner) - pc = dot_16x1x16_int8_int8_int32() + pc = dot_16x1x16_uint8_int8_int32() s[C].tensorize(oc_inner, pc) if C != O: diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index 53b79bdbeec9..7c5096dc2c1a 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -23,7 +23,7 @@ from ..nn.util import infer_pad from ..generic import conv2d as conv2d_generic from ..util import get_const_tuple -from .tensor_intrin import dot_16x1x16_int8_int8_int32 +from .tensor_intrin import dot_16x1x16_uint8_int8_int32 from .util import get_fp32_len def _fallback_schedule(cfg, wkl): @@ -209,4 +209,4 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): return conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(s, cfg, data, conv_out, last, int32_lanes=16, - intrin=dot_16x1x16_int8_int8_int32()) + intrin=dot_16x1x16_uint8_int8_int32()) diff --git a/topi/python/topi/x86/conv2d_int8.py b/topi/python/topi/x86/conv2d_int8.py index f701108071e5..df53850ec603 100644 --- a/topi/python/topi/x86/conv2d_int8.py +++ b/topi/python/topi/x86/conv2d_int8.py @@ -57,16 +57,14 @@ def _is_int8_hw_support(data_dtype, kernel_dtype): is_dtype_support = data_dtype == 'uint8' and kernel_dtype == 'int8' # 2) Check LLVM support - llvm_intrin_fast_int8 = "llvm.x86.avx512.pmaddubs.w.512" - llvm_id = tvm.codegen.llvm_lookup_intrinsic_id(llvm_intrin_fast_int8) - is_llvm_support = llvm_id != 0 + llvm_version = tvm.codegen.llvm_version_major() + is_llvm_support = llvm_version >= 8 # 3) Check target - target = tvm.target.current_target() + mcpu = tvm.target.current_target().mcpu is_target_support = False - for opt in target.options: - if opt == '-mcpu=skylake-avx512': - is_target_support = True + if mcpu == 'skylake-avx512' or mcpu == 'cascadelake': + is_target_support = True return is_dtype_support and is_llvm_support and is_target_support diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index cba00c023f89..a8ad251115d7 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -19,15 +19,27 @@ import tvm -def dot_16x1x16_int8_int8_int32(): +def dot_16x1x16_uint8_int8_int32(): + """Dispatch the most optimized intrin depending on the target""" + mcpu = tvm.target.current_target().mcpu + + assert mcpu in ("skylake-avx512", "cascadelake"), \ + "An old Intel machine that does not have fast Int8 support." + if mcpu == "skylake-avx512": + return dot_16x1x16_uint8_int8_int32_skylake() + # cascadelake + return dot_16x1x16_uint8_int8_int32_cascadelake() + + +def dot_16x1x16_uint8_int8_int32_skylake(): """ Int8 dot product by every 4 elements using AVX512 Skylake instructions. - This function takes two arrays of int8 datatype -- data[4] and + This function takes two arrays of uint8 and int8 datatype -- data[4] and kernel[16][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[16] of int32 datatype. The pseudo code is as follows. .. code-block:: c - void dot_16x1x16_int8_int8_int32(int8 data[4], int8 kernel[16][4], + void dot_16x1x16_uint8_int8_int32(uint8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < 16; i++){ output[i] = 0; @@ -100,15 +112,15 @@ def _instr(index): return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) -def dot_16x1x16_int8_int8_int16(): +def dot_16x1x16_uint8_int8_int16(): """ Int8 dot product by every 2 elements using AVX512 Skylake instructions. - This function takes two arrays of int8 datatype -- data[2] and + This function takes two arrays of uint8 and int8 datatype -- data[2] and kernel[4][32][2] -- and computes a dot product of data[2] with every 2 elements of kernels, resulting in output[4][32] of int16 datatype. The pseudo code is as follows. .. code-block:: c - void dot_16x1x16_int8_int8_int16(int8 data[2], int8 kernel[32*4][2], + void dot_16x1x16_uint8_int8_int16(uint8 data[2], int8 kernel[32*4][2], int16 output[32*4]){ for (int i = 0; i< 4; i++){ for (int j = 0; j < 32; j++){ @@ -182,15 +194,15 @@ def _instr(index): return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) -def dot_16x1x16_int8_int8_int32_vnni(): +def dot_16x1x16_uint8_int8_int32_cascadelake(): """ Int8 dot product by every 4 elements using AVX512VNNI Cascade Lake instructions. - This function takes two arrays of int8 datatype -- data[4] and + This function takes two arrays of uint8 and int8 datatype -- data[4] and kernel[16][4] -- and computes a dot product of data[4] with every 4 elements of kernels, resulting in output[16] of int32 datatype. The pseudo code is as follows. .. code-block:: c - void dot_16x1x16_int8_int8_int32_vnni(int8 data[4], int8 kernel[16][4], + void dot_16x1x16_uint8_int8_int32_cascadelake(uint8 data[4], int8 kernel[16][4], int32 output[16]){ for (int i = 0; i < 16; i++){ output[i] = 0; diff --git a/topi/python/topi/x86/util.py b/topi/python/topi/x86/util.py index f0b3c755e1e2..00f297e4307f 100644 --- a/topi/python/topi/x86/util.py +++ b/topi/python/topi/x86/util.py @@ -19,10 +19,8 @@ import tvm def get_fp32_len(): + mcpu = tvm.target.current_target().mcpu fp32_vec_len = 8 - target = tvm.target.current_target() - if target is not None: - for opt in target.options: - if opt == '-mcpu=skylake-avx512': - fp32_vec_len = 16 + if mcpu == 'skylake-avx512' or mcpu == 'cascadelake': + fp32_vec_len = 16 return fp32_vec_len From a8a983176dfe41506458b628e7666dd0a6347807 Mon Sep 17 00:00:00 2001 From: Marcus Shawcroft Date: Thu, 17 Oct 2019 18:04:46 +0100 Subject: [PATCH 10/62] [DOCKER] Pin torchvision==0.4.1 (#4140) The existing sequence of pip install commands fetches and installs torch==1.0.1.post2 then fetches an unpinned version of torchvision, recent torchvision packages hardwire the specific torch version they depend on, the overall effect is that we install a pinned torch version then replace it with whatever version the torchvision package depends on. The most recent torchvision==0.4.1 package results in some test case failures. This patch pins torchvision back to 0.4.0, the most recent version that the test suite worked. Removing the explicit torch install because it is implied and pinned as dependency of torchvision. Change-Id: Ib30bf6aed79ff130ea15ef5134fefb0508790574 --- docker/install/ubuntu_install_onnx.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index f3e8d8e8f540..54210b83f4d6 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -27,5 +27,4 @@ pip3 install onnx==1.5.0 # not expose that in the wheel!!! pip3 install future -pip3 install https://download.pytorch.org/whl/cu80/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl -pip3 install torchvision +pip3 install torch==1.2.0 torchvision==0.4.0 From cf046972eb5602c2d1b67edea230f6ca07c966b1 Mon Sep 17 00:00:00 2001 From: lhutton1 <35535092+lhutton1@users.noreply.github.com> Date: Thu, 17 Oct 2019 18:05:10 +0100 Subject: [PATCH 11/62] [PATCH] Fix undefined __floatdihf in libtvmruntime.so on aarch64. (#4119) Arm architecture provides optional FP16 floating point support in two alternative formats, IEEE and an an alternative Arm format. The ACLE (Arm C Language Extension) defined preprocessor symbol __ARM_FP16_FORMAT_IEEE can be used to distinguish between implementations providing IEEE and the Arm alternative format, but cannot, on its own, be used to determined if FP16 HW support is actually present. Testing this preprocessor symbol can lead to undefined __floatdihf at runtime on an aarch64 target where no FP16 HW is present. The relevant preprocessor symbol to determine whether FP16 HW support is present in the target is __ARM_FEATURE_FP16_SCALAR_ARITHMETIC, this symbol implies __ARM_FP16_FORMAT_IEEE. The relevant preprocessor symbols are defined by the ACLE standard, section 5.5.21 16-bit floating-point data processing operations, https://static.docs.arm.com/101028/0008/Q2-ACLE_2019Q2_release-0008.pdf --- src/contrib/sort/sort.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index a87ce07cb602..0ccaee515acb 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " "to be float."; -#if (__ARM_FP16_FORMAT_IEEE != 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " "to be float32."; #endif @@ -100,23 +100,23 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } if (is_ascend) { -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) if (dtype.bits == 16) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); } else { #endif std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } else { -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) if (dtype.bits == 16) { std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { #endif std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } @@ -210,7 +210,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } -#if (__ARM_FP16_FORMAT_IEEE == 1) +#if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } else if (data_dtype == "float16") { if (out_dtype == "float16") { argsort<__fp16, __fp16>(input, output, axis, is_ascend); From 4052de6d1ce446d124363c3530bc2ad2fb7bfa80 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 17 Oct 2019 13:25:08 -0700 Subject: [PATCH 12/62] [relay][vm] Separate VM runtime with executable (#4100) * [relay][vm] Separate VM runtime with executable * Address comments * move ctx back to vm * make only vm related fields and methods protected * integrate seriliaztion/deserialization to executable * create stream --- include/tvm/runtime/vm.h | 210 ++++- python/tvm/relay/__init__.py | 2 - python/tvm/relay/backend/deserializer.py | 81 -- python/tvm/relay/backend/profiler_vm.py | 12 +- python/tvm/relay/backend/serializer.py | 191 ----- python/tvm/relay/backend/vm.py | 232 +++++- src/relay/backend/vm/compiler.cc | 20 +- src/relay/backend/vm/compiler.h | 12 +- src/relay/backend/vm/deserializer.cc | 324 -------- src/relay/backend/vm/deserializer.h | 102 --- src/relay/backend/vm/profiler/compiler.cc | 1 - src/relay/backend/vm/serializer.cc | 439 ----------- src/relay/backend/vm/serializer.h | 202 ----- src/runtime/vm/executable.cc | 734 ++++++++++++++++++ src/runtime/vm/profiler/vm.cc | 29 +- src/runtime/vm/profiler/vm.h | 2 + .../backend => runtime}/vm/serialize_util.h | 12 +- src/runtime/vm/vm.cc | 92 +-- tests/python/relay/test_vm.py | 30 +- tests/python/relay/test_vm_serialization.py | 119 ++- .../unittest/test_runtime_vm_profiler.py | 4 +- 21 files changed, 1285 insertions(+), 1565 deletions(-) delete mode 100644 python/tvm/relay/backend/deserializer.py delete mode 100644 python/tvm/relay/backend/serializer.py delete mode 100644 src/relay/backend/vm/deserializer.cc delete mode 100644 src/relay/backend/vm/deserializer.h delete mode 100644 src/relay/backend/vm/serializer.cc delete mode 100644 src/relay/backend/vm/serializer.h create mode 100644 src/runtime/vm/executable.cc rename src/{relay/backend => runtime}/vm/serialize_util.h (95%) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index aa8543d569af..a276c658c496 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -430,15 +431,184 @@ struct VMFrame { caller_return_register(0) {} }; +/*! \brief The executable emitted by the VM compiler. + * + * The executable contains information (e.g. data in different memory regions) + * to run in a virtual machine. + * + * - Global section, containing all globals. + * - Constant section, storing the constant pool. + * - Primitive name section, containing the function name of the primitive ops + * used by the virtual machine. + * - Code section, handling the VM functions and bytecode. + */ +class Executable : public ModuleNode { + public: + /*! + * \brief Get a PackedFunc from an executable module. + * + * \param name the name of the function. + * \param sptr_to_self The shared_ptr that points to this module node. + * + * \return PackedFunc or nullptr when it is not available. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + /*! + * \brief Serialize the executable into global section, constant section, and + * code section. + * + * \return The binary representation of the VM. + */ + TVMByteArray Save(); + + /*! + * \brief Load the saved VM executable. + * + * \param code The bytecode in string. + * \param lib The compiled runtime library. + * + * \return exe The constructed executable. + */ + static runtime::Module Load(const std::string& code, const runtime::Module lib); + + /*! + * \brief Get the serialized form of the `functions`. This is + * essentially bytecode serialization. + * + * \return The serialized vm bytecode. + * + * \note The bytecode is in the following format: + * func_name reg_file_size num_instructions + * param1 param2 ... paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Each instruction is printed in the following format: + * opcode num_fields field1 ... fieldX # The text format. + * + * Serializing an `Instruction` requires us to deal with the bytecode. Each line + * of the instructions could be serialized as the following format: + * hash, opcode, f1, f2, ..., fX, field with variable length + * 1. hash: the hash of the instruction. This number will be used to help us + * validate if an instruction is well-formed during deserialization. + * 2. opcode: the opcode code of the instruction. + * 3. f1, f2, ..., fX. These fields together represent the fixed fields in + * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For + * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). + * 4. The rest of the line indicates the field with variable length, e.g., + * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. + + * The field starting from # is only used for debugging. The serialized code + * doesn't contain it, therefore the deserializer doens't need to handle it. + */ + std::string GetBytecode() const; + +/*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + + /*! \brief Get the `lib` module in an executable. Users have the flexibility to call + * `export_library` from the frontend to save the library to disk. + * + * \return The runtime module that contains the hardwre dependent code. + */ + runtime::Module GetLib() const { return lib; } + + virtual ~Executable() {} + + const char* type_key() const final { + return "VMExecutable"; + } + + /*! \brief The runtime module/library that contains both the host and also the device + * code when executing on non-CPU devices. */ + runtime::Module lib; + /*! \brief The global constant pool. */ + std::vector constants; + /*! \brief A map from globals (as strings) to their index in the function map. */ + std::unordered_map global_map; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object. + */ + std::unordered_map primitive_map; + /*! \brief The virtual machine's function table. */ + std::vector functions; + + private: + /*! + * \brief Save the globals. + * + * \param strm The input stream. + */ + void SaveGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Save the constant pool. + * + * \param strm The input stream. + */ + void SaveConstantSection(dmlc::Stream* strm); + + /*! + * \brief Save primitive op names. + * + * \param strm The input stream. + */ + void SavePrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Save the vm functions. + * + * \param strm The input stream. + */ + void SaveCodeSection(dmlc::Stream* strm); + + /*! + * \brief Load the globals. + * + * \param strm The input stream. + */ + void LoadGlobalSection(dmlc::Stream* strm); + + /*! + * \brief Load the constant pool. + * + * \param strm The input stream. + */ + void LoadConstantSection(dmlc::Stream* strm); + + /*! + * \brief Load primitive op names. + * + * \param strm The input stream. + */ + void LoadPrimitiveOpNames(dmlc::Stream* strm); + + /*! + * \brief Load the vm functions. + * + * \param strm The input stream. + */ + void LoadCodeSection(dmlc::Stream* strm); + + /*! \brief The serialized bytecode. */ + std::string code_; +}; + /*! \brief The virtual machine. * * The virtual machine contains all the current execution state, - * as well as the global view of functions, the global constant - * table, the compiled operators. + * as well as the executable. * * The goal is to have a single self-contained object, * enabling one to easily pass around VMs, execute them on - * multiple threads, or serialized them to disk or over the + * multiple threads, or serialize them to disk or over the * wire. */ class VirtualMachine : public runtime::ModuleNode { @@ -486,16 +656,18 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } - /*! \brief The runtime module/library that contains generated code. */ - runtime::Module lib; + VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} + + /*! \brief load the executable for the virtual machine. + * \param exec The executable. + */ + void LoadExecutable(const Executable* exec); + + protected: /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; - /*! \brief The virtual machine's function table. */ - std::vector functions; /*! \brief The current stack of call frames. */ std::vector frames; - /*! \brief The global constant pool. */ - std::vector constants; /*! \brief The fuction table index of the current function. */ Index func_index; /*! \brief The current pointer to the code section. */ @@ -506,6 +678,9 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The special return register. */ ObjectRef return_register; + /*! \brief The executable the VM will operate on. */ + const Executable* exec; + /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs; @@ -550,8 +725,6 @@ class VirtualMachine : public runtime::ModuleNode { */ ObjectRef Invoke(const std::string& name, const std::vector& args); - VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} - /*! \brief Initialize the virtual machine for a set of contexts. * \param contexts The set of TVM contexts. */ @@ -565,21 +738,6 @@ class VirtualMachine : public runtime::ModuleNode { */ TVMContext GetParamsContext() const; - /*! - * \brief Load parameters from the parameter bytearray. - * \param params The binary file that contains parameters. - */ - void LoadParams(const std::string& params); - - /*! \brief A map from globals (as strings) to their index in the function map. - */ - std::unordered_map global_map; - - /*! \brief A mapping from the packed function (as string) to the index that - * corresponds to the position of the `packed_funcs` list. - */ - std::unordered_map primitive_map; - private: /*! \brief Invoke a global setting up the VM state to execute. * diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index ceb98c4d251e..fff9c99e5007 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -37,8 +37,6 @@ from . import feature from .backend import vm from .backend import profiler_vm -from .backend import serializer -from .backend import deserializer from .backend import vmobj # Root operators diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py deleted file mode 100644 index fde702b1cd04..000000000000 --- a/python/tvm/relay/backend/deserializer.py +++ /dev/null @@ -1,81 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine deserializer. - -Python interface for deserializing a Relay VM. -""" -from tvm import module -from tvm._ffi.runtime_ctypes import TVMByteArray -from . import _vm -from . import vm as rly_vm - -def _create_deserializer(code, lib): - """Create a deserializer object. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - - Returns - ------- - ret : Deserializer - The created virtual machine deserializer. - """ - if isinstance(code, (bytes, str)): - code = bytearray(code) - elif not isinstance(code, (bytearray, TVMByteArray)): - raise TypeError("vm is expected to be the type of bytearray or " + - "TVMByteArray, but received {}".format(type(code))) - - if not isinstance(lib, module.Module): - raise TypeError("lib is expected to be the type of tvm.module.Module" + - ", but received {}".format(type(lib))) - return _vm._Deserializer(code, lib) - - -class Deserializer: - """Relay VM deserializer. - - Parameters - ---------- - code : bytearray - The serialized virtual machine code. - - lib : :py:class:`~tvm.module.Module` - The serialized runtime module/library that contains the hardware - dependent binary code. - """ - def __init__(self, code, lib): - self.mod = _create_deserializer(code, lib) - self._deserialize = self.mod["deserialize"] - - def deserialize(self): - """Deserialize the serialized bytecode into a Relay VM. - - Returns - ------- - ret : VirtualMachine - The deserialized Relay VM. - """ - return rly_vm.VirtualMachine(self._deserialize()) diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 8ae3161e0b83..b36715249f0a 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -49,8 +49,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachineProfiler - The profile VM runtime. + exec : Executable + The executable with profiling code. """ compiler = VMCompilerProfiler() target = compiler.update_target(target) @@ -60,7 +60,7 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachineProfiler(compiler._get_vm()) + return vm.Executable(compiler._get_exec()) class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" @@ -68,13 +68,17 @@ def __init__(self): super().__init__() self.mod = _vm._VMCompilerProfiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] class VirtualMachineProfiler(vm.VirtualMachine): """Relay profile VM runtime.""" def __init__(self, mod): super().__init__(mod) + m = mod.module if isinstance(mod, vm.Executable) else mod + self.mod = _vm._VirtualMachineDebug(m) + self._init = self.mod["init"] + self._invoke = self.mod["invoke"] self._get_stat = self.mod["get_stat"] def get_stat(self): diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py deleted file mode 100644 index b45ba9116a15..000000000000 --- a/python/tvm/relay/backend/serializer.py +++ /dev/null @@ -1,191 +0,0 @@ -# License .to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name -""" -The Relay Virtual Machine serializer. - -Python interface for serializing a Relay VM. -""" -import tvm -from . import _vm -from . import vm as rly_vm - -def _create_serializer(vm): - """Create a VM serializer. - - Parameters - ---------- - vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`] - The virtual machine to be serialized. - - Returns - ------- - ret : Serializer - The created virtual machine serializer. - """ - if isinstance(vm, rly_vm.VirtualMachine): - vm = vm.module - elif not isinstance(vm, tvm.module.Module): - raise TypeError("vm is expected to be the type of VirtualMachine or " + - "tvm.Module, but received {}".format(type(vm))) - - return _vm._Serializer(vm) - - -class Serializer: - """Relay VM serializer.""" - def __init__(self, vm): - self.mod = _create_serializer(vm) - self._get_lib = self.mod["get_lib"] - self._get_bytecode = self.mod["get_bytecode"] - self._get_globals = self.mod["get_globals"] - self._get_stats = self.mod["get_stats"] - self._get_primitive_ops = self.mod["get_primitive_ops"] - self._serialize = self.mod["serialize"] - - @property - def stats(self): - """Get the statistics of the Relay VM. - - Returns - ------- - ret : String - The serialized statistic information. - """ - return self._get_stats() - - @property - def primitive_ops(self): - """Get the name of the primitive ops that are executed in the VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The list of primitive ops. - """ - return [prim_op.value for prim_op in self._get_primitive_ops()] - - @property - def bytecode(self): - """Get the bytecode of the Relay VM. - - Returns - ------- - ret : String - The serialized bytecode. - - Notes - ----- - The bytecode is in the following format: - func_name reg_file_size num_instructions - param1 param2 ... paramM - instruction1 - instruction2 - ... - instructionN - - Each instruction is printed in the following format: - hash opcode field1 ... fieldX # The text format. - - The part starting from # is only used for visualization and debugging. - The real serialized code doesn't contain it, therefore the deserializer - doesn't need to deal with it as well. - """ - return self._get_bytecode() - - @property - def globals(self): - """Get the globals used by the Relay VM. - - Returns - ------- - ret : List[:py:class:`~tvm.expr.StringImm`] - The serialized globals. - """ - return [glb.value for glb in self._get_globals()] - - def serialize(self): - """Serialize the Relay VM. - - Returns - ------- - code : bytearray - The binary blob representing a serialized Relay VM. It can then be - saved to disk and later deserialized into a new VM. - - lib : :py:class:`~tvm.module.Module` - The runtime module that contains the generated code. It is - basically a library that is composed of hardware dependent code. - - Notes - ----- - The returned code is organized with the following sections in order. - - Global section. This section contains the globals used by the - virtual machine. - - Constant section. This section is used to store the constant pool of - a virtual machine. - - Primitive name section. This section is introduced to accommodate - the list of primitive operator names that will be invoked by the - virtual machine. - - Code section. The VM functions, including bytecode, are sitting in - this section. - - Examples - -------- - .. code-block:: python - - import numpy as np - import tvm - from tvm import relay - - # define a simple network. - x = relay.var('x', shape=(10, 10)) - f = relay.Function([x], x + x) - mod = relay.Module({"main": f}) - - # create a Relay VM. - ctx = tvm.cpu() - target = "llvm" - compiler = relay.vm.VMCompiler() - vm = compiler.compile(mod, target) - vm.init(ctx) - - # serialize. - ser = relay.serializer.Serializer(vm) - code, lib = ser.serialize() - - # save and load the code and lib file. - tmp = tvm.contrib.util.tempdir() - path_lib = tmp.relpath("lib.so") - lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: - fo.write(code) - - loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) - - # deserialize. - deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() - - # execute the deserialized vm. - des_vm.init(ctx) - x_data = np.random.rand(10, 10).astype('float32') - res = des_vm.run(x_data) - print(res.asnumpy()) - """ - return self._serialize(), self._get_lib() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index c24b16ca6437..942c93b866f4 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -24,8 +24,8 @@ import tvm from tvm import autotvm -from tvm._ffi.runtime_ctypes import TVMByteArray from tvm.relay import expr as _expr +from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj from .interpreter import Executor @@ -44,6 +44,7 @@ def _convert(arg, cargs): else: raise "unsupported type" + def convert(args): cargs = [] for arg in args: @@ -52,12 +53,202 @@ def convert(args): return cargs +class Executable(object): + """Relay VM executable""" + def __init__(self, mod): + self.mod = mod + self._save = self.mod["save"] + self._get_lib = self.mod["get_lib"] + self._get_bytecode = self.mod["get_bytecode"] + self._get_stats = self.mod["get_stats"] + + def save(self): + """Save the Relay VM Executable. + + Returns + ------- + code : bytearray + The binary blob representing a serialized Relay VM executable. It + can then be saved to disk and later deserialized into a new + Executable. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. It is + basically a library that is composed of hardware dependent code. + + Notes + ----- + The returned code is organized with the following sections in order. + - Global section. This section contains the globals used by the + virtual machine. + - Constant section. This section is used to store the constant pool of + a virtual machine. + - Primitive name section. This section is introduced to accommodate + the list of primitive operator names that will be invoked by the + virtual machine. + - Code section. The VM functions, including bytecode, are sitting in + this section. + + Examples + -------- + + .. code-block:: python + + import numpy as np + import tvm + from tvm import relay + # define a simple network. + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + mod = relay.Module({"main": f}) + # create a Relay VM. + ctx = tvm.cpu() + target = "llvm" + executable = relay.vm.compile(mod, target) + code, lib = executable.save() + # save and load the code and lib file. + tmp = tvm.contrib.util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.ro"), "wb") as fo: + fo.write(code) + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) + # deserialize. + des_exec = relay.vm.Executable.load_exec(loaded_code, loaded_code) + # execute the deserialized executable. + x_data = np.random.rand(10, 10).astype('float32') + des_vm = relay.vm.VirtualMachine(des_exec) + des_vm.init(ctx) + res = des_vm.run(x_data) + print(res.asnumpy()) + """ + return self._save(), self._get_lib() + + @staticmethod + def load_exec(bytecode, lib): + """Construct an executable from saved artifacts. + + Parameters + ---------- + bytecode : bytearray + The binary blob representing a the Relay VM bytecode. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. + + Returns + ------- + exec: Executable + An executable constructed using the provided artifacts. + """ + if isinstance(bytecode, (bytes, str)): + code = bytearray(bytecode) + elif not isinstance(bytecode, (bytearray, TVMByteArray)): + raise TypeError("bytecode is expected to be the type of bytearray " + + "or TVMByteArray, but received {}".format(type(code))) + + if not isinstance(lib, tvm.module.Module): + raise TypeError("lib is expected to be the type of tvm.module.Module" + + ", but received {}".format(type(lib))) + + return Executable(_vm.Load_Executable(bytecode, lib)) + + @property + def lib(self): + """Get the library that contains hardware dependent code. + + Returns + ------- + ret : :py:class:`~tvm.Module` + The runtime module that contains hardware dependent code. + """ + return self._get_lib() + + @property + def stats(self): + """Get the statistics of the Relay VM executable. + + Returns + ------- + ret : String + The statistic information of the VM executable. + """ + return self._get_stats() + + @property + def primitive_ops(self): + """Get the name of the primitive ops contained in the executable. + + Returns + ------- + ret : List[String] + The list of primitive ops. + """ + ret = [] + num_primitives = _vm.GetNumOfPrimitives(self.module) + for i in range(num_primitives): + ret.append(_vm.GetPrimitiveFields(self.module, i)) + return ret + + @property + def bytecode(self): + """Get the bytecode of the Relay VM executable. + + Returns + ------- + ret : String + The bytecode of the executable. + + Notes + ----- + The bytecode is in the following format: + func_name reg_file_size num_instructions + param1 param2 ... paramM + instruction1 + instruction2 + ... + instructionN + + Each instruction is printed in the following format: + hash opcode field1 ... fieldX # The text format. + + The part starting from # is only used for visualization and debugging. + The real serialized code doesn't contain it, therefore the deserializer + doesn't need to deal with it as well. + """ + return self._get_bytecode() + + @property + def globals(self): + """Get the globals used by the Relay VM executable. + + Returns + ------- + ret : List[String] + The globals contained in the executable. + """ + ret = [] + num_globals = _vm.GetNumOfGlobals(self.module) + for i in range(num_globals): + ret.append(_vm.GetGlobalFields(self.module, i)) + return ret + + @property + def module(self): + """Return the runtime module contained in a virtual machine executable.""" + return self.mod + + class VirtualMachine(object): """Relay VM runtime.""" def __init__(self, mod): - self.mod = mod + if not isinstance(mod, (Executable, tvm.module.Module)): + raise TypeError("mod is expected to be the type of Executable or " + + "tvm.Module, but received {}".format(type(mod))) + m = mod.module if isinstance(mod, Executable) else mod + self.mod = _vm._VirtualMachine(m) self._init = self.mod["init"] - self._load_params = self.mod["load_params"] self._invoke = self.mod["invoke"] def init(self, ctx): @@ -71,23 +262,6 @@ def init(self, ctx): args = [ctx.device_type, ctx.device_id] self._init(*args) - def load_params(self, params): - """Load parameters for the VM. - - Parameters - ---------- - params : Union[bytearray, Dict] - The dictionary that contains serialized parameters. - """ - if isinstance(params, dict): - params = tvm.relay.save_param_dict(params) - elif isinstance(params, (bytes, str)): - params = bytearray(params) - if not isinstance(params, (bytearray, TVMByteArray)): - raise TypeError("params must be a bytearray") - - self._load_params(bytearray(params)) - def invoke(self, func_name, *args): """Invoke a function. @@ -122,11 +296,6 @@ def run(self, *args): """ return self.invoke("main", *args) - @property - def module(self): - """Return the runtime module contained in a virtual machine.""" - return self.mod - def compile(mod, target=None, target_host=None, params=None): """ @@ -155,8 +324,8 @@ def compile(mod, target=None, target_host=None, params=None): Returns ------- - vm : VirtualMachine - The VM runtime. + exec : Executable + The VM executable that contains both library code and bytecode. """ compiler = VMCompiler() @@ -167,14 +336,14 @@ def compile(mod, target=None, target_host=None, params=None): tophub_context = compiler.tophub_context(target) with tophub_context: compiler._compile(mod, target, target_host) - return VirtualMachine(compiler._get_vm()) + return Executable(compiler._get_exec()) class VMCompiler(object): """Build Relay module to run on VM runtime.""" def __init__(self): self.mod = _vm._VMCompiler() self._compile = self.mod["compile"] - self._get_vm = self.mod["get_vm"] + self._get_exec = self.mod["get_executable"] self._set_params_func = self.mod["set_params"] def set_params(self, params): @@ -240,7 +409,7 @@ class VMExecutor(Executor): mod : :py:class:`~tvm.relay.module.Module` The module to support the execution. - ctx : :py:class:`TVMContext` + ctx : :py:class:`~tvm.TVMContext` The runtime context to run the code on. target : :py:class:`Target` @@ -252,7 +421,8 @@ def __init__(self, mod, ctx, target): self.mod = mod self.ctx = ctx self.target = target - self.vm = compile(mod, target) + self.executable = compile(mod, target) + self.vm = VirtualMachine(self.executable) self.vm.init(ctx) def _make_executor(self, expr=None): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 0cfae374ab2c..f295ccd7a555 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -783,9 +783,9 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, Module mod = args[0]; this->Compile(mod, args[1], args[2]); }); - } else if (name == "get_vm") { + } else if (name == "get_executable") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(vm_); + *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -864,7 +864,7 @@ void VMCompiler::Compile(Module mod, // Next we get ready by allocating space for // the global state. - vm_->functions.resize(context_.module->functions.size()); + exec_->functions.resize(context_.module->functions.size()); for (auto named_func : context_.module->functions) { auto gvar = named_func.first; @@ -873,25 +873,25 @@ void VMCompiler::Compile(Module mod, auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); - CHECK(func_index < vm_->functions.size()); - vm_->functions[func_index] = vm_func; + CHECK(func_index < exec_->functions.size()); + exec_->functions[func_index] = vm_func; } #if USE_RELAY_DEBUG - for (auto vm_func : vm_->functions) { + for (auto vm_func : exec_->functions) { DLOG(INFO) << vm_func << "-------------"; } #endif // USE_RELAY_DEBUG // populate constants for (auto data : context_.constants) { - vm_->constants.push_back(runtime::vm::Tensor(data)); + exec_->constants.push_back(runtime::vm::Tensor(data)); } LibraryCodegen(); for (auto gv : context_.global_map) { - vm_->global_map.insert({gv.first->name_hint, gv.second}); + exec_->global_map.insert({gv.first->name_hint, gv.second}); } } @@ -987,13 +987,13 @@ void VMCompiler::LibraryCodegen() { // therefore target won't be used in the build function runtime::Module mod = (*f)(funcs, Target(), target_host_); CHECK(mod.operator->()); - vm_->lib = mod; + exec_->lib = mod; } else { LOG(FATAL) << "relay.backend.build is not registered"; } size_t primitive_index = 0; for (auto cfunc : cached_funcs) { - vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); + exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); } } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index dff1ef7f4569..215cc12c4cdb 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -92,12 +92,8 @@ class VMCompiler : public runtime::ModuleNode { return "VMCompiler"; } - std::shared_ptr GetVirtualMachine() const { - return vm_; - } - - virtual void InitVM() { - vm_ = std::make_shared(); + void InitVM() { + exec_ = std::make_shared(); } /*! @@ -144,8 +140,8 @@ class VMCompiler : public runtime::ModuleNode { tvm::Target target_host_; /*! \brief Global shared meta data */ VMCompilerContext context_; - /*! \brief Compiled virtual machine. */ - std::shared_ptr vm_; + /*! \brief Compiled executable. */ + std::shared_ptr exec_; /*! \brief parameters */ std::unordered_map params_; }; diff --git a/src/relay/backend/vm/deserializer.cc b/src/relay/backend/vm/deserializer.cc deleted file mode 100644 index 777282782e99..000000000000 --- a/src/relay/backend/vm/deserializer.cc +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.cc - * \brief Implementation of APIs to deserialize the serialized VM bytecode. - */ - -#include "deserializer.h" - -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -#define STREAM_CHECK(val, section) \ - CHECK(val) << "Invalid VM file format in the " << section << " section." \ - << "\n"; - -void Deserializer::Init(const std::string& code, const runtime::Module& lib) { - code_ = code; - vm_ = std::make_shared(); - vm_->lib = lib; - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Deserializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "deserialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Deserialize(); - *rv = runtime::Module(vm_); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -void Deserializer::Deserialize() { - // Check header. - uint64_t header; - STREAM_CHECK(strm_->Read(&header), "header"); - STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); - - // Check version. - std::string version; - STREAM_CHECK(strm_->Read(&version), "version"); - STREAM_CHECK(version == TVM_VERSION, "version"); - - // Global section. - DeserializeGlobalSection(); - - // Constant section. - DeserializeConstantSection(); - - // Primitive names that will be invoked by `InvokePacked` instructions. - DeserializePrimitiveOpNames(); - - // Code section. - DeserializeCodeSection(); -} - -void Deserializer::DeserializeGlobalSection() { - std::vector globals; - STREAM_CHECK(strm_->Read(&globals), "global"); - for (size_t i = 0; i < globals.size(); i++) { - vm_->global_map.insert({globals[i], i}); - } -} - -void Deserializer::DeserializeConstantSection() { - uint64_t sz; - // Load the number of constants. - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); - - size_t size = static_cast(sz); - // Load each of the constants. - for (size_t i = 0; i < size; i++) { - runtime::NDArray constant; - STREAM_CHECK(constant.Load(strm_), "constant"); - runtime::ObjectRef obj = runtime::vm::Tensor(constant); - vm_->constants.push_back(obj); - } -} - -void Deserializer::DeserializePrimitiveOpNames() { - std::vector primitive_names; - STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); - for (size_t i = 0; i < primitive_names.size(); i++) { - vm_->primitive_map.insert({primitive_names[i], i}); - } -} - -// Extract the `cnt` number of fields started at `start` from the list -// `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, - Index cnt) { - CHECK_LE(static_cast(start + cnt), instr_fields.size()); - std::vector ret; - for (auto i = start; i < start + cnt; i++) { - ret.push_back(instr_fields[i]); - } - return ret; -} - -Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { - Opcode opcode = static_cast(instr.opcode); - switch (opcode) { - case Opcode::Move: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::Move(instr.fields[0], instr.fields[1]); - } - case Opcode::Ret: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Ret(instr.fields[0]); - } - case Opcode::Fatal: { - // Number of fields = 0 - DCHECK(instr.fields.empty()); - return Instruction::Fatal(); - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index packed_index = instr.fields[0]; - Index arity = instr.fields[1]; - Index output_size = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, arity); - return Instruction::InvokePacked(packed_index, arity, output_size, args); - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 5U); - DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); - - DLDataType dtype; - dtype.code = instr.fields[0]; - dtype.bits = instr.fields[1]; - dtype.lanes = instr.fields[2]; - - Index ndim = instr.fields[3]; - RegName dst = instr.fields[4]; - - std::vector shape = ExtractFields(instr.fields, 5, ndim); - - return Instruction::AllocTensor(shape, dtype, dst); - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 5U); - Index shape_register = instr.fields[0]; - - DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; - - RegName dst = instr.fields[4]; - - return Instruction::AllocTensorReg(shape_register, dtype, dst); - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index constructor_tag = instr.fields[0]; - Index num_fields = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector fields = ExtractFields(instr.fields, 3, num_fields); - - return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index clo_index = instr.fields[0]; - Index num_freevar = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); - - return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); - } - case Opcode::If: { - // Number of fields = 4 - DCHECK_EQ(instr.fields.size(), 4U); - Index test = instr.fields[0]; - Index target = instr.fields[1]; - Index true_offset = instr.fields[2]; - Index false_offset = instr.fields[3]; - - return Instruction::If(test, target, true_offset, false_offset); - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index func_index = instr.fields[0]; - Index num_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_args); - - return Instruction::Invoke(func_index, args, dst); - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - DCHECK_GE(instr.fields.size(), 3U); - DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); - - Index closure = instr.fields[0]; - Index num_closure_args = instr.fields[1]; - RegName dst = instr.fields[2]; - std::vector args = ExtractFields(instr.fields, 3, num_closure_args); - - return Instruction::InvokeClosure(closure, args, dst); - } - case Opcode::LoadConst: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConst(instr.fields[0], instr.fields[1]); - } - case Opcode::LoadConsti: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); - } - case Opcode::GetField: { - // Number of fields = 3 - DCHECK_EQ(instr.fields.size(), 3U); - return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); - } - case Opcode::GetTag: { - // Number of fields = 2 - DCHECK_EQ(instr.fields.size(), 2U); - return Instruction::GetTag(instr.fields[0], instr.fields[1]); - } - case Opcode::Goto: { - // Number of fields = 1 - DCHECK_EQ(instr.fields.size(), 1U); - return Instruction::Goto(instr.fields[0]); - } - default: - LOG(FATAL) << "Invalid opcode" << instr.opcode; - return Instruction(); - } -} - -void Deserializer::DeserializeCodeSection() { - // Load the number of functions. - uint64_t sz; - STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); - - size_t num_funcs = static_cast(sz); - vm_->functions.resize(num_funcs); - for (size_t i = 0; i < num_funcs; i++) { - // Load the function info. - VMFunctionSerializer loaded_func; - STREAM_CHECK(loaded_func.Load(strm_), "code/function"); - - // Load the instructions. - std::vector instructions; - for (size_t j = 0; j < loaded_func.num_instructions; j++) { - VMInstructionSerializer instr; - std::vector instr_fields; - STREAM_CHECK(instr.Load(strm_), "code/instruction"); - instructions.push_back(DeserializeInstruction(instr)); - } - - // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, - loaded_func.register_file_size); - auto it = vm_->global_map.find(loaded_func.name); - CHECK(it != vm_->global_map.end()); - CHECK_LE(it->second, vm_->global_map.size()); - vm_->functions[it->second] = vm_func; - } -} - -runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) { - std::shared_ptr exec = std::make_shared(); - exec->Init(code, lib); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Deserializer") -.set_body_typed(CreateDeserializer); - -} // namespace vm -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/vm/deserializer.h b/src/relay/backend/vm/deserializer.h deleted file mode 100644 index 0caf72bee92c..000000000000 --- a/src/relay/backend/vm/deserializer.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/deserializer.h - * \brief Define a deserializer for the serialized Relay VM. - */ - -#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace vm { - -using namespace tvm::runtime::vm; -namespace runtime = tvm::runtime; - -class Deserializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the deserializer for creating a virtual machine object. - * - * \param code The serialized code. - * \param lib The serialized runtime module/library that contains the - * hardware dependent code. - */ - inline void Init(const std::string& code, const runtime::Module& lib); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Deserializer"; } - - /*! \brief Deserialize the serialized VM. */ - void Deserialize(); - - virtual ~Deserializer() { delete strm_; } - - private: - /*! \brief Deserialize the globals in `vm_`. */ - void DeserializeGlobalSection(); - - /*! \brief Deserialize the constant pool in `vm_`. */ - void DeserializeConstantSection(); - - /*! \brief Deserialize primitive op names in `vm_`. */ - void DeserializePrimitiveOpNames(); - - /*! \brief Deserialize the vm functions in `vm_`. */ - void DeserializeCodeSection(); - - /*! \brief The code to be serialized. */ - std::string code_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The VM to be created. */ - std::shared_ptr vm_; -}; - -} // namespace vm -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ diff --git a/src/relay/backend/vm/profiler/compiler.cc b/src/relay/backend/vm/profiler/compiler.cc index 9fd28e8c7f46..60c441a60cf0 100644 --- a/src/relay/backend/vm/profiler/compiler.cc +++ b/src/relay/backend/vm/profiler/compiler.cc @@ -33,7 +33,6 @@ namespace vm { class VMCompilerDebug : public VMCompiler { public: VMCompilerDebug() {} - void InitVM() override { vm_ = std::make_shared(); } virtual ~VMCompilerDebug() {} }; diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc deleted file mode 100644 index 0040ef9db470..000000000000 --- a/src/relay/backend/vm/serializer.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.cc - * \brief Implementation of serializing APIs for the Relay VM. - */ -#include "serializer.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "serialize_util.h" - -namespace tvm { -namespace relay { -namespace vm { - -void Serializer::Init(const VirtualMachine* vm) { - vm_ = vm; - // Initialize the stream object. - strm_ = new dmlc::MemoryStringStream(&code_); -} - -runtime::PackedFunc Serializer::GetFunction( - const std::string& name, - const std::shared_ptr& sptr_to_self) { - if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); - } else if (name == "get_primitive_ops") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetPrimitiveOps(); - }); - } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); - } else if (name == "get_globals") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGlobals(); - }); - } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); - } else if (name == "serialize") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Serialize(); - }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -tvm::Array Serializer::GetPrimitiveOps() const { - std::vector ret; - for (const auto& it : vm_->primitive_map) { - auto packed_name = tvm::ir::StringImm::make(it.first); - auto packed_index = static_cast(it.second); - if (ret.size() <= packed_index) { - ret.resize(packed_index + 1); - } - ret[packed_index] = packed_name; - } - return ret; -} - -std::string Serializer::Stats() const { - std::ostringstream oss; - oss << "Relay VM statistics:" << std::endl; - - // Get the number of constants and the shape of each of them. - oss << " Constant shapes (# " << vm_->constants.size() << "): ["; - for (const auto& it : vm_->constants) { - auto* cell = it.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - const auto& shape = data.Shape(); - - // Scalar - if (shape.empty()) { - oss << "scalar, "; - continue; - } - - oss << "["; - for (auto s : shape) { - oss << s << ", "; - } - oss.seekp(-2, oss.cur); - oss << "], " << std::endl; - } - if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of globals and the name of each of them. - oss << " Globals (#" << vm_->global_map.size() << "): ["; - for (const auto& it : vm_->global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; - } - if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - // Get the number of primitive ops and the name of each of them. - oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; - const auto& prim_ops = GetPrimitiveOps(); - for (const auto& it : prim_ops) { - oss << it << ", "; - } - if (!prim_ops.empty()) oss.seekp(-2, oss.cur); - oss << "]" << std::endl; - - return oss.str(); -} - -TVMByteArray Serializer::Serialize() { - uint64_t header = kTVMVMBytecodeMagic; - strm_->Write(header); - std::string version = TVM_VERSION; - strm_->Write(version); - - // Global section. - SerializeGlobalSection(); - - // Constant section. - SerializeConstantSection(); - - // Primitive names. - SerializePrimitiveOpNames(); - - // Code section. - SerializeCodeSection(); - - TVMByteArray arr; - arr.data = code_.c_str(); - arr.size = code_.length(); - return arr; -} - -void Serializer::SerializeGlobalSection() { - auto globals = GetGlobals(); - std::vector glbs; - for (const auto& it : globals) { - glbs.push_back(it.as()->value); - } - strm_->Write(glbs); -} - -void Serializer::SerializeConstantSection() { - std::vector arrays; - for (const auto& obj : vm_->constants) { - const auto* cell = obj.as(); - CHECK(cell != nullptr); - runtime::NDArray data = cell->data; - arrays.push_back(const_cast(data.operator->())); - } - strm_->Write(static_cast(vm_->constants.size())); - for (const auto& it : arrays) { - runtime::SaveDLTensor(strm_, it); - } -} - -void Serializer::SerializePrimitiveOpNames() { - auto names = GetPrimitiveOps(); - std::vector primitive_names; - for (const auto& it : names) { - primitive_names.push_back(it.as()->value); - } - strm_->Write(primitive_names); -} - -// Serialize a virtual machine instruction. It creates a list that contains the -// hash, opcode, and all fields of an instruction. -// -// For example, the function signature used to create an `AllocTensor` -// instruction is: -// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) -// -// The serialized form will be: -// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` -// -// where hash is the hash of serialized instruction that is computed internally -// by the `VMInstructionSerializer`. It is used for sanity check before decoding. -// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` -// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` -// is the destination register, and the rest of it together indicates the shape -// of the tensor to be allocated. -VMInstructionSerializer SerializeInstruction(const Instruction& instr) { - std::vector fields; - // Save the opcode. - DLOG(INFO) << "Serializing: " << instr << std::endl; - switch (instr.op) { - case Opcode::Move: { - // Number of fields = 2 - fields.assign({instr.from, instr.dst}); - break; - } - case Opcode::Ret: { - // Number of fields = 1 - fields.push_back(instr.result); - break; - } - case Opcode::Fatal: { - // Number of fields = 0 - break; - } - case Opcode::InvokePacked: { - // Number of fields = 3 + instr.arity - // Note that arity includes both input arguments and outputs. We will - // put all the `arity` number of fields in the end for serialization. - fields.assign({instr.packed_index, instr.arity, instr.output_size}); - // Save the args. - fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); - break; - } - case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - - // The number of dimensions is not needed for constructing an - // `AllocTensor` instruction as it equals to the length of the `shape` - // vector. However, we save it to conveniently deserialize the instruction - // because we will know how many fields are needed by the `shape` argument. - fields.push_back(instr.alloc_tensor.ndim); - fields.push_back(instr.dst); - - // Save the shape of the tensor. - // Note that this field is rotated to the end of the list. - fields.insert(fields.end(), instr.alloc_tensor.shape, - instr.alloc_tensor.shape + instr.alloc_tensor.ndim); - break; - } - case Opcode::AllocTensorReg: { - // Number of fields = 5 - fields.push_back(instr.alloc_tensor_reg.shape_register); - // Save `DLDataType` and the dst register. - const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); - fields.push_back(instr.dst); - break; - } - case Opcode::AllocDatatype: { - // Number of fields = 3 + instr.num_fields - fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); - - // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); - break; - } - case Opcode::AllocClosure: { - // Number of fields = 3 + instr.num_freevar - fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); - - // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); - break; - } - case Opcode::If: { - // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, - instr.if_op.false_offset}); - break; - } - case Opcode::Invoke: { - // Number of fields = 3 + instr.num_args - fields.assign({instr.func_index, instr.num_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.invoke_args_registers, - instr.invoke_args_registers + instr.num_args); - break; - } - case Opcode::InvokeClosure: { - // Number of fields = 3 + instr.num_closure_args - fields.assign({instr.closure, instr.num_closure_args, instr.dst}); - - // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); - break; - } - case Opcode::LoadConst: { - // Number of fields = 2 - fields.assign({instr.const_index, instr.dst}); - break; - } - case Opcode::LoadConsti: { - // Number of fields = 2 - fields.assign({instr.load_consti.val, instr.dst}); - break; - } - case Opcode::GetField: { - // Number of fields = 3 - fields.assign({instr.object, instr.field_index, instr.dst}); - break; - } - case Opcode::GetTag: { - // Number of fields = 2 - fields.assign({instr.get_tag.object, instr.dst}); - break; - } - case Opcode::Goto: { - // Number of fields = 1 - fields.push_back(instr.pc_offset); - break; - } - default: - LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); - break; - } - - return VMInstructionSerializer(static_cast(instr.op), fields); -} - -void Serializer::SerializeCodeSection() { - // Save the number of functions. - strm_->Write(static_cast(vm_->functions.size())); - for (const auto& func : vm_->functions) { - // Serialize the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), - func.params); - func_format.Save(strm_); - - // Serialize each instruction. - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - serialized_instr.Save(strm_); - } - } -} - -tvm::Array Serializer::GetGlobals() const { - tvm::Array ret; - std::vector > globals(vm_->global_map.begin(), - vm_->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { - return a.second < b.second; - }; - std::sort(globals.begin(), globals.end(), comp); - for (const auto& it : globals) { - ret.push_back(tvm::ir::StringImm::make(it.first)); - } - return ret; -} - -std::string Serializer::GetBytecode() const { - std::ostringstream oss; - - for (const auto& func : vm_->functions) { - // Print the header of the function format. - oss << "# func name, reg file size, param count, inst count:" - << std::endl; - oss << func.name << " " - << func.register_file_size << " " - << func.params.size() << " " - << func.instructions.size() << std::endl; - - // Print pramams of a `VMFunction`. - oss << "# Parameters:"<< std::endl; - for (const auto& param : func.params) { - oss << param << " "; - } - oss << std::endl; - - // Print the instructions of a `VMFunction`. - // The part after ";" is the instruction in text format. - oss << "hash, opcode, fields # inst(text):"<< std::endl; - for (const auto& instr : func.instructions) { - const auto& serialized_instr = SerializeInstruction(instr); - oss << std::hex << "0x" << serialized_instr.Hash() << " " - << std::dec << serialized_instr.opcode << " "; - for (auto it : serialized_instr.fields) { - oss << it << " "; - } - oss << " # " << instr; - if (oss.str().back() != '\n') oss << std::endl; - } - } - - return oss.str(); -} - -runtime::Module Serializer::GetLib() const { - return vm_->lib; -} - -runtime::Module CreateSerializer(const VirtualMachine* vm) { - std::shared_ptr exec = std::make_shared(); - exec->Init(vm); - return runtime::Module(exec); -} - -TVM_REGISTER_GLOBAL("relay._vm._Serializer") -.set_body([](TVMArgs args, TVMRetValue* rv) { - runtime::Module mod = args[0]; - const auto* vm = dynamic_cast(mod.operator->()); - CHECK(vm) << "Virtual machine has not been defined yet." - << "\n"; - *rv = CreateSerializer(vm); -}); - -} // namespace vm -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/vm/serializer.h b/src/relay/backend/vm/serializer.h deleted file mode 100644 index 2371bb4c94f5..000000000000 --- a/src/relay/backend/vm/serializer.h +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serializer.h - * \brief Define a serializer for the Relay VM. - * - * The following components of a Relay VM will be serialized: - * - The `constants`, e.g., the constant pool, that contains the - * constants used in a Relay program. - * - The `packed_funcs` that essentially contains the generated code for - * a specific target. We return it as a runtime module that can be exported as - * a library file (e.g., .so, .o, or .tar). - * - The `global_map` that contains the globals. - * - The `primitive_map` that contains the name of individual primitive operators. - * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of - * a list of instructions/bytecode. - * - * Note that only the library is returned as a separate module. All othere parts - * are stored in a single serialized code that is organized with the following - * sections in order. - * - Global section, containing all globals. - * - Constant section, storing the constant pool. - * - Primitive name section, containing the function name of the primitive ops - * used by the virtual machine. - * - Code section, handling the VM functions and bytecode. - * - * The code section is again organized as follows for each VM function: - * func_name, register_file_size, num_instructions (N) - * param1, param2, ..., paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Serializing an `Instruction` requires us to deal with the bytecode. Each line - * of the instructions could be serialized as the following format: - * hash, opcode, f1, f2, ..., fX, field with variable length - * 1. hash: the hash of the instruction. This number will be used to help us - * validate if an instruction is well-formed during deserialization. - * 2. opcode: the opcode code of the instruction. - * 3. f1, f2, ..., fX. These fields together represent the fixed fields in - * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For - * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). - * 4. The rest of the line indicates the field with variable length, e.g., - * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. - */ - -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace relay { -namespace vm { - -using namespace tvm::runtime; -using namespace tvm::runtime::vm; - -/*! - * \brief The Relay VM serializer. - */ -class Serializer : public runtime::ModuleNode { - public: - /*! - * \brief Initialize the serializer for a virtual machine. - * - * \param vm The Relay virtual machine. - */ - inline void Init(const VirtualMachine* vm); - - /*! - * \brief Return the member function to the frontend. - * - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * - * \return The corresponding member function. - */ - PackedFunc GetFunction(const std::string& name, - const std::shared_ptr& sptr_to_self) final; - - const char* type_key() const final { return "Serializer"; } - - /*! - * \brief Print the detailed statistics of the given code, i.e. number of - * globls and constants, etc. - */ - std::string Stats() const; - - /*! - * \brief Serialize the `vm_` into global section, constant section, and code - * section. - * - * \return The binary representation of the VM. - */ - TVMByteArray Serialize(); - - /*! - * \brief Get a list of the globals used by the `_vm`. - * - * \return The global map in the form a list. - */ - tvm::Array GetGlobals() const; - - /*! - * \brief Get the primitive operators that are contained in the Relay VM. - * - * \return The list of primitve operators. - */ - tvm::Array GetPrimitiveOps() const; - - /*! - * \brief Get the serialized form of the `functions` in `vm_`. This is - * essentially bytecode serialization. - * - * \return The serialized vm bytecode. - * - * \note The bytecode is in the following format: - * func_name reg_file_size num_instructions - * param1 param2 ... paramM - * instruction1 - * instruction2 - * ... - * instructionN - * - * Each instruction is printed in the following format: - * opcode num_fields field1 ... fieldX # The text format. - * - * The field starting from # is only used for debugging. The serialized code - * doesn't contain it, therefore the deserializer doens't need to handle it. - */ - std::string GetBytecode() const; - - /*! \brief Get the `lib` module in vm_. Serialization of `runtime::module` - * has already been supported by TVM. Therefore, we only return the runtime - * module and let users have the flexibility to call `export_library` from - * the frontend to save the library to disk. - * - * \return The runtime module that contains the hardwre dependent code. - */ - inline runtime::Module GetLib() const; - - virtual ~Serializer() { delete strm_; } - - private: - /*! \brief Serialize the globals in vm_. */ - void SerializeGlobalSection(); - - /*! \brief Serialize the constant pool in vm_. */ - void SerializeConstantSection(); - - /*! \brief Serialize primitive op names in vm_. */ - void SerializePrimitiveOpNames(); - - /*! \brief Serialize the vm functions in vm_. */ - void SerializeCodeSection(); - - /*! \brief The Relay virtual machine for to be serialized. */ - const VirtualMachine* vm_; - - /*! \brief The stream used for serialization. */ - dmlc::Stream* strm_; - - /*! \brief The serialized code. */ - std::string code_; -}; - -} // namespace vm -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_ diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc new file mode 100644 index 000000000000..21f71af4eb8c --- /dev/null +++ b/src/runtime/vm/executable.cc @@ -0,0 +1,734 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/runtime/vm/executable.cc + * \brief The implementation of a virtual machine executable APIs. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "serialize_util.h" + +namespace tvm { +namespace runtime { +namespace vm { + +#define STREAM_CHECK(val, section) \ + CHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +// Helper to serialize a vm instruction. +VMInstructionSerializer SerializeInstruction(const Instruction& instr); +// Helper to deserialize a serialized vm instruction. +Instruction DeserializeInstruction(const VMInstructionSerializer& instr); + +PackedFunc Executable::GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetLib(); + }); + } else if (name == "get_bytecode") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetBytecode(); + }); + } else if (name == "get_stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Stats(); + }); + } else if (name == "save") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Save(); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } +} + +std::string Executable::GetBytecode() const { + std::ostringstream oss; + + for (const auto& func : functions) { + // Print the header of the function format. + oss << "# func name, reg file size, param count, inst count:" + << std::endl; + oss << func.name << " " + << func.register_file_size << " " + << func.params.size() << " " + << func.instructions.size() << std::endl; + + // Print pramams of a `VMFunction`. + oss << "# Parameters: "<< std::endl; + for (const auto& param : func.params) { + oss << param << " "; + } + oss << std::endl; + + // Print the instructions of a `VMFunction`. + // The part after ";" is the instruction in text format. + oss << "hash, opcode, fields # inst(text):"<< std::endl; + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + oss << std::hex << "0x" << serialized_instr.Hash() << " " + << std::dec << serialized_instr.opcode << " "; + for (auto it : serialized_instr.fields) { + oss << it << " "; + } + oss << " # " << instr; + if (oss.str().back() != '\n') oss << std::endl; + } + } + + return oss.str(); +} + +std::string Executable::Stats() const { + std::ostringstream oss; + oss << "Relay VM executable statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << constants.size() << "): ["; + for (const auto& it : constants) { + const auto* cell = it.as(); + CHECK(cell); + runtime::NDArray data = cell->data; + const auto& shape = data.Shape(); + + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << global_map.size() << "): ["; + for (const auto& it : global_map) { + oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + } + if (!global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of primitive ops and the name of each of them. + oss << " Primitive ops (#" << primitive_map.size() << "): ["; + std::vector prim_ops; + for (const auto& it : primitive_map) { + auto packed_index = static_cast(it.second); + if (prim_ops.size() <= packed_index) { + prim_ops.resize(packed_index + 1); + } + prim_ops[packed_index] = it.first; + } + for (const auto& it : prim_ops) { + oss << it << ", "; + } + if (!prim_ops.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +void SaveHeader(dmlc::Stream* strm) { + uint64_t header = kTVMVMBytecodeMagic; + strm->Write(header); + std::string version = TVM_VERSION; + strm->Write(version); +} + +TVMByteArray Executable::Save() { + // Initialize the stream object. + code_.clear(); + dmlc::MemoryStringStream strm(&code_); + + // Save header + SaveHeader(&strm); + + // Global section. + SaveGlobalSection(&strm); + + // Constant section. + SaveConstantSection(&strm); + + // Primitive names. + SavePrimitiveOpNames(&strm); + + // Code section. + SaveCodeSection(&strm); + + TVMByteArray arr; + arr.data = code_.c_str(); + arr.size = code_.length(); + return arr; +} + +void Executable::SaveGlobalSection(dmlc::Stream* strm) { + std::vector > globals(this->global_map.begin(), + this->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + + std::vector glbs; + for (const auto& it : globals) { + glbs.push_back(it.first); + } + strm->Write(glbs); +} + +void Executable::SaveConstantSection(dmlc::Stream* strm) { + std::vector arrays; + for (const auto& obj : this->constants) { + const auto* cell = obj.as(); + CHECK(cell != nullptr); + runtime::NDArray data = cell->data; + arrays.push_back(const_cast(data.operator->())); + } + strm->Write(static_cast(this->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm, it); + } +} + +void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { + std::vector primitive_names; + for (const auto& it : this->primitive_map) { + auto packed_index = static_cast(it.second); + if (primitive_names.size() <= packed_index) { + primitive_names.resize(packed_index + 1); + } + primitive_names[packed_index] = it.first; + } + strm->Write(primitive_names); +} + +// Serialize a virtual machine instruction. It creates a list that contains the +// hash, opcode, and all fields of an instruction. +// +// For example, the function signature used to create an `AllocTensor` +// instruction is: +// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) +// +// The serialized form will be: +// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` +// +// where hash is the hash of serialized instruction that is computed internally +// by the `VMInstructionExecutable`. It is used for sanity check before decoding. +// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` +// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` +// is the destination register, and the rest of it together indicates the shape +// of the tensor to be allocated. +VMInstructionSerializer SerializeInstruction(const Instruction& instr) { + std::vector fields; + // Save the opcode. + DLOG(INFO) << "Serializing: " << instr << std::endl; + switch (instr.op) { + case Opcode::Move: { + // Number of fields = 2 + fields.assign({instr.from, instr.dst}); + break; + } + case Opcode::Ret: { + // Number of fields = 1 + fields.push_back(instr.result); + break; + } + case Opcode::Fatal: { + // Number of fields = 0 + break; + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + // Note that arity includes both input arguments and outputs. We will + // put all the `arity` number of fields in the end for serialization. + fields.assign({instr.packed_index, instr.arity, instr.output_size}); + // Save the args. + fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); + break; + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + + // The number of dimensions is not needed for constructing an + // `AllocTensor` instruction as it equals to the length of the `shape` + // vector. However, we save it to conveniently deserialize the instruction + // because we will know how many fields are needed by the `shape` argument. + fields.push_back(instr.alloc_tensor.ndim); + fields.push_back(instr.dst); + + // Save the shape of the tensor. + // Note that this field is rotated to the end of the list. + fields.insert(fields.end(), instr.alloc_tensor.shape, + instr.alloc_tensor.shape + instr.alloc_tensor.ndim); + break; + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + fields.push_back(instr.alloc_tensor_reg.shape_register); + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(instr.dst); + break; + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); + + // Save the fields. + fields.insert(fields.end(), instr.datatype_fields, + instr.datatype_fields + instr.num_fields); + break; + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); + + // Save the free vars. + fields.insert(fields.end(), instr.free_vars, + instr.free_vars + instr.num_freevar); + break; + } + case Opcode::If: { + // Number of fields = 4 + fields.assign({instr.if_op.test, + instr.if_op.target, + instr.if_op.true_offset, + instr.if_op.false_offset}); + break; + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + fields.assign({instr.func_index, instr.num_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.invoke_args_registers, + instr.invoke_args_registers + instr.num_args); + break; + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + fields.assign({instr.closure, instr.num_closure_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.closure_args, + instr.closure_args + instr.num_closure_args); + break; + } + case Opcode::LoadConst: { + // Number of fields = 2 + fields.assign({instr.const_index, instr.dst}); + break; + } + case Opcode::LoadConsti: { + // Number of fields = 2 + fields.assign({instr.load_consti.val, instr.dst}); + break; + } + case Opcode::GetField: { + // Number of fields = 3 + fields.assign({instr.object, instr.field_index, instr.dst}); + break; + } + case Opcode::GetTag: { + // Number of fields = 2 + fields.assign({instr.get_tag.object, instr.dst}); + break; + } + case Opcode::Goto: { + // Number of fields = 1 + fields.push_back(instr.pc_offset); + break; + } + default: + LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); + break; + } + + return VMInstructionSerializer(static_cast(instr.op), fields); +} + +void Executable::SaveCodeSection(dmlc::Stream* strm) { + // Save the number of functions. + strm->Write(static_cast(this->functions.size())); + for (const auto& func : this->functions) { + // Save the function info. + VMFunctionSerializer func_format(func.name, + func.register_file_size, + func.instructions.size(), + func.params); + func_format.Save(strm); + + // Serialize each instruction. + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + serialized_instr.Save(strm); + } + } +} + +void LoadHeader(dmlc::Stream* strm) { + // Check header. + uint64_t header; + STREAM_CHECK(strm->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); +} + +runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->lib = lib; + exec->code_ = code; + dmlc::MemoryStringStream strm(&exec->code_); + + // Load header. + LoadHeader(&strm); + + // Global section. + exec->LoadGlobalSection(&strm); + + // Constant section. + exec->LoadConstantSection(&strm); + + // Primitive names that will be invoked by `InvokePacked` instructions. + exec->LoadPrimitiveOpNames(&strm); + + // Code section. + exec->LoadCodeSection(&strm); + + return runtime::Module(exec); +} + +void Executable::LoadGlobalSection(dmlc::Stream* strm) { + std::vector globals; + STREAM_CHECK(strm->Read(&globals), "global"); + for (size_t i = 0; i < globals.size(); i++) { + this->global_map.insert({globals[i], i}); + } +} + +void Executable::LoadConstantSection(dmlc::Stream* strm) { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + runtime::NDArray constant; + STREAM_CHECK(constant.Load(strm), "constant"); + runtime::ObjectRef obj = runtime::vm::Tensor(constant); + this->constants.push_back(obj); + } +} + +void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { + std::vector primitive_names; + STREAM_CHECK(strm->Read(&primitive_names), "primitive name"); + for (size_t i = 0; i < primitive_names.size(); i++) { + this->primitive_map.insert({primitive_names[i], i}); + } +} + +// Extract the `cnt` number of fields started at `start` from the list +// `instr_fields`. +inline std::vector ExtractFields(const std::vector& instr_fields, + Index start, + Index cnt) { + CHECK_LE(static_cast(start + cnt), instr_fields.size()); + std::vector ret; + for (auto i = start; i < start + cnt; i++) { + ret.push_back(instr_fields[i]); + } + return ret; +} + +Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { + Opcode opcode = static_cast(instr.opcode); + switch (opcode) { + case Opcode::Move: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::Move(instr.fields[0], instr.fields[1]); + } + case Opcode::Ret: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Ret(instr.fields[0]); + } + case Opcode::Fatal: { + // Number of fields = 0 + DCHECK(instr.fields.empty()); + return Instruction::Fatal(); + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index packed_index = instr.fields[0]; + Index arity = instr.fields[1]; + Index output_size = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, arity); + return Instruction::InvokePacked(packed_index, arity, output_size, args); + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 5U); + DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); + + DLDataType dtype; + dtype.code = instr.fields[0]; + dtype.bits = instr.fields[1]; + dtype.lanes = instr.fields[2]; + + Index ndim = instr.fields[3]; + RegName dst = instr.fields[4]; + + std::vector shape = ExtractFields(instr.fields, 5, ndim); + + return Instruction::AllocTensor(shape, dtype, dst); + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + DCHECK_EQ(instr.fields.size(), 5U); + Index shape_register = instr.fields[0]; + + DLDataType dtype; + dtype.code = instr.fields[1]; + dtype.bits = instr.fields[2]; + dtype.lanes = instr.fields[3]; + + RegName dst = instr.fields[4]; + + return Instruction::AllocTensorReg(shape_register, dtype, dst); + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index constructor_tag = instr.fields[0]; + Index num_fields = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector fields = ExtractFields(instr.fields, 3, num_fields); + + return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index clo_index = instr.fields[0]; + Index num_freevar = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); + + return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); + } + case Opcode::If: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + Index test = instr.fields[0]; + Index target = instr.fields[1]; + Index true_offset = instr.fields[2]; + Index false_offset = instr.fields[3]; + + return Instruction::If(test, target, true_offset, false_offset); + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index func_index = instr.fields[0]; + Index num_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_args); + + return Instruction::Invoke(func_index, args, dst); + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index closure = instr.fields[0]; + Index num_closure_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_closure_args); + + return Instruction::InvokeClosure(closure, args, dst); + } + case Opcode::LoadConst: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConst(instr.fields[0], instr.fields[1]); + } + case Opcode::LoadConsti: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); + } + case Opcode::GetField: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); + } + case Opcode::GetTag: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::GetTag(instr.fields[0], instr.fields[1]); + } + case Opcode::Goto: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Goto(instr.fields[0]); + } + default: + LOG(FATAL) << "Invalid opcode" << instr.opcode; + return Instruction(); + } +} + +void Executable::LoadCodeSection(dmlc::Stream* strm) { + // Load the number of functions. + uint64_t sz; + STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "code"); + + size_t num_funcs = static_cast(sz); + this->functions.resize(num_funcs); + for (size_t i = 0; i < num_funcs; i++) { + // Load the function info. + VMFunctionSerializer loaded_func; + STREAM_CHECK(loaded_func.Load(strm), "code/function"); + + // Load the instructions. + std::vector instructions; + for (size_t j = 0; j < loaded_func.num_instructions; j++) { + VMInstructionSerializer instr; + std::vector instr_fields; + STREAM_CHECK(instr.Load(strm), "code/instruction"); + instructions.push_back(DeserializeInstruction(instr)); + } + + // Create the VM function. + VMFunction vm_func = VMFunction(loaded_func.name, + loaded_func.params, + instructions, + loaded_func.register_file_size); + auto it = this->global_map.find(loaded_func.name); + CHECK(it != this->global_map.end()); + CHECK_LE(it->second, this->global_map.size()); + this->functions[it->second] = vm_func; + } +} + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->global_map.size()); +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + std::vector > globals(exec->global_map.begin(), + exec->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + CHECK_LT(idx, globals.size()); + *rv = globals[idx].first; +}); + +TVM_REGISTER_GLOBAL("relay._vm.GetNumOfPrimitives") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + *rv = static_cast(exec->primitive_map.size()); +}); + + +TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec); + int idx = args[1]; + CHECK_GE(idx, 0); + CHECK_LT(idx, exec->primitive_map.size()); + + for (const auto& it : exec->primitive_map) { + if (idx == static_cast(it.second)) { + *rv = it.first; + break; + } + } +}); + +TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") +.set_body_typed([]( + std::string code, + runtime::Module lib) { + return Executable::Load(code, lib); +}); + +} // namespace vm +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 80e0ce57a8ae..821de0bda245 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -85,19 +85,25 @@ PackedFunc VirtualMachineDebug::GetFunction( } } -void VirtualMachineDebug::Init(const std::vector& ctxs) { - VirtualMachine::Init(ctxs); - for (auto kv : primitive_map) { +void VirtualMachineDebug::LoadExecutable(const Executable* exec) { + VirtualMachine::LoadExecutable(exec); + CHECK(this->exec); + for (auto kv : this->exec->primitive_map) { packed_index_map[kv.second] = kv.first; op_invokes[kv.second] = 0; } } +void VirtualMachineDebug::Init(const std::vector& ctxs) { + VirtualMachine::Init(ctxs); +} + void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { - auto ctx = VirtualMachine::GetParamsContext(); + CHECK(this->exec); + auto ctx = this->GetParamsContext(); // warmup VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); @@ -117,6 +123,21 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, op_invokes[packed_index] += 1; } +runtime::Module CreateVirtualMachineDebug(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->LoadExecutable(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachineDebug") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "Virtual machine has not been defined yet." + << "\n"; + *rv = CreateVirtualMachineDebug(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 447967cafeb0..ff3296cb6c16 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -47,6 +47,8 @@ class VirtualMachineDebug : public VirtualMachine { void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) final; + void LoadExecutable(const Executable* exec); + ~VirtualMachineDebug() {} private: diff --git a/src/relay/backend/vm/serialize_util.h b/src/runtime/vm/serialize_util.h similarity index 95% rename from src/relay/backend/vm/serialize_util.h rename to src/runtime/vm/serialize_util.h index 3e7508ebee9b..3931f2f0e023 100644 --- a/src/relay/backend/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -19,11 +19,11 @@ /*! * Copyright (c) 2019 by Contributors - * \file src/relay/backend/vm/serialize_util.h + * \file src/runtime/vm/serialize_util.h * \brief Definitions of helpers for serializing and deserializing a Relay VM. */ -#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ -#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#ifndef TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ +#define TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ #include #include @@ -34,7 +34,7 @@ #include namespace tvm { -namespace relay { +namespace runtime { namespace vm { /*! \brief The magic number for the serialized VM bytecode file */ @@ -158,7 +158,7 @@ struct VMInstructionSerializer { }; } // namespace vm -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#endif // TVM_RUNTIME_VM_SERIALIZE_UTIL_H_ diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 7dea9bdb95ea..78b74768b930 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -575,11 +575,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(exec) << "The executable is not created yet."; std::string func_name = args[0]; - auto gvit = this->global_map.find(func_name); - CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name; + auto gvit = exec->global_map.find(func_name); + CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name; auto func_index = gvit->second; - const auto& vm_func = this->functions[func_index]; + const auto& vm_func = exec->functions[func_index]; const auto& param_names = vm_func.params; auto ctx = this->GetParamsContext(); @@ -617,10 +618,6 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } this->Init(contexts); }); - } else if (name == "load_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -628,43 +625,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } TVMContext VirtualMachine::GetParamsContext() const { + CHECK(!ctxs.empty()) << "Context has not been initialized yet." + << "\n"; + // Use the fallback device if no device index is available. int fallback_device_type = static_cast(ctxs[0].device_type); // TODO(wweic): For heterogeneous execution, get device information from byte const auto& cit = - std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); + std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { + return fallback_device_type == static_cast(c.device_type); + }); return (cit == ctxs.end() ? ctxs[0] : *cit); } -void VirtualMachine::LoadParams(const std::string& params) { - dmlc::MemoryStringStream mss(const_cast(¶ms)); - dmlc::Stream* strm = &mss; - uint64_t header, reserved; - CHECK(strm->Read(&header)) << "Invalid parameter file"; - CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file"; - CHECK(strm->Read(&reserved)) << "Invalid parameter file"; - - std::vector names; - CHECK(strm->Read(&names)) << "Invalid parameter file"; - - uint64_t sz; - strm->Read(&sz); - size_t size = static_cast(sz); - CHECK(size == names.size()) << "Invalid parameter file"; - - auto ctx = GetParamsContext(); - for (size_t i = 0; i < size; i++) { - NDArray arr; - CHECK(arr.Load(strm)) << "Invalid parameter file"; - ObjectRef obj = Tensor(arr); - auto copy = CopyTo(obj, ctx); - params_.emplace(std::make_pair(names[i], copy)); - } -} - void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); frames.push_back(frame); @@ -699,15 +673,17 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vectorGetAllocator(ctxs[0]); DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; return return_register; } ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector& args) { - auto func_index = this->global_map[name]; + CHECK(exec) << "The executable has not been created yet."; + auto func_index = exec->global_map.at(name); DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; - return Invoke(this->functions[func_index], args); + return Invoke(exec->functions[func_index], args); } void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, @@ -744,14 +720,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -void VirtualMachine::Init(const std::vector& ctxs) { - this->ctxs = ctxs; +void VirtualMachine::LoadExecutable(const Executable* exec) { + CHECK(exec) << "The executable is not created yet."; + this->exec = exec; + runtime::Module lib = this->exec->lib; // Get the list of packed functions. - CHECK(primitive_map.empty() || lib.operator->()) + CHECK(exec->primitive_map.empty() || lib.operator->()) << "runtime module should have been built for primitive functions" << "\n"; - for (const auto& it : primitive_map) { + for (const auto& it : this->exec->primitive_map) { const auto& packed_name = it.first; auto packed_index = static_cast(it.second); if (packed_funcs.size() <= packed_index) { @@ -761,6 +739,11 @@ void VirtualMachine::Init(const std::vector& ctxs) { } } + +void VirtualMachine::Init(const std::vector& ctxs) { + this->ctxs = ctxs; +} + inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames.back().register_file[r] = val; } @@ -788,6 +771,7 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { void VirtualMachine::RunLoop() { CHECK(this->code); + CHECK(this->exec); this->pc = 0; Index frame_start = frames.size(); while (true) { @@ -810,7 +794,8 @@ void VirtualMachine::RunLoop() { throw std::runtime_error("VM encountered fatal error"); } case Opcode::LoadConst: { - auto constant_obj = this->constants[instr.const_index]; + auto constant_obj = exec->constants[instr.const_index]; + // TODO(wweic) ctx could be obtained from the ctxs list. auto device_obj = CopyTo(constant_obj, ctxs[0]); WriteRegister(instr.dst, device_obj); pc++; @@ -828,7 +813,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_args; ++i) { args.push_back(ReadRegister(instr.invoke_args_registers[i])); } - InvokeGlobal(this->functions[instr.func_index], args); + InvokeGlobal(exec->functions[instr.func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -858,7 +843,7 @@ void VirtualMachine::RunLoop() { for (Index i = 0; i < instr.num_closure_args; ++i) { args.push_back(ReadRegister(instr.closure_args[i])); } - InvokeGlobal(this->functions[closure->func_index], args); + InvokeGlobal(exec->functions[closure->func_index], args); frames.back().caller_return_register = instr.dst; goto main_loop; } @@ -910,6 +895,7 @@ void VirtualMachine::RunLoop() { for (uint32_t i = 0; i < instr.alloc_tensor.ndim; ++i) { shape[i] = instr.alloc_tensor.shape[i]; } + // TODO(wweic) ctx could be obtained from the ctxs list. auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto obj = Tensor(data); @@ -931,6 +917,7 @@ void VirtualMachine::RunLoop() { auto num_dims = shape_tensor->shape[0]; auto shape = std::vector(shape_tensor->shape[0]); shape.assign(dims, dims + num_dims); + // TODO(wweic) ctx could be obtained from the ctxs list. auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto obj = Tensor(data); @@ -976,6 +963,21 @@ void VirtualMachine::RunLoop() { } } +runtime::Module CreateVirtualMachine(const Executable* exec) { + std::shared_ptr vm = std::make_shared(); + vm->LoadExecutable(exec); + return runtime::Module(vm); +} + +TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* exec = dynamic_cast(mod.operator->()); + CHECK(exec) << "The virtual machine executable has not been defined yet." + << "\n"; + *rv = CreateVirtualMachine(exec); +}); + } // namespace vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index cedbc4f71859..1b40f894db08 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -47,14 +47,16 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) return vm.invoke("main", *args) else: assert isinstance(f, relay.Module), "expected expression or module" mod = f - vm = relay.vm.compile(mod, target) - vm.init(tvm.cpu()) + exe = relay.vm.compile(mod, target) + vm = relay.vm.VirtualMachine(exe) + vm.init(ctx) ret = vm.invoke("main", *args) return ret @@ -573,25 +575,6 @@ def test_add_op_broadcast(): mod["main"] = func check_result([x_data, y_data], x_data + y_data, mod=mod) -def test_set_params(): - mod = relay.Module() - x = relay.var('x', shape=(10, 5)) - w = relay.var('w', shape=(6, 5)) - b = relay.var('b', shape=(6,)) - y = relay.nn.bias_add(relay.nn.dense(x, w), b) - mod["main"] = relay.Function([x, w, b], y) - vm = relay.vm.compile(mod, 'llvm') - vm.init(tvm.cpu()) - - x_np = np.random.uniform(size=(10, 5)).astype('float32') - w_np = np.random.uniform(size=(6, 5)).astype('float32') - b_np = np.random.uniform(size=(6,)).astype('float32') - ref_np = np.dot(x_np, w_np.T) + b_np - params = {'w': w_np} - vm.load_params(params) - out = vm.run(x_np, b_np) - tvm.testing.assert_allclose(out.asnumpy(), ref_np) - if __name__ == "__main__": test_id() @@ -626,4 +609,3 @@ def test_set_params(): test_add_op_scalar() test_add_op_tensor() test_add_op_broadcast() - test_set_params() diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 3a317fc2d111..014648099aeb 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -22,29 +22,25 @@ from tvm import relay from tvm.relay.module import Module as rly_module from tvm.relay import vm as _vm -from tvm.relay import serializer, deserializer from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.prelude import Prelude from tvm.contrib import util from tvm.relay import testing -def create_vm(f, ctx=tvm.cpu(), target="llvm", params=None): +def create_exec(f, target="llvm", params=None): if isinstance(f, relay.Expr): mod = relay.Module() mod["main"] = f - vm = _vm.compile(mod, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(mod, target=target, params=params) + return executable else: assert isinstance(f, relay.Module), "expected mod as relay.Module" - vm = _vm.compile(f, target=target, params=params) - vm.init(ctx) - return vm + executable = _vm.compile(f, target=target, params=params) + return executable def veval(vm, *args, ctx=tvm.cpu()): assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" - vm.init(ctx) ret = vm.run(*args) return ret @@ -59,13 +55,11 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32'): return result.asnumpy().astype(dtype) def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): - vm = create_vm(mod, ctx, target, params=params) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod, target, params=params) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) des_vm.init(ctx) - des_vm.load_params(params) result = des_vm.run(data) return result.asnumpy().astype(dtype) @@ -99,26 +93,25 @@ def test_serializer(): main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) mod["main"] = main - vm = create_vm(mod) - ser = serializer.Serializer(vm) + exe = create_exec(mod) - glbs = ser.globals + glbs = exe.globals assert len(glbs) == 3 assert "f1" in glbs assert "f2" in glbs assert "main" in glbs - prim_ops = ser.primitive_ops + prim_ops = exe.primitive_ops assert any(item.startswith('fused_add') for item in prim_ops) assert any(item.startswith('fused_subtract') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops) - code = ser.bytecode + code = exe.bytecode assert "main 5 2 5" in code assert "f1 2 1 3" in code assert "f2 2 1 3" in code - code, lib = ser.serialize() + code, lib = exe.save() assert isinstance(code, bytearray) assert isinstance(lib, tvm.module.Module) @@ -129,24 +122,24 @@ def test_save_load(): x_data = np.random.rand(10, 10).astype('float32') # serialize. - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + vm = create_exec(f) + code, lib = vm.save() assert isinstance(code, bytearray) # save and load the code and lib file. tmp = util.tempdir() path_lib = tmp.relpath("lib.so") lib.export_library(path_lib) - with open(tmp.relpath("code.bc"), "wb") as fo: + with open(tmp.relpath("code.ro"), "wb") as fo: fo.write(code) loaded_lib = tvm.module.load(path_lib) - loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + loaded_code = bytearray(open(tmp.relpath("code.ro"), "rb").read()) # deserialize. - deser = deserializer.Deserializer(loaded_code, loaded_lib) - des_vm = deser.deserialize() + des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) @@ -156,12 +149,12 @@ def test_const(): c = relay.const(1.0, "float32") x = relay.var('x', shape=(10, 10), dtype='float32') f = relay.Function([x], x + c) - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() + exe = create_exec(f) + code, lib = exe.save() assert isinstance(code, bytearray) - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.random.rand(10, 10).astype('float32') res = veval(des_vm, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) @@ -177,11 +170,11 @@ def test_if(): x_data = np.random.rand(10, 10).astype('float32') y_data = np.random.rand(10, 10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(f) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) # same res = veval(des_vm, x_data, x_data) @@ -213,11 +206,11 @@ def test_loop(): aarg = relay.var('accum', shape=[], dtype='int32') mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) @@ -230,11 +223,11 @@ def test_tuple(): i_data = np.random.rand(41).astype('float32') j_data = np.random.rand(10).astype('float32') - vm = create_vm(f) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(f) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) @@ -251,11 +244,11 @@ def test_adt_list(): f = relay.Function([], l321) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) result = veval(des_vm) assert len(result) == 2 @@ -297,11 +290,11 @@ def test_adt_compose(): f = relay.Function([y], add_two_body) mod["main"] = f - vm = create_vm(mod) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(mod) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) x_data = np.array(np.random.rand()).astype('float32') result = veval(des_vm, x_data) @@ -317,11 +310,11 @@ def test_closure(): clo = ff(relay.const(1.0)) main = clo(relay.const(2.0)) - vm = create_vm(main) - ser = serializer.Serializer(vm) - code, lib = ser.serialize() - deser = deserializer.Deserializer(code, lib) - des_vm = deser.deserialize() + exe = create_exec(main) + code, lib = exe.save() + des_exec = _vm.Executable.load_exec(code, lib) + des_vm = _vm.VirtualMachine(des_exec) + des_vm.init(tvm.cpu()) res = veval(des_vm) tvm.testing.assert_allclose(res.asnumpy(), 3.0) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index b5ce0ec70e51..53f573730576 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -26,9 +26,9 @@ def test_basic(): mod, params = resnet.get_workload() target = 'llvm' ctx = tvm.cpu() - vm = relay.profiler_vm.compile(mod, target) + exe = relay.profiler_vm.compile(mod, target, params=params) + vm = relay.profiler_vm.VirtualMachineProfiler(exe) vm.init(ctx) - vm.load_params(params) data = np.random.rand(1, 3, 224, 224).astype('float32') res = vm.invoke("main", [data]) From 36a96773bc24f65a52404056d9f1c170ebea206b Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 17 Oct 2019 22:41:34 -0700 Subject: [PATCH 13/62] [Relay][Frontend][TF] Add tensor array ops (#3798) * [Relay][Frontend][TF] Add tensor array ops * rename * delete test * Move utility function * Refactor * fix tensor array ops * fix test * fix rebase * Fix serializer bug * Improve tf convert name lookup to use prelude api * Fix lint * Fix test --- python/tvm/relay/frontend/tensorflow.py | 82 ++- python/tvm/relay/op/_tensor.py | 26 + python/tvm/relay/prelude.py | 520 ++++++++++++++++++ python/tvm/relay/testing/py_converter.py | 8 +- src/runtime/vm/executable.cc | 4 +- .../frontend/tensorflow/test_forward.py | 118 +++- tests/python/relay/test_adt.py | 148 +++++ tests/python/relay/test_feature.py | 3 +- 8 files changed, 899 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 38f9c523e0b1..eb67cf24b81e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -22,10 +22,14 @@ import warnings from collections import defaultdict + # Numpy support import numpy as np import tvm + +from tvm.relay.prelude import Prelude + from .. import analysis from .. import expr as _expr from .. import op as _op @@ -508,6 +512,69 @@ def _impl(inputs, attr, params): return _op.concatenate(inputs_reshaped, axis) return _impl +def _tensor_array(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('dtype').name + tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) + return tensor_array_constructor(_op.take(inputs[0], tvm.relay.const(0))) + return _impl + +def _tensor_array_scatter(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get('T').name + values_rank = len(inputs[2].type_annotation.shape) + unstack_name = "tensor_array_unstack_tensor{}".format(values_rank) + unstack_function = prelude.get_var(unstack_name, dtype_str) + values = unstack_function(inputs[2]) + tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) + return tensor_array_scatter_func(inputs[0], inputs[1], values) + return _impl + +def _tensor_array_gather(): + def _impl(inputs, attr, params, prelude): + return prelude.tensor_array_gather(inputs[2], inputs[1]) + return _impl + +def _tensor_array_size(): + def _impl(inputs, attr, params, prelude): + return prelude.length(inputs[0]) + return _impl + +def _tensor_array_write(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[2].type_annotation.shape) + dtype = attr.get('T').name + + tensor_name = 'tensor{}'.format(input_rank) + tensor_func = prelude.get_var(tensor_name, dtype) + v = tensor_func(inputs[2]) + write_func = prelude.get_var('tensor_array_write', dtype) + + return write_func(inputs[3], _op.take(inputs[1], tvm.relay.const(0)), v) + return _impl + +def _tensor_array_read(): + def _impl(inputs, attr, params, prelude): + read_func = prelude.get_var('tensor_array_read', attr.get('dtype').name) + return read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) + return _impl + +def _tensor_array_split(): + def _impl(inputs, attr, params, prelude): + input_rank = len(inputs[1].type_annotation.shape) + dtype_str = attr.get('T').name + v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) + lengths = _op.cast(inputs[2], 'int32') + split_var = prelude.get_var('tensor_array_split', dtype_str) + return split_var(inputs[0], v, lengths) + return _impl + +def _tensor_array_concat(): + def _impl(inputs, attr, params, prelude): + concat_func = prelude.get_var('tensor_array_concat', attr['dtype'].name) + return concat_func(inputs[1]) + return _impl + def _tile(): def _impl(inputs, attr, params): reps = _get_list_param(params, inputs.pop()) @@ -1313,6 +1380,14 @@ def _impl(inputs, attr, params): 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), 'Pack' : _pack(), + 'TensorArrayV3' : _tensor_array(), + 'TensorArrayScatterV3' : _tensor_array_scatter(), + 'TensorArrayGatherV3' : _tensor_array_gather(), + 'TensorArraySizeV3' : _tensor_array_size(), + 'TensorArrayWriteV3' : _tensor_array_write(), + 'TensorArrayReadV3' : _tensor_array_read(), + 'TensorArraySplitV3' : _tensor_array_split(), + 'TensorArrayConcatV3' : _tensor_array_concat(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), 'Pow' : _elemwise('power'), @@ -1860,6 +1935,7 @@ def __init__(self): self._loops = {} self._branches = {} self._mod = _module.Module({}) + self._prelude = Prelude(self._mod) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -2335,7 +2411,11 @@ def _convert_operator(self, op_name, inputs, attrs, if op_name in identity_list: sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: - sym = convert_map[op_name](inputs, attrs, self._params) + if 'TensorArray' in op_name: + sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = convert_map[op_name](inputs, attrs, self._params) + elif op_name in convert_map_rnn: sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index da5804906269..188b3bb15956 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -108,6 +108,29 @@ def clip_compute(attrs, inputs, output_type, target): register_schedule("clip", schedule_elemwise) +@script +def _cast_shape_function(x): + out_ndim = len(x) + out = output_tensor((out_ndim,), "int64") + for i in const_range(out_ndim): + out[i] = x[i] + return out + +def cast_shape_func(attrs, inputs, out_ndims): + return [_cast_shape_function(*inputs)] + +@script +def _expand_dims_shape_func(x): + ndim = len(x.shape) + out = output_tensor((ndim+1,), "int64") + out[0] = int64(1) + for i in const_range(0, ndim): + out[i+1] = int64(x.shape[i]) + return out + +def expand_dims_shape_func(attrs, inputs, out_ndims): + return [_expand_dims_shape_func(*inputs)] + # shape func @script def _broadcast_shape_func(x, y, ndim): @@ -140,6 +163,9 @@ def _broadcast_shape_func(x, y, ndim): def broadcast_shape_func(attrs, inputs, out_ndims): return [_broadcast_shape_func(*inputs, out_ndims[0])] +register_shape_func("expand_dims", False, expand_dims_shape_func) +register_shape_func("cast", False, cast_shape_func) + register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) register_shape_func("multiply", False, broadcast_shape_func) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 803d8ef50db5..d27ffe512617 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -16,8 +16,513 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" +from .ty import GlobalTypeVar, TensorType, Any, scalar_type +from .expr import Var, Function, GlobalVar, If, const +from .op.tensor import add, subtract, equal +from .adt import Constructor, TypeData, Clause, Match +from .adt import PatternConstructor, PatternVar, PatternWildcard +from . import op from .module import Module +class TensorArrayOps(object): + """Contains tensor array related ops""" + + def __init__(self, prelude, dtype): + """Create tensor array ops registry""" + self.prelude = prelude + self.dtype = dtype + + def get_name(self, canonical): + """Get name corresponding to the caninical name""" + return self.prelude.get_name(canonical, self.dtype) + + def get_var(self, canonical): + """Get var corresponding to the caninical name""" + return self.prelude.get_var(canonical, self.dtype) + + def define_tensor_adt(self): + """Defines the dynamic tensor ADT, which is the container for tensors + with variable shapes.""" + tensor_type_name = self.get_name('tensor_t') + tensor_type_var = GlobalTypeVar(tensor_type_name) + setattr(self.prelude, tensor_type_name, tensor_type_var) + tensor0_type = TensorType([], self.dtype) + tensor1_type = TensorType([Any()], self.dtype) + tensor2_type = TensorType([Any(), Any()], self.dtype) + tensor3_type = TensorType([Any(), Any(), Any()], self.dtype) + tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype) + tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype) + tensor_nil_name = self.get_name('tensor_nil') + tensor0_name = self.get_name('tensor0') + tensor1_name = self.get_name('tensor1') + tensor2_name = self.get_name('tensor2') + tensor3_name = self.get_name('tensor3') + tensor4_name = self.get_name('tensor4') + tensor5_name = self.get_name('tensor5') + tensor6_name = self.get_name('tensor6') + tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) + tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var) + tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var) + tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var) + tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var) + tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) + tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) + tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) + setattr(self.prelude, tensor_nil_name, tensor_nil_case) + setattr(self.prelude, tensor0_name, tensor0_case) + setattr(self.prelude, tensor1_name, tensor1_case) + setattr(self.prelude, tensor2_name, tensor2_case) + setattr(self.prelude, tensor3_name, tensor3_case) + setattr(self.prelude, tensor4_name, tensor4_case) + setattr(self.prelude, tensor5_name, tensor5_case) + setattr(self.prelude, tensor6_name, tensor6_case) + self.prelude.mod[tensor_type_var] = TypeData(tensor_type_var, [], [tensor_nil_case, + tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case]) + + def define_tensor_take(self): + """Defines a function to return a range of tensor_t on axis 0. + tensor_take(t, lower, upper) : + tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t + """ + take_name = self.get_name("tensor_take") + take_var = GlobalVar(take_name) + setattr(self.prelude, take_name, take_var) + tensor_t = self.get_var('tensor_t') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + t = Var('tensor', tensor_t()) + lower = Var('lower', scalar_type('int32')) + upper = Var('upper', scalar_type('int32')) + t1 = Var('t1') + t2 = Var('t2') + t3 = Var('t3') + t4 = Var('t4') + t5 = Var('t5') + t6 = Var('t6') + tensor1_case =\ + Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor1_var(op.take(t1, op.arange(lower, upper, dtype='int32')))) + tensor2_case =\ + Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor2_var(op.take(t2, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor3_case =\ + Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor3_var(op.take(t3, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor4_case =\ + Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor4_var(op.take(t4, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor5_case =\ + Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor5_var(op.take(t5, op.arange(lower, upper, dtype='int32'), axis=0))) + tensor6_case =\ + Clause(PatternConstructor(tensor6_var, [PatternVar(t6)]), + tensor6_var(op.take(t6, op.arange(lower, upper, dtype='int32'), axis=0))) + self.prelude.mod[take_var] =\ + Function([t, lower, upper], + Match(t, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case, + tensor6_case], False), + tensor_t(), []) + + def define_tensor_expand_dims(self): + """Defines a function to grow a tensor_t's rank by adding one dimension in front + of the original tensor_t. + tensor_expand_dims(t) : tensor_t -> tensor_t + """ + expand_dims_name = self.get_name("tensor_expand_dims") + expand_dims_var = GlobalVar(expand_dims_name) + setattr(self.prelude, expand_dims_name, expand_dims_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + t0 = Var("t0") + t1 = Var("t1") + t2 = Var("t2") + t3 = Var("t3") + t4 = Var("t4") + t5 = Var("t5") + tensor0_var = self.get_var('tensor0') + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + tensor5_var = self.get_var('tensor5') + tensor6_var = self.get_var('tensor6') + tensor0_case = Clause(PatternConstructor(tensor0_var, [PatternVar(t0)]), + tensor1_var(op.expand_dims(t0, 0, 1))) + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t1)]), + tensor2_var(op.expand_dims(t1, 0, 1))) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t2)]), + tensor3_var(op.expand_dims(t2, 0, 1))) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t3)]), + tensor4_var(op.expand_dims(t3, 0, 1))) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t4)]), + tensor5_var(op.expand_dims(t4, 0, 1))) + tensor5_case = Clause(PatternConstructor(tensor5_var, [PatternVar(t5)]), + tensor6_var(op.expand_dims(t5, 0, 1))) + self.prelude.mod[expand_dims_var] =\ + Function([x], + Match(x, [tensor0_case, + tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case, + tensor5_case], False)) + + def define_tensor_concat(self): + """Defines a function to concatenate two tensor_t on the first axis + + tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t + """ + concat_name = self.get_name("tensor_concatenate") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_type_var = self.get_var('tensor_t') + x = Var("x", tensor_type_var()) + y = Var("y", tensor_type_var()) + + tensor1_var = self.get_var('tensor1') + tensor2_var = self.get_var('tensor2') + tensor3_var = self.get_var('tensor3') + tensor4_var = self.get_var('tensor4') + t11 = Var("t11") + t12 = Var("t12") + t21 = Var("t21") + t22 = Var("t22") + t31 = Var("t31") + t32 = Var("t32") + t41 = Var("t41") + t42 = Var("t42") + tensor1_case = Clause(PatternConstructor(tensor1_var, [PatternVar(t11)]), + Match(y, [Clause(PatternConstructor(tensor1_var, [PatternVar(t12)]), + tensor1_var(op.concatenate([t11, t12], axis=0)))], + False)) + tensor2_case = Clause(PatternConstructor(tensor2_var, [PatternVar(t21)]), + Match(y, [Clause(PatternConstructor(tensor2_var, [PatternVar(t22)]), + tensor2_var(op.concatenate([t21, t22], axis=0)))], + False)) + tensor3_case = Clause(PatternConstructor(tensor3_var, [PatternVar(t31)]), + Match(y, [Clause(PatternConstructor(tensor3_var, [PatternVar(t32)]), + tensor3_var(op.concatenate([t31, t32], axis=0)))], + False)) + tensor4_case = Clause(PatternConstructor(tensor4_var, [PatternVar(t41)]), + Match(y, [Clause(PatternConstructor(tensor4_var, [PatternVar(t42)]), + tensor4_var(op.concatenate([t41, t42], axis=0)))], + False)) + # op.concatenate does not support tensor with rank higher than 4 + self.prelude.mod[concat_var] =\ + Function([x, y], Match(x, [tensor1_case, + tensor2_case, + tensor3_case, + tensor4_case], False)) + + def define_tensor_array(self): + """Defines a function to create a tensor array with size n. + tensor_array(n) : Tensor[(), int32] -> list[tensor_t] + """ + tensor_array_constructor_name = self.get_name("tensor_array") + tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) + setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) + tensor_nil_var = self.get_var('tensor_nil') + tensor_type_var = self.get_var('tensor_t') + n = Var("x", scalar_type('int32')) + body = If(equal(n, const(0)), + self.prelude.nil(), + self.prelude.cons(tensor_nil_var(), + tensor_array_constructor_var(subtract(n, const(1))))) + self.prelude.mod[tensor_array_constructor_var] = \ + Function([n], body, self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_read(self): + """Defines a function to get the head of a list. Assume the list has at least one + element. + + tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t + """ + read_name = self.get_name("tensor_array_read") + read_var = GlobalVar(read_name) + setattr(self.prelude, read_name, read_var) + tensor_type_var = self.get_var('tensor_t') + + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + self.prelude.mod[read_var] =\ + Function([tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []) + + def define_tensor_array_write(self): + """Defines a function to update a tensor array at index n with value v. + tensor_array_write(ta, n, v) : + list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t] + """ + write_name = self.get_name("tensor_array_write") + write_var = GlobalVar(write_name) + setattr(self.prelude, write_name, write_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + n = Var("x", scalar_type('int32')) + v = Var("v", tensor_type_var()) + self.prelude.mod[write_var] =\ + Function([tensor_array, n, v], self.prelude.update(tensor_array, n, v), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor1(self): + """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array. + tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor1_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + tensor_type_var = self.get_var('tensor_t') + tensor0_var = self.get_var('tensor0') + helper_body =\ + If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(tensor0_var(op.take(tensor, i)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), []) + unstack_name = self.get_name("tensor_array_unstack_tensor1") + unstack_var = GlobalVar(unstack_name) + setattr(self.prelude, unstack_name, unstack_var) + tensor1 = Var("tensor", TensorType([Any()], self.dtype)) + shape = op.shape_of(tensor1) + ndim = op.take(shape, const(0)) + self.prelude.mod[unstack_var] =\ + Function([tensor1], helper_var(const(0), ndim, tensor1), + self.prelude.l(tensor_type_var()), []) + + def define_tensor_array_unstack_tensor2(self): + """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array. + + tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t] + """ + helper_name = self.get_name("tensor_array_unstack_tensor2_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor = Var("t", TensorType([Any(), Any()], self.dtype)) + up = Var("up", scalar_type('int32')) + i = Var("i", scalar_type('int32')) + + helper_body = If(equal(i, up), + self.prelude.nil(), + self.prelude.cons(self.get_var('tensor1')(op.take(tensor, i, axis=0)), + helper_var(add(i, const(1)), up, tensor))) + self.prelude.mod[helper_var] =\ + Function([i, up, tensor], helper_body, self.prelude.l(self.get_var('tensor_t')()), []) + + tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") + tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name) + setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var) + tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype)) + shape = op.shape_of(tensor2) + ndim = op.take(shape, const(0)) + self.prelude.mod[tensor_array_unstack_tensor2_var] =\ + Function([tensor2], helper_var(const(0), ndim, tensor2), + self.prelude.l(self.get_var('tensor_t')()), []) + + def define_tensor_array_scatter(self): + """Defines a function to scatter the values of a tensor_t in indices of a tensor array. + tensor_array_scatter(ta, indices, value) : + list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] + """ + tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) + tensor_t = self.get_var('tensor_t') + ta = Var("ta", self.prelude.l(tensor_t())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + values_ = Var('values_', self.prelude.l(tensor_t())) + write_var = self.get_var('tensor_array_write') + read_var = self.get_var('tensor_array_read') + helper_body = If(equal(current, limit), + ta, + tensor_array_scatter_helper_var( + write_var(ta, op.take(indices_, current), + read_var(values_, current)), + add(current, const(1)), + limit, indices_, values_)) + self.prelude.mod[tensor_array_scatter_helper_var] =\ + Function([ta, current, limit, indices_, values_], + helper_body, self.prelude.l(tensor_t()), []) + tensor_array_scatter_name = self.get_name("tensor_array_scatter") + tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) + setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + indices = Var('indices', TensorType([Any()], 'int32')) + values = Var('values', self.prelude.l(tensor_t())) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) + self.prelude.mod[tensor_array_scatter_var] =\ + Function([tensor_array, indices, values], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_split(self): + """Defines a function to split the values of a tensor_t into a tensor array. + tensor_array_split(ta, value, lengths) : + list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] + """ + tensor_t = self.get_var('tensor_t') + tensor_array_split_helper_name = self.get_name("ta_split_helper") + tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) + setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) + ta1 = Var("tensor_array", self.prelude.l(tensor_t())) + value1 = Var('value1', tensor_t()) + offset1 = Var('offset1', scalar_type('int32')) + current1 = Var('current1', scalar_type('int32')) + limit1 = Var('limit1', scalar_type('int32')) + lengths1 = Var('lengths', TensorType([Any()], 'int32')) + write_var = self.get_var('tensor_array_write') + take_var = self.get_var('tensor_take') + helper1_body = If(equal(current1, limit1), + ta1, + write_var( + tensor_array_split_helper_var( + ta1, + value1, + add(offset1, op.take(lengths1, current1)), + add(current1, const(1)), + limit1, + lengths1 + ), + current1, + take_var(value1, + offset1, + add(op.take(lengths1, current1), offset1)))) + self.prelude.mod[tensor_array_split_helper_var] = \ + Function([ta1, value1, offset1, current1, limit1, lengths1], + helper1_body, self.prelude.l(tensor_t()), []) + split_name = self.get_name("tensor_array_split") + split_var = GlobalVar(split_name) + setattr(self.prelude, split_name, split_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + value = Var('value', tensor_t()) + lengths = Var('lengths', TensorType([Any()], 'int32')) + lengths_shape = op.shape_of(lengths) + lengths_limit = op.take(lengths_shape, const(0)) + body = tensor_array_split_helper_var( + tensor_array, + value, + const(0), + const(0), + lengths_limit, + lengths) + self.prelude.mod[split_var] =\ + Function([tensor_array, value, lengths], body, self.prelude.l(tensor_t()), []) + + def define_tensor_array_concat(self): + """Defines a function to return the values in the tensor array as concatenated tensor_t. + tensor_array_concat(ta) : list[tensor_t] -> tensor_t + """ + concat_name = self.get_name("tensor_array_concat") + concat_var = GlobalVar(concat_name) + setattr(self.prelude, concat_name, concat_var) + tensor_concat_var = self.get_var('tensor_concatenate') + tensor_t = self.get_var('tensor_t') + tensor_nil_var = self.get_var('tensor_nil') + tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) + hd = Var("hd") + tl = Var("tl") + nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) + cons_case = Clause(PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), + Match(tl, [ + Clause(PatternConstructor(self.prelude.nil), hd), + Clause(PatternWildcard(), + tensor_concat_var(hd, concat_var(tl))) + ], False)) + self.prelude.mod[concat_var] =\ + Function([tensor_array], + Match(tensor_array, [nil_case, cons_case], False), tensor_t(), []) + + def define_tensor_array_gather(self): + """Defines a function to return the selected values in a tensor array as tensor_t. + tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t + """ + helper_name = self.get_name("tensor_array_gather_helper") + helper_var = GlobalVar(helper_name) + setattr(self.prelude, helper_name, helper_var) + tensor_type_var = self.get_var('tensor_t') + stack_var = self.get_var('tensor_array_stack') + read_var = self.get_var('tensor_array_read') + ta = Var("ta", self.prelude.l(tensor_type_var())) + accu = Var("accu", self.prelude.l(tensor_type_var())) + current = Var("current", scalar_type('int32')) + limit = Var("limit", scalar_type('int32')) + indices_ = Var('indices_', TensorType([Any()], 'int32')) + helper_body =\ + If(equal(current, const(0)), + stack_var(accu), + helper_var( + ta, + self.prelude.cons( + read_var( + ta, op.take(indices_, subtract(current, const(1)))), accu), + subtract(current, const(1)), + limit, indices_)) + self.prelude.mod[helper_var] = \ + Function([ta, accu, current, limit, indices_], helper_body, tensor_type_var(), []) + gather_name = self.get_name("tensor_array_gather") + gather_var = GlobalVar(gather_name) + setattr(self.prelude, gather_name, gather_var) + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + indices = Var('indices', TensorType([Any()], 'int32')) + indices_shape = op.shape_of(indices) + limit = op.take(indices_shape, const(0)) + body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) + self.prelude.mod[gather_var] =\ + Function([tensor_array, indices], body, tensor_type_var(), []) + + def define_tensor_array_stack(self): + """Defines a function to get the values in the tensor array as a stack tensor_t. + tensor_array_stack(l) : list[tensor_t] -> tensor_t + """ + stack_name = self.get_name("tensor_array_stack") + stack_var = GlobalVar(stack_name) + setattr(self.prelude, stack_name, stack_var) + tensor_type_var = self.get_var('tensor_t') + tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) + expand_dims_var = self.get_var('tensor_expand_dims') + concat_var = self.get_var('tensor_concatenate') + tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) + tensors = self.prelude.foldl(concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims)) + self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) + + def register(self): + """Register all tensor array ops in Prelude""" + self.define_tensor_adt() + self.define_tensor_take() + self.define_tensor_expand_dims() + self.define_tensor_concat() + self.define_tensor_array() + self.define_tensor_array_read() + self.define_tensor_array_write() + self.define_tensor_array_unstack_tensor1() + self.define_tensor_array_unstack_tensor2() + self.define_tensor_array_scatter() + self.define_tensor_array_split() + self.define_tensor_array_concat() + self.define_tensor_array_stack() + # TODO(wweic): Gather fails in PartialEvaluate + # self.define_tensor_array_gather() + class Prelude: """Contains standard definitions.""" @@ -27,6 +532,17 @@ def __init__(self, mod=None): self.mod = mod self.load_prelude() + def get_name(self, canonical, dtype): + """Get name corresponding to the canonical name""" + if canonical == 'tensor_t': + return 'tensor_{}_t'.format(dtype) + return "{}_{}".format(canonical, dtype) + + def get_var(self, canonical, dtype): + """Get var corresponding to the canonical name""" + name = self.get_name(canonical, dtype) + return getattr(self, name) + def load_prelude(self): """Parses the Prelude from Relay's text format into a module.""" # TODO(@jroesch): we should remove this helper when we port over prelude @@ -74,3 +590,7 @@ def load_prelude(self): ] for global_def in GLOBAL_DEFS: setattr(self, global_def, self.mod.get_global_var(global_def)) + + for dtype in ['float32', 'int32']: + tensor_array_ops = TensorArrayOps(self, dtype) + tensor_array_ops.register() diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d661be73ad02..d7b59922b89d 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -203,8 +203,12 @@ def convert_module(self): for var, func in self.mod.functions.items(): # optimize the definition so any operators used are lowered opt_func = self.optimize(func) - converted_func, _ = self.convert_func_node(opt_func, var) - defs.append(converted_func) + try: + converted_func, _ = self.convert_func_node(opt_func, var) + defs.append(converted_func) + except TypeError: + # TODO(wweic): fix conversion for Any + pass return defs diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 21f71af4eb8c..f85283094e91 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -309,7 +309,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; - fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(dtype.code); + fields.push_back(dtype.bits); + fields.push_back(dtype.lanes); fields.push_back(instr.dst); break; } diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index c2cbbff24173..3321d71a2cb8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -60,13 +60,19 @@ def vmobj_to_list(o): result.append(vmobj_to_list(f)) return result elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): - if o.constructor.name_hint == 'cons': + if o.constructor.name_hint == 'Cons': tl = vmobj_to_list(o.fields[1]) hd = vmobj_to_list(o.fields[0]) hd.extend(tl) return hd - elif o.constructor.name_hint == 'nil': + elif o.constructor.name_hint == 'Nil': return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.data.asnumpy()] else: @@ -77,14 +83,11 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) - layout = None if target == "cuda": layout = "NCHW" target_host = None - shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} - mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, @@ -581,6 +584,111 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5]) _test_squeeze(np.arange(6).reshape((1, 2, 1, 3, 1)), [-3, -5, -1]) +def test_tensor_array_constructor(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) + t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype) + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + ta2 = ta1.write(0, t) + ta3 = ta2.write(1, t2) + out = ta3.read(0) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') + run('float32') + run('int32') + +def test_tensor_array_scatter(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) + indices = tf.constant([2, 1, 0]) + ta1 = tf.TensorArray(dtype=dtype, size=3, infer_shape=False, dynamic_size=False) + ta2 = ta1.scatter(indices, t) + out0 = ta2.read(0) + out1 = ta2.read(1) + out2 = ta2.read(2) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') + run('float32') + run('int32') + +# TODO(wweic): Fix gather issue with PartialEvaluate +# def test_tensor_array_gather(): +# with tf.Graph().as_default(): +# dtype = 'float32' +# t = tf.constant([[1.0], [2.0], [3.0]]) +# scatter_indices = tf.constant([2, 1, 0]) +# gather_indices = tf.constant([1, 2]) +# ta1 = tf.TensorArray(dtype=tf.float32, size=3, infer_shape=False, dynamic_size=False) +# ta2 = ta1.scatter(scatter_indices, t) +# t1 = ta2.gather(gather_indices) +# g = tf.get_default_graph() +# compare_tf_with_tvm([], [], ['TensorArrayGatherV3:0'], mode='debug') + +def test_tensor_array_split(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) + split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) + ta2 = ta1.split(t, split_length) + out0 = ta2.read(0) + out1 = ta2.read(1) + out2 = ta2.read(2) + out3 = ta2.read(3) + g = tf.get_default_graph() + compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug') + compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug') + run('float32') + run('int32') + +def test_tensor_array_concat(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype) + split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) + ta1 = tf.TensorArray(dtype=dtype, size=4, infer_shape=False, dynamic_size=False) + ta2 = ta1.split(t, split_length) + t = ta2.concat() + compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug') + run('float32') + run('int32') + +def test_tensor_array_size(): + def run(dtype_str): + with tf.Graph().as_default(): + dtype = { + 'float32': tf.float32, + 'int32' : tf.int32 + }[dtype_str] + ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False) + out = ta1.size() + g = tf.get_default_graph() + compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') + run('float32') + run('int32') + ####################################################################### # ConcatV2 # -------- diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 7be7c75dfe64..390d3cd9f3c4 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -21,6 +21,8 @@ from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr +import numpy as np + mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) @@ -683,6 +685,146 @@ def test_iterate(): res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 +def test_tensor_expand_dims(): + def run(dtype): + x = relay.var('x') + mod = relay.Module() + p = Prelude(mod) + expand_dims_func = p.get_var('tensor_expand_dims', dtype) + tensor1 = p.get_var('tensor1', dtype) + mod["main"] = relay.Function([x], expand_dims_func(tensor1(x))) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + x_np = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(x_np) + got = vmobj_to_list(result) + expected = [np.expand_dims(x_np, axis=0)] + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def test_tensor_array_constructor(): + def run(dtype): + x = relay.var('x') + mod = relay.Module() + p = Prelude(mod) + tensor_array = p.get_var('tensor_array', dtype) + mod["main"] = relay.Function([x], tensor_array(x)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(5) + got = vmobj_to_list(result) + expected = np.array([0, 0, 0, 0, 0]) + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def test_tensor_array_read(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + l = relay.var('l') + i = relay.var('i') + read_func = p.get_var('tensor_array_read', dtype) + tensor_array = p.get_var('tensor_array', dtype) + mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(10, 5) + got = vmobj_to_list(result) + expected = [0] + tvm.testing.assert_allclose(expected, got) + run('float32') + run('int32') + +def vmobj_to_list(o): + if isinstance(o, tvm.relay.backend.vmobj.Tensor): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): + return [o.asnumpy()] + elif isinstance(o, tvm.relay.backend.vmobj.Datatype): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == 'Cons': + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + return hd + elif o.constructor.name_hint == 'Nil': + return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + +def test_tensor_array_stack(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + tensor_array = p.get_var('tensor_array', dtype) + tensor1 = p.get_var('tensor1', dtype) + write = p.get_var('tensor_array_write', dtype) + stack = p.get_var('tensor_array_stack', dtype) + l = relay.var('l') + v = relay.var('v') + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor1(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor1(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor1(v)) + tensor_array4 = stack(tensor_array3) + mod["main"] = relay.Function([v], tensor_array4) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(t) + res = vmobj_to_list(result) + expected = [np.stack([t, t, t])] + tvm.testing.assert_allclose(expected, res) + run('float32') + run('int32') + +def test_tensor_array_unstack(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype) + v = relay.var('v') + mod["main"] = relay.Function([v], unstack_tensor1(v)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(1,)).astype(dtype) + result = ex.evaluate()(t) + res = vmobj_to_list(result) + tvm.testing.assert_allclose(t, res) + run('float32') + run('int32') + +def test_tensor_take(): + def run(dtype): + mod = relay.Module() + p = Prelude(mod) + take = p.get_var('tensor_take', dtype) + tensor2 = p.get_var('tensor2', dtype) + v = relay.var('v') + lower = relay.var('lower') + upper = relay.var('upper') + mod["main"] = relay.Function([v, lower, upper], take(tensor2(v), lower, upper)) + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + t = np.random.uniform(size=(10, 10)).astype(dtype) + result = ex.evaluate()(t, 2, 5) + res = vmobj_to_list(result) + expected = [np.take(t, range(2, 5), axis=0)] + tvm.testing.assert_allclose(expected, res) + run('float32') + run('int32') if __name__ == "__main__": test_nat_constructor() @@ -707,3 +849,9 @@ def test_iterate(): test_size() test_compose() test_iterate() + + test_tensor_expand_dims() + test_tensor_array_constructor() + test_tensor_array_read() + test_tensor_array_stack() + test_tensor_array_unstack() diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py index 8f0e90de0315..64eda9d04e7c 100644 --- a/tests/python/relay/test_feature.py +++ b/tests/python/relay/test_feature.py @@ -38,7 +38,8 @@ def test_prelude(): Feature.fLet, Feature.fIf, Feature.fConstructor, - Feature.fMatch + Feature.fMatch, + Feature.fGraph ]) From fe418ecd93a1833b3dfe256a02304a4d9c2dd2dd Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Fri, 18 Oct 2019 08:19:32 -0700 Subject: [PATCH 14/62] Fix typo (#4144) --- src/pass/lower_tvm_builtin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 69618985d50c..79329cbe717f 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -230,7 +230,7 @@ class BuiltinLower : public IRMutator { cast(Int(32), device_type_))); return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr); } - // call packled. + // call packed. Expr MakeCallPacked(const Call* op, const Expr& e) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; From e7c88a99f830de30814df14eaa980547ecbd61c1 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 18 Oct 2019 09:49:37 -0700 Subject: [PATCH 15/62] [CI] Pin NNPack pthreadtools version (#4152) --- docker/install/ubuntu_install_nnpack.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docker/install/ubuntu_install_nnpack.sh b/docker/install/ubuntu_install_nnpack.sh index 4f45f130e2e5..dc51fc28d492 100755 --- a/docker/install/ubuntu_install_nnpack.sh +++ b/docker/install/ubuntu_install_nnpack.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,11 +22,14 @@ set -o pipefail apt-get update && apt-get install -y --no-install-recommends git cmake -# TODO: specific tag? git clone https://github.com/Maratyszcza/NNPACK NNPACK +git clone https://github.com/Maratyszcza/pthreadpool NNPACK/pthreadpool + +# Use specific versioning tag. (cd NNPACK && git checkout 1e005b0c2) +(cd NNPACK/pthreadpool && git checkout 13da0b4c) mkdir -p NNPACK/build cd NNPACK/build -cmake -DCMAKE_INSTALL_PREFIX:PATH=. -DNNPACK_INFERENCE_ONLY=OFF -DNNPACK_CONVOLUTION_ONLY=OFF -DNNPACK_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON .. && make -j4 && make install +cmake -DCMAKE_INSTALL_PREFIX:PATH=. -DNNPACK_INFERENCE_ONLY=OFF -DNNPACK_CONVOLUTION_ONLY=OFF -DNNPACK_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DPTHREADPOOL_SOURCE_DIR=pthreadpool .. && make -j4 && make install cd - From fdb01cb6084959b1305bec16d7c9785ebc050839 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 18 Oct 2019 10:51:15 -0700 Subject: [PATCH 16/62] [QNN][TFLite] Parsing QNN Add op. Adding MobilenetV2. (#4142) --- python/tvm/relay/frontend/tflite.py | 66 +++++++++++++++++++- tests/python/frontend/tflite/test_forward.py | 22 ++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 35bc85e09fdd..b08dd6bf94e0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -224,6 +224,18 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor): return lhs_tensor.qnn_params['scale'] == rhs_tensor.qnn_params['scale'] and \ lhs_tensor.qnn_params['zero_point'] == rhs_tensor.qnn_params['zero_point'] + def is_quantized(self, op): + """Check if an input tensor is quantized.""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + first_tensor = input_tensors[0] + return first_tensor.qnn_params is not None + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -498,7 +510,25 @@ def _convert_elemwise(self, relay_op, op): rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) - out = relay_op(lhs_expr, rhs_expr) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + # If quantized, extracts qnn params and call QNN add operator. + if lhs_tensor.qnn_params: + assert rhs_tensor.qnn_params, "Both tensors should be quantized." + assert output_tensor.qnn_params, "Output tensor should be quantized." + out = relay_op(lhs=lhs_expr, + rhs=rhs_expr, + lhs_scale=lhs_tensor.qnn_params['scale'], + lhs_zero_point=lhs_tensor.qnn_params['zero_point'], + rhs_scale=rhs_tensor.qnn_params['scale'], + rhs_zero_point=rhs_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point']) + else: + out = relay_op(lhs_expr, rhs_expr) # Options (fused_activation_function) options = None @@ -517,36 +547,70 @@ def _convert_elemwise(self, relay_op, op): fused_activation_fn = options.FusedActivationFunction() # if we have activation fn if fused_activation_fn != ActivationFunctionType.NONE: + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'Elemwise operators with fused activation are not supported yet.') out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): """Convert TFLite ADD""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + return self._convert_elemwise(_qnn.op.add, op) return self._convert_elemwise(_op.add, op) def convert_sub(self, op): """Convert TFLite SUB""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized sub operator is not supported yet.') return self._convert_elemwise(_op.subtract, op) def convert_mul(self, op): """Convert TFLite MUL""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized mul operator is not supported yet.') return self._convert_elemwise(_op.multiply, op) def convert_div(self, op): """Convert TFLite DIV""" + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized div operator is not supported yet.') return self._convert_elemwise(_op.divide, op) def convert_pow(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized pow operator is not supported yet.') return self._convert_elemwise(_op.power, op) def convert_maximum(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized maximum operator is not supported yet.') return self._convert_elemwise(_op.maximum, op) def convert_minimum(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized minimum operator is not supported yet.') return self._convert_elemwise(_op.minimum, op) def convert_greater(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized greater operator is not supported yet.') return self._convert_elemwise(_op.greater, op) def convert_zeros_like(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a71a24ee0a4f..29b0c87c5b32 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1037,6 +1037,26 @@ def test_forward_qnn_mobilenet_v1_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +def test_forward_qnn_mobilenet_v2_net(): + """Test the Quantized TFLite Mobilenet V2 model.""" + # MobilenetV2 + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz", + "mobilenet_v2_1.0_224_quant.tflite") + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + # Checking the labels because the requantize implementation is different between TFLite and + # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. + np.random.seed(0) + data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tflite_predictions = np.squeeze(tflite_output) + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') + tvm_predictions = np.squeeze(tvm_output) + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + ####################################################################### # SSD Mobilenet # ------------- @@ -1111,6 +1131,6 @@ def test_forward_ssd_mobilenet_v1(): test_forward_ssd_mobilenet_v1() # End to End quantized - # TODO - MobilenetV2 fails for now. Remove when fixed. test_forward_qnn_inception_v1_net() test_forward_qnn_mobilenet_v1_net() + test_forward_qnn_mobilenet_v2_net() From 687d4a83acaa5e952dab2844995b969a0ceea107 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 18 Oct 2019 15:22:37 -0700 Subject: [PATCH 17/62] Add lift_if_then_else pass (#3865) * Add LiftIfThenElse pass * Add more comments * Rename and refactor * Add description for internal data structure * Rename a test * Minor change * Address comments * Improve update_for --- include/tvm/ir_pass.h | 7 + src/api/api_pass.cc | 1 + src/pass/hoist_if_then_else.cc | 424 ++++++++++++++++++++ tests/python/unittest/test_pass_hoist_if.py | 185 +++++++++ 4 files changed, 617 insertions(+) create mode 100644 src/pass/hoist_if_then_else.cc create mode 100644 tests/python/unittest/test_pass_hoist_if.py diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 5ac71fdce47b..03078b8be41f 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -377,6 +377,13 @@ Stmt LowerStorageAccessInfo(Stmt stmt); */ Stmt DecorateDeviceScope(Stmt stmt); +/*! + * \brief Loop invariant code motion which locates and hoists if statements. + * \param stmt The stmt to do if statement hoisting. + * \return Transformed stmt. + */ +Stmt HoistIfThenElse(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 25cd5838385f..d2352496c2b4 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -160,5 +160,6 @@ REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); +REGISTER_PASS(HoistIfThenElse); } // namespace ir } // namespace tvm diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc new file mode 100644 index 000000000000..bbdb609e9a08 --- /dev/null +++ b/src/pass/hoist_if_then_else.cc @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file hoist_if_then_else.cc + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../arithmetic/int_set.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +using HoistMap = std::unordered_map>; +using VarMap = std::unordered_map>; + +/* + * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant. + * For example, given the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt. + * Then we hoist IfThenElse stmt by one For stmt each step: + * + * Step 1: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Step 2: + * for (i = 0; i < 3; i++) + * if (likely(i*2 < 4)) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * In this pass, we only continue detecting possible hoisting chance when visiting For, + * IfThenElse or AttrStmt Node. For example, for the following block: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * Only the For with k variable will be considered and the resulting stmt would be: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * A[i + j] = A[i + j] - 1 + * if (likely(i*2 < 4)) + * for (k = 0; k < 5; k++) + * A[3*i+2j+k] = B[7*i+3j+k] + * + * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following + * block won't be optimized: + * for (i = 0; i < 3; i++) + * for (j = 0; j < 4; j++) + * for (k = 0; k < 5; k++) + * if (likely(i*2 < 4)) + * A[3*i+2j+k] = B[7*i+3j+k] + * if (likely(j > 2)) + * A[i+j+k] = B[i+j+k] + * + */ +class IfThenElseHoist { + public: + Stmt VisitAndMutate(const Stmt& stmt) { + SelectCandidates(stmt); + LocateTopFor(); + return PostOrderMutate(stmt); + } + + private: + void SelectCandidates(const Stmt& stmt); + void LocateTopFor(); + Stmt PostOrderMutate(const Stmt& stmt); + size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt); + Stmt HoistIf(const Stmt& if_stmt); + + // Map of all For nodes to all child IfThenElse nodes. + HoistMap for2if_map_; + // Map of all IfThenElse nodes to all For nodes which are loop invariant. + HoistMap if2for_map_; + // Map of highest loop invariant For to child IfThenElse. + HoistMap top_for_var_map_; + // Map of original For to list of update For nodes. + HoistMap for_tracking_map_; + // Map of all IfThenElse nodes to condition variable nodes. + VarMap cond_var_map_; + // List of For nodes added in post order DFS visiting. + std::vector ordered_for_list_; +}; + +// Check whether a given IfThenElse stmt is the first one appearing +// in a For stmt. +bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) { + std::vector if_node_list; + const For* for_node = for_stmt.as(); + CHECK(for_node); + CHECK(if_stmt.as()); + + PostOrderVisit(for_node->body, [&](const NodeRef& node) { + if (node.as()) { + if_node_list.push_back(node.get()); + } + }); + return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back(); +} + +// Update upper level For node when current For node is modified. +// With this function we only need to visit and mutate top level For node +// in the main VisitAndMutate function. +Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { + const Node* top_for_node; + const For* parent_for_node = parent_for_stmt.as(); + CHECK(parent_for_node); + CHECK(new_if_stmt.as()); + + PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) { + if (node.as()) { + top_for_node = node.get(); + } + }); + + PackedFunc replace_target_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + if (current_for.get() == top_for_node) { + *ret = new_if_stmt; + } + }); + + return IRTransform(parent_for_stmt, nullptr, replace_target_for, + {Expr("For")}); +} + +// Remove IfThenElse node from a For node. +// A pair of For nodes will be generated. +std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { + Stmt then_for; + Stmt else_for; + CHECK(if_stmt.as()); + + PackedFunc replace_then_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); + + PackedFunc replace_else_case = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); + + then_for = IRTransform(for_stmt, nullptr, replace_then_case, + {Expr("IfThenElse")}); + if (if_stmt.as()->else_case) { + else_for = IRTransform(for_stmt, nullptr, replace_else_case, + {Expr("IfThenElse")}); + } + + return std::make_pair(then_for, else_for); +} + +// Locate all For nodes and capture child IfThenElse nodes. +void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { + PostOrderVisit(stmt, [&](const NodeRef& node){ + const For* for_node = node.as(); + if (!for_node) return; + + std::queue tracker; + tracker.push(for_node->body); + Stmt for_stmt = Downcast(node); + for2if_map_.insert({for_stmt.get(), std::vector()}); + while (!tracker.empty()) { + Stmt head = tracker.front(); + tracker.pop(); + if (head->is_type()) { + for (const auto& if_stmt : for2if_map_.at(head.get())) { + for2if_map_[for_stmt.get()].push_back(if_stmt); + } + } else if (head->is_type()) { + const AttrStmt* attr_node = head.as(); + tracker.push(attr_node->body); + } else if (head->is_type()) { + for2if_map_[for_stmt.get()].push_back(head); + const IfThenElse* if_node = head.as(); + tracker.push(if_node->then_case); + if (if_node->else_case) { + tracker.push(if_node->else_case); + } + + // Record condition variables. + if (!cond_var_map_.count(head.get())) { + std::unordered_set new_var_set; + cond_var_map_.insert({head.get(), new_var_set}); + PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) { + if (cond_node.as()) { + cond_var_map_[head.get()].insert(cond_node.get()); + } + }); + } + } else { + continue; + } + } + ordered_for_list_.emplace_back(Downcast(node)); + }); +} + +// For each IfThenElse node, find the highest For node which +// meets loop invariant condition. +void IfThenElseHoist::LocateTopFor() { + std::unordered_map if_position_map; + std::unordered_set top_for_var_set; + + // Create IfThenElse -> For map. + for (const Stmt& for_stmt : ordered_for_list_) { + std::vector if_list = for2if_map_[for_stmt.get()]; + const For* for_node = for_stmt.as(); + CHECK(for_node); + top_for_var_map_.insert({for_node->loop_var.get(), if_list}); + for (const Stmt& if_stmt : if_list) { + const Node* if_node = if_stmt.get(); + if2for_map_[if_node].push_back(for_stmt); + } + } + + // Locate the highest For node which is loop invariant. + for (const auto& item : if2for_map_) { + Stmt top_for; + const Node* if_stmt = item.first; + std::vector for_list = item.second; + for (size_t i = 0; i < for_list.size(); ++i) { + const Stmt& for_stmt = for_list.at(i); + const For* for_node = for_stmt.as(); + CHECK(for_node); + std::vector new_for_list{for_stmt}; + for_tracking_map_.insert({for_stmt.get(), new_for_list}); + if (cond_var_map_[if_stmt] + .count(for_node->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), + for_list.begin() + i); + if2for_map_[if_stmt] = updated_for_list; + break; + } else { + top_for = for_stmt; + } + } + if (top_for.as()) { + if_position_map.insert({if_stmt, top_for}); + } + } + + for (const auto& item : if_position_map) { + top_for_var_set.insert(item.second.as()->loop_var.get()); + } + + std::vector removed_for_var_list; + for (const auto& item : top_for_var_map_) { + const Node* top_for_var = item.first; + std::vector if_list = item.second; + if (!top_for_var_set.count(top_for_var)) { + removed_for_var_list.push_back(top_for_var); + } else { + std::vector actual_if_list; + for (const Stmt& if_stmt : if_list) { + if (if_position_map.count(if_stmt.get())) { + actual_if_list.push_back(if_stmt); + } + } + top_for_var_map_[top_for_var] = actual_if_list; + } + } + for (const Node* top_for_var : removed_for_var_list) { + top_for_var_map_.erase(top_for_var); + } +} + +// When we try to mutate a For node, some child For nodes can have already +// been mutated. This function is to get the updated For node and further +// hoisting can be done based on this new node. +// We keep all For nodes tracing in for_tracking_map_. When we get a +// hoisted IfThenElse, we match it with tracing For nodes to pick +// the updated one. +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, + const Stmt& if_stmt) { + std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; + size_t updated_for_idx = 0; + for (size_t i = 0; i < tracked_for_list.size(); ++i) { + const Stmt& current_for = + tracked_for_list.at(tracked_for_list.size() - 1 - i); + if (is_first_if(current_for, if_stmt)) { + updated_for_idx = tracked_for_list.size() - 1 - i; + break; + } + } + return updated_for_idx; +} + +// Hoist an IfThenElse node as high as possible. +// This function iterates on all candidate For nodes. For each For node, +// it first removes IfThenElse nodes. Then it generates a new IfThenElse +// node using mutated For nodes. +Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { + Stmt new_if = if_stmt; + + for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { + const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); + size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); + const Stmt& updated_for_node = + for_tracking_map_[for_stmt.get()].at(updated_for_idx); + auto generated_for_pair = RemoveIf(updated_for_node, new_if); + const Stmt& then_for = generated_for_pair.first; + const Stmt& else_for = generated_for_pair.second;; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; + + if (else_for.get()) { + for_tracking_map_[for_stmt.get()].push_back(else_for); + } + + const IfThenElse* new_if_node = new_if.as(); + CHECK(new_if_node); + new_if = IfThenElse::make(new_if_node->condition, then_for, else_for); + if (i < if2for_map_[if_stmt.get()].size() - 1) { + const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); + const Stmt& actual_next_for = + for_tracking_map_[original_next_for.get()].at(updated_for_idx); + Stmt update_for_stmt = update_for(actual_next_for, new_if); + + for_tracking_map_[original_next_for.get()]. + at(updated_for_idx) = update_for_stmt; + } + } + return new_if; +} + +// Mutate For nodes in post order DFS manner. +Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { + PackedFunc replace_top_for = PackedFunc( + [&](TVMArgs args, TVMRetValue *ret){ + const NodeRef& current_for = args[0]; + const For* for_node = current_for.as(); + if (!for_node) return; + + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : + top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } + + const IfThenElse* next_if_node; + const IfThenElse* current_if_node = + new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = + IfThenElse::make(current_if_node->condition, + current_if_node->then_case, + current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = IfThenElse::make(next_if_node->condition, current_if_stmt, + next_if_node->else_case); + current_if_node = new_for.as(); + } + + if (!new_for.get()) { + const IfThenElse* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElse::make(first_if_node->condition, + first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + }); + return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")}); +} + +Stmt HoistIfThenElse(Stmt stmt) { + return IfThenElseHoist().VisitAndMutate(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_pass_hoist_if.py b/tests/python/unittest/test_pass_hoist_if.py new file mode 100644 index 000000000000..4a28cf6b318a --- /dev/null +++ b/tests/python/unittest/test_pass_hoist_if.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + + +var_list = [] + +def verify_structure(stmt, expected_struct): + node_dict = {} + struct = {} + def _extract_vars(op): + global var_list + if isinstance(op, tvm.expr.Var): + var_list.append(op.name) + + def _visit(op): + key = op + if isinstance(op, tvm.stmt.IfThenElse): + global var_list + tvm.ir_pass.PostOrderVisit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] + var_list.clear() + elif isinstance(op, tvm.stmt.For): + val = [(op.body,), ("For", op.loop_var.name)] + elif isinstance(op, tvm.stmt.AttrStmt): + val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] + else: + return + node_dict[key] = val + + tvm.ir_pass.PostOrderVisit(stmt, _visit) + for key, val in node_dict.items(): + struct[val[1]] = tuple(node_dict[child][1] if child in node_dict + else None for child in val[0]) + + assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \ + % (expected_struct, struct) + var_list.clear() + +def test_basic(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + with ib.else_scope(): + ib.emit(tvm.make.Evaluate(n)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_no_else(): + ib = tvm.ir_builder.create() + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(ib.likely(i < 2)): + ib.emit(tvm.make.Evaluate(m)) + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_attr_stmt(): + ib = tvm.ir_builder.create() + dshape = (32, 64) + data = ib.pointer("float32", name="data") + l = tvm.var('l') + m = tvm.var('m') + n = tvm.var('n') + + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", dshape[0]) + ib.scope_attr(bx, "thread_extent", dshape[1]) + with ib.for_range(0, l, "i") as i: + with ib.for_range(0, m, "j") as j: + with ib.for_range(0, n, "k") as k: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 0.5 + with ib.else_scope(): + data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), + ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), + ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), + ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} + verify_structure(new_stmt, expected_struct) + +def test_nested_for(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), + ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), + ('For', 'i'): (('IfThenElse', ('i',)),)} + verify_structure(new_stmt, expected_struct) + +def test_if_block(): + ib = tvm.ir_builder.create() + data = ib.pointer("float32", name="data") + n = tvm.var("n") + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.if_scope(i >= 3): + data[i * 3 + j] = data[i * 3 + j] + 0.5 + with ib.for_range(0, 15, "k") as k: + with ib.for_range(0, 20, "l") as l: + with ib.if_scope(tvm.any(i < 4, j >= 8)): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2 + with ib.else_scope(): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 + with ib.if_scope(j <5): + data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1 + + + with ib.for_range(0, 5, "i") as i: + with ib.for_range(0, 10, "j") as j: + with ib.for_range(0, 15, "k") as k: + with ib.if_scope(n >= 3): + data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 + + stmt = ib.get() + new_stmt = tvm.ir_pass.HoistIfThenElse(stmt) + expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), + ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), + ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), + ('IfThenElse', ('n',)): (('For', 'j'), None)} + verify_structure(new_stmt, expected_struct) + + +if __name__ == "__main__": + test_basic() + test_no_else() + test_attr_stmt() + test_nested_for() + test_if_block() From 3c4b7cce5462fe54f5062a096a566f1085327030 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 18 Oct 2019 16:15:52 -0700 Subject: [PATCH 18/62] [CI] Update cpu docker (#4153) --- Jenkinsfile | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 6134023f9c21..9cf6e4f01275 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -38,9 +38,15 @@ // - Tag the new version as the lates // - Periodically cleanup the old versions on local workers // + +// Hashtag in the source to build current CI docker builds +// +// - ci-cpu:v0.54: e7c88a99f830de30814df14eaa980547ecbd61c1 +// + ci_lint = "tvmai/ci-lint:v0.51" ci_gpu = "tvmai/ci-gpu:v0.54" -ci_cpu = "tvmai/ci-cpu:v0.52" +ci_cpu = "tvmai/ci-cpu:v0.54" ci_i386 = "tvmai/ci-i386:v0.52" // tvm libraries From 32aad56c3a1f9712a7fe396dcf1b583a884bb118 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sat, 19 Oct 2019 21:57:50 -0700 Subject: [PATCH 19/62] [Refactor] Rename Datatype to ADT (#4156) We think it will reduce the confusion with the meaning. https://discuss.tvm.ai/t/discuss-consider-rename-vm-datatype/4339 --- docs/dev/virtual_machine.rst | 10 +++---- include/tvm/runtime/object.h | 2 +- include/tvm/runtime/vm.h | 24 ++++++++-------- python/tvm/relay/backend/vm.py | 2 +- python/tvm/relay/backend/vmobj.py | 20 ++++++------- src/relay/backend/vm/compiler.cc | 8 +++--- src/runtime/vm/executable.cc | 6 ++-- src/runtime/vm/object.cc | 28 +++++++++---------- src/runtime/vm/vm.cc | 24 ++++++++-------- .../frontend/tensorflow/test_forward.py | 2 +- tests/python/relay/test_adt.py | 2 +- tests/python/relay/test_vm.py | 2 +- tests/python/relay/test_vm_object.py | 8 +++--- 13 files changed, 69 insertions(+), 69 deletions(-) diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst index 2791ee71177e..cb08cc14e56e 100644 --- a/docs/dev/virtual_machine.rst +++ b/docs/dev/virtual_machine.rst @@ -121,7 +121,7 @@ AllocTensor Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result is saved to register `dst`. -AllocDatatype +AllocADT ^^^^^^^^^^^^^ **Arguments**: :: @@ -176,7 +176,7 @@ GetTagi RegName object RegName dst -Get the object tag for Datatype object in register `object`. And saves the reult to register `dst`. +Get the object tag for ADT object in register `object`. And saves the reult to register `dst`. Fatal ^^^^^ @@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures. :: - VMObject VMTensor(const tvm::runtime::NDArray& data); - VMObject VMDatatype(size_t tag, const std::vector& fields); - VMObject VMClosure(size_t func_index, std::vector free_vars); + Object Tensor(const tvm::runtime::NDArray& data); + Object ADT(size_t tag, const std::vector& fields); + Object Closure(size_t func_index, std::vector free_vars); Stack and State diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0693b1f47b3c..7291510c16df 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -51,7 +51,7 @@ enum TypeIndex { kRoot = 0, kVMTensor = 1, kVMClosure = 2, - kVMDatatype = 3, + kVMADT = 3, kStaticIndexEnd, /*! \brief Type index is allocated during runtime. */ kDynamic = kStaticIndexEnd diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index a276c658c496..7d2df0b285b1 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -57,31 +57,31 @@ class Tensor : public ObjectRef { /*! \brief An object representing a structure or enumeration. */ -class DatatypeObj : public Object { +class ADTObj : public Object { public: /*! \brief The tag representing the constructor used. */ size_t tag; /*! \brief The fields of the structure. */ std::vector fields; - static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype; - static constexpr const char* _type_key = "vm.Datatype"; - TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object); + static constexpr const uint32_t _type_index = TypeIndex::kVMADT; + static constexpr const char* _type_key = "vm.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); }; -/*! \brief reference to data type. */ -class Datatype : public ObjectRef { +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { public: - Datatype(size_t tag, std::vector fields); + ADT(size_t tag, std::vector fields); /*! * \brief construct a tuple object. * \param fields The fields of the tuple. * \return The constructed tuple type. */ - static Datatype Tuple(std::vector fields); + static ADT Tuple(std::vector fields); - TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj); + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); }; /*! \brief An object representing a closure. */ @@ -129,7 +129,7 @@ enum class Opcode { InvokePacked = 4U, AllocTensor = 5U, AllocTensorReg = 6U, - AllocDatatype = 7U, + AllocADT = 7U, AllocClosure = 8U, GetField = 9U, If = 10U, @@ -237,7 +237,7 @@ struct Instruction { /*! \brief The register to project from. */ RegName object; } get_tag; - struct /* AllocDatatype Operands */ { + struct /* AllocADT Operands */ { /*! \brief The datatype's constructor tag. */ Index constructor_tag; /*! \brief The number of fields to store in the datatype. */ @@ -294,7 +294,7 @@ struct Instruction { * \param dst The register name of the destination. * \return The allocate instruction tensor. */ - static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector& fields, + static Instruction AllocADT(Index tag, Index num_fields, const std::vector& fields, RegName dst); /*! \brief Construct an allocate closure instruction. * \param func_index The index of the function table. diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 942c93b866f4..e190e3f1eb41 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -31,7 +31,7 @@ from .interpreter import Executor Tensor = _obj.Tensor -Datatype = _obj.Datatype +ADT = _obj.ADT def _convert(arg, cargs): if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): diff --git a/python/tvm/relay/backend/vmobj.py b/python/tvm/relay/backend/vmobj.py index 939b122bf510..f3fdb763209d 100644 --- a/python/tvm/relay/backend/vmobj.py +++ b/python/tvm/relay/backend/vmobj.py @@ -61,14 +61,14 @@ def asnumpy(self): return self.data.asnumpy() -@register_object("vm.Datatype") -class Datatype(Object): - """Datatype object. +@register_object("vm.ADT") +class ADT(Object): + """Algebatic data type(ADT) object. Parameters ---------- tag : int - The tag of datatype. + The tag of ADT. fields : list[Object] or tuple[Object] The source tuple. @@ -77,22 +77,22 @@ def __init__(self, tag, fields): for f in fields: assert isinstance(f, Object) self.__init_handle_by_constructor__( - _vmobj.Datatype, tag, *fields) + _vmobj.ADT, tag, *fields) @property def tag(self): - return _vmobj.GetDatatypeTag(self) + return _vmobj.GetADTTag(self) def __getitem__(self, idx): return getitem_helper( - self, _vmobj.GetDatatypeFields, len(self), idx) + self, _vmobj.GetADTFields, len(self), idx) def __len__(self): - return _vmobj.GetDatatypeNumberOfFields(self) + return _vmobj.GetADTNumberOfFields(self) def tuple_object(fields): - """Create a datatype object from source tuple. + """Create a ADT object from source tuple. Parameters ---------- @@ -101,7 +101,7 @@ def tuple_object(fields): Returns ------- - ret : Datatype + ret : ADT The created object. """ for f in fields: diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index f295ccd7a555..fab01bd40423 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -239,7 +239,7 @@ class VMFunctionCompiler : ExprFunctor { DLOG(INFO) << "VMCompiler::Emit: instr=" << instr; CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { - case Opcode::AllocDatatype: + case Opcode::AllocADT: case Opcode::AllocTensor: case Opcode::AllocTensorReg: case Opcode::GetField: @@ -287,7 +287,7 @@ class VMFunctionCompiler : ExprFunctor { } // TODO(@jroesch): use correct tag - Emit(Instruction::AllocDatatype( + Emit(Instruction::AllocADT( 0, tuple->fields.size(), fields_registers, @@ -626,7 +626,7 @@ class VMFunctionCompiler : ExprFunctor { for (size_t i = arity - return_count; i < arity; ++i) { fields_registers.push_back(unpacked_arg_regs[i]); } - Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister())); + Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister())); } } @@ -659,7 +659,7 @@ class VMFunctionCompiler : ExprFunctor { } } else if (auto constructor_node = op.as()) { auto constructor = GetRef(constructor_node); - Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers, + Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers, NewRegister())); } else if (auto var_node = op.as()) { VisitExpr(GetRef(var_node)); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index f85283094e91..32032b5a1e64 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -315,7 +315,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(instr.dst); break; } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); @@ -551,7 +551,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::AllocTensorReg(shape_register, dtype, dst); } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields DCHECK_GE(instr.fields.size(), 3U); DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); @@ -561,7 +561,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { RegName dst = instr.fields[2]; std::vector fields = ExtractFields(instr.fields, 3, num_fields); - return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + return Instruction::AllocADT(constructor_tag, num_fields, fields, dst); } case Opcode::AllocClosure: { // Number of fields = 3 + instr.num_freevar diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index c20a1ce9de27..12edf511db66 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -39,15 +39,15 @@ Tensor::Tensor(NDArray data) { data_ = std::move(ptr); } -Datatype::Datatype(size_t tag, std::vector fields) { - auto ptr = make_object(); +ADT::ADT(size_t tag, std::vector fields) { + auto ptr = make_object(); ptr->tag = tag; ptr->fields = std::move(fields); data_ = std::move(ptr); } -Datatype Datatype::Tuple(std::vector fields) { - return Datatype(0, fields); +ADT ADT::Tuple(std::vector fields) { + return ADT(0, fields); } Closure::Closure(size_t func_index, std::vector free_vars) { @@ -66,28 +66,28 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") *rv = cell->data; }); -TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag") +TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; - const auto* cell = obj.as(); + const auto* cell = obj.as(); CHECK(cell != nullptr); *rv = static_cast(cell->tag); }); -TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields") +TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; - const auto* cell = obj.as(); + const auto* cell = obj.as(); CHECK(cell != nullptr); *rv = static_cast(cell->fields.size()); }); -TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields") +TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; - const auto* cell = obj.as(); + const auto* cell = obj.as(); CHECK(cell != nullptr); CHECK_LT(idx, cell->fields.size()); *rv = cell->fields[idx]; @@ -104,10 +104,10 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple") for (auto i = 0; i < args.size(); ++i) { fields.push_back(args[i]); } - *rv = Datatype::Tuple(fields); + *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("_vmobj.Datatype") +TVM_REGISTER_GLOBAL("_vmobj.ADT") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); @@ -115,11 +115,11 @@ TVM_REGISTER_GLOBAL("_vmobj.Datatype") for (int i = 1; i < args.size(); i++) { fields.push_back(args[i]); } - *rv = Datatype(tag, fields); + *rv = ADT(tag, fields); }); TVM_REGISTER_OBJECT_TYPE(TensorObj); -TVM_REGISTER_OBJECT_TYPE(DatatypeObj); +TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 78b74768b930..fd5ff64d5812 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -74,7 +74,7 @@ Instruction::Instruction(const Instruction& instr) { this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return; - case Opcode::AllocDatatype: + case Opcode::AllocADT: this->constructor_tag = instr.constructor_tag; this->num_fields = instr.num_fields; this->datatype_fields = Duplicate(instr.datatype_fields, instr.num_fields); @@ -159,7 +159,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return *this; - case Opcode::AllocDatatype: + case Opcode::AllocADT: this->constructor_tag = instr.constructor_tag; this->num_fields = instr.num_fields; FreeIf(this->datatype_fields); @@ -229,7 +229,7 @@ Instruction::~Instruction() { case Opcode::AllocTensor: delete this->alloc_tensor.shape; return; - case Opcode::AllocDatatype: + case Opcode::AllocADT: delete this->datatype_fields; return; case Opcode::AllocClosure: @@ -301,10 +301,10 @@ Instruction Instruction::AllocTensorReg(RegName shape_register, DLDataType dtype return instr; } -Instruction Instruction::AllocDatatype(Index tag, Index num_fields, +Instruction Instruction::AllocADT(Index tag, Index num_fields, const std::vector& datatype_fields, Index dst) { Instruction instr; - instr.op = Opcode::AllocDatatype; + instr.op = Opcode::AllocADT; instr.dst = dst; instr.constructor_tag = tag; instr.num_fields = num_fields; @@ -485,7 +485,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"; break; @@ -691,7 +691,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, const std::vector& args) { size_t arity = 0; for (Index i = 0; i < arg_count; i++) { - if (const auto* obj = args[i].as()) { + if (const auto* obj = args[i].as()) { arity += obj->fields.size(); } else { ++arity; @@ -703,7 +703,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, runtime::TVMArgsSetter setter(values.data(), codes.data()); int idx = 0; for (Index i = 0; i < arg_count; i++) { - if (const auto* dt_cell = args[i].as()) { + if (const auto* dt_cell = args[i].as()) { for (auto obj : dt_cell->fields) { const auto* tensor = obj.as(); CHECK(tensor != nullptr); @@ -849,7 +849,7 @@ void VirtualMachine::RunLoop() { } case Opcode::GetField: { auto object = ReadRegister(instr.object); - const auto* tuple = object.as(); + const auto* tuple = object.as(); CHECK(tuple != nullptr) << "Object is not data type object, register " << instr.object << ", Object tag " << object->type_index(); @@ -860,7 +860,7 @@ void VirtualMachine::RunLoop() { } case Opcode::GetTag: { auto object = ReadRegister(instr.get_tag.object); - const auto* data = object.as(); + const auto* data = object.as(); CHECK(data != nullptr) << "Object is not data type object, register " << instr.get_tag.object << ", Object tag " @@ -925,12 +925,12 @@ void VirtualMachine::RunLoop() { pc++; goto main_loop; } - case Opcode::AllocDatatype: { + case Opcode::AllocADT: { std::vector fields; for (Index i = 0; i < instr.num_fields; ++i) { fields.push_back(ReadRegister(instr.datatype_fields[i])); } - ObjectRef obj = Datatype(instr.constructor_tag, fields); + ObjectRef obj = ADT(instr.constructor_tag, fields); WriteRegister(instr.dst, obj); pc++; goto main_loop; diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 3321d71a2cb8..420bcb72a4a2 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -49,7 +49,7 @@ def convert_to_list(x): def vmobj_to_list(o): if isinstance(o, tvm.relay.backend.vmobj.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vmobj.Datatype): + elif isinstance(o, tvm.relay.backend.vmobj.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 390d3cd9f3c4..32bc22f9031a 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -742,7 +742,7 @@ def vmobj_to_list(o): return [o.asnumpy().tolist()] elif isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.asnumpy()] - elif isinstance(o, tvm.relay.backend.vmobj.Datatype): + elif isinstance(o, tvm.relay.backend.vmobj.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 1b40f894db08..a3b251c38e00 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -63,7 +63,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): def vmobj_to_list(o): if isinstance(o, tvm.relay.backend.vm.Tensor): return [o.asnumpy().tolist()] - elif isinstance(o, tvm.relay.backend.vm.Datatype): + elif isinstance(o, tvm.relay.backend.vm.ADT): result = [] for f in o: result.extend(vmobj_to_list(f)) diff --git a/tests/python/relay/test_vm_object.py b/tests/python/relay/test_vm_object.py index ad21fff8e185..12d263d1125b 100644 --- a/tests/python/relay/test_vm_object.py +++ b/tests/python/relay/test_vm_object.py @@ -28,13 +28,13 @@ def test_tensor(): assert isinstance(x.data, tvm.nd.NDArray) -def test_datatype(): +def test_adt(): arr = tvm.nd.array([1,2,3]) x = vm.Tensor(arr) - y = vm.Datatype(0, [x, x]) + y = vm.ADT(0, [x, x]) assert len(y) == 2 - assert isinstance(y, vm.Datatype) + assert isinstance(y, vm.ADT) y[0:1][-1].data == x.data assert y.tag == 0 assert isinstance(x.data, tvm.nd.NDArray) @@ -43,4 +43,4 @@ def test_datatype(): if __name__ == "__main__": test_tensor() - test_datatype() + test_adt() From 97ea31c8f5d460c5ec401b146cfb16481bef6641 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 20 Oct 2019 10:40:10 -0700 Subject: [PATCH 20/62] [Runtime] Enable option to use OpenMP thread pool (#4089) --- CMakeLists.txt | 4 ++++ cmake/config.cmake | 4 ++++ cmake/modules/OpenMP.cmake | 48 ++++++++++++++++++++++++++++++++++++++ src/runtime/thread_pool.cc | 26 +++++++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 cmake/modules/OpenMP.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 10730ac718b4..1b7d5efaf0e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF) tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF) +tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_SGX "Build with SGX" OFF) tvm_option(USE_RTTI "Build with RTTI" ON) @@ -154,6 +155,7 @@ list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_IR_SRCS}) list(APPEND COMPILER_SRCS ${RELAY_QNN_SRCS}) + if(USE_VM_PROFILER) message(STATUS "Build compiler with Relay VM profiler support...") file(GLOB BACKEND_VM_PROFILER_SRCS src/relay/backend/vm/profiler/*.cc) @@ -233,6 +235,7 @@ include(cmake/modules/VTA.cmake) include(cmake/modules/CUDA.cmake) include(cmake/modules/OpenCL.cmake) include(cmake/modules/OpenGL.cmake) +include(cmake/modules/OpenMP.cmake) include(cmake/modules/Vulkan.cmake) include(cmake/modules/Metal.cmake) include(cmake/modules/ROCM.cmake) @@ -264,6 +267,7 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) + if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") diff --git a/cmake/config.cmake b/cmake/config.cmake index d92c2151d9c8..6a55397692a0 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -115,6 +115,10 @@ set(USE_BLAS none) # set(USE_MKL_PATH ) if using `pip install mkl` set(USE_MKL_PATH none) +# Whether use OpenMP thread pool, choices: gnu, intel +# Note: "gnu" uses gomp library, "intel" uses iomp5 library +set(USE_OPENMP none) + # Whether use contrib.random in runtime set(USE_RANDOM OFF) diff --git a/cmake/modules/OpenMP.cmake b/cmake/modules/OpenMP.cmake new file mode 100644 index 000000000000..5dd9be508342 --- /dev/null +++ b/cmake/modules/OpenMP.cmake @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# OpenMP Module +if(USE_OPENMP STREQUAL "gnu") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenMP_CXX_LIBRARIES}) + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=1) + message(STATUS "Build with OpenMP ${OpenMP_CXX_LIBRARIES}") + else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) + message(WARNING "OpenMP cannot be found, use TVM threadpool instead.") + endif() +elseif(USE_OPENMP STREQUAL "intel") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + if (MSVC) + find_library(OMP_LIBRARY NAMES libiomp5md) + else() + find_library(OMP_LIBRARY NAMES iomp5) + endif() + list(APPEND TVM_RUNTIME_LINKER_LIBS ${OMP_LIBRARY}) + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=1) + message(STATUS "Build with OpenMP " ${OMP_LIBRARY}) + else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) + message(WARNING "OpenMP cannot be found, use TVM threadpool instead.") + endif() +else() + add_definitions(-DTVM_THREADPOOL_USE_OPENMP=0) +endif() diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 2e101364db2a..e9e6d03243e3 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -29,6 +29,9 @@ #include #include #include +#if TVM_THREADPOOL_USE_OPENMP +#include +#endif #include #include #include @@ -394,12 +397,34 @@ int TVMBackendParallelLaunch( FTVMParallelLambda flambda, void* cdata, int num_task) { +#if !TVM_THREADPOOL_USE_OPENMP int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch( flambda, cdata, num_task, 1); return res; +#else + int num_workers = tvm::runtime::threading::MaxConcurrency(); + if (num_task == 0) num_task = num_workers; + omp_set_num_threads(num_workers); + #pragma omp parallel num_threads(num_workers) + { + TVMParallelGroupEnv env; + env.num_task = num_task; + std::atomic* sync_counter = new std::atomic[num_task * tvm::runtime::kSyncStride]; + for (int i = 0; i < num_task; ++i) { + sync_counter[i * tvm::runtime::kSyncStride].store( + 0, std::memory_order_relaxed); + } + env.sync_handle = sync_counter; + (*flambda)(omp_get_thread_num(), &env, cdata); + } + return 0; +#endif } int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { +#if TVM_THREADPOOL_USE_OPENMP + #pragma omp barrier +#else using tvm::runtime::kSyncStride; int num_task = penv->num_task; std::atomic* sync_counter = @@ -415,5 +440,6 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { } } std::atomic_thread_fence(std::memory_order_acquire); +#endif return 0; } From 7895adb243ea6fbb1b434904ff3925c3a84f5693 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 20 Oct 2019 18:30:41 -0700 Subject: [PATCH 21/62] [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. (#4161) * [REFACTOR][NODE][RUNTIME] Move Node to the new Object protocol. This PR removes the original node system, and make node as a subclass of Object. This is a major refactor towards a better unified runtime object system. List of changes in the refactor: - We now hide data_ field, use Downcast explicitly to get a sub-class object. - Removed the node system FFI in python. - Removed the node C API, instead use PackedFunc for list and get attrs. - Change relay::Op::set_attr_type_key(attr_key_name) to relay::Op::set_attr_type(). - This change was necessary because of the new Object registration mechanism. - Subsequent changes to the op registrations - The change revealed a few previous problems that is now fixed. - Patched up a few missing node type registration. - Now we will raise an error if we register object that is not registered. - The original node.h and container.h are kept in the same location. - Calling convention: kObjectHandle now equals the old kNodeHandle, kNodeHandle is removed. - IRFunctor now dispatches on ObjectRef. - Update to the new type checking API: is_type, derived_from are replaced by IsInstance. - Removed .hash member function, instead use C++ convention hasher functors. * Address review comments --- golang/src/value.go | 4 +- include/tvm/api_registry.h | 8 +- include/tvm/arithmetic.h | 4 +- include/tvm/attrs.h | 24 +- include/tvm/base.h | 16 +- include/tvm/buffer.h | 4 +- include/tvm/build_module.h | 16 +- include/tvm/c_dsl_api.h | 98 ------ include/tvm/channel.h | 4 +- include/tvm/data_layout.h | 8 +- include/tvm/expr.h | 27 +- include/tvm/ir.h | 6 +- include/tvm/ir_functor_ext.h | 18 +- include/tvm/ir_mutator.h | 4 +- include/tvm/ir_visitor.h | 4 +- include/tvm/lowered_func.h | 9 +- include/tvm/node/container.h | 219 ++++++------- include/tvm/node/ir_functor.h | 50 +-- include/tvm/node/memory.h | 77 ----- include/tvm/node/node.h | 300 +++--------------- include/tvm/operation.h | 2 +- include/tvm/packed_func_ext.h | 183 +++++------ include/tvm/relay/adt.h | 2 +- include/tvm/relay/base.h | 10 +- include/tvm/relay/expr.h | 11 +- include/tvm/relay/expr_functor.h | 8 +- include/tvm/relay/interpreter.h | 4 +- include/tvm/relay/module.h | 6 +- include/tvm/relay/op.h | 19 +- include/tvm/relay/pattern_functor.h | 8 +- include/tvm/relay/transform.h | 8 +- include/tvm/relay/type.h | 7 +- include/tvm/runtime/c_runtime_api.h | 3 +- include/tvm/runtime/memory.h | 2 +- include/tvm/runtime/node_base.h | 259 --------------- include/tvm/runtime/object.h | 248 +++++++++++++-- include/tvm/runtime/packed_func.h | 65 ++-- include/tvm/schedule.h | 20 +- include/tvm/tensor.h | 16 +- include/tvm/tensor_intrin.h | 4 +- .../main/native/ml_dmlc_tvm_native_c_api.cc | 6 +- nnvm/include/nnvm/compiler/util.h | 6 +- nnvm/src/compiler/compile_engine.cc | 7 +- nnvm/src/compiler/compile_engine.h | 6 +- nnvm/src/compiler/graph_runtime.h | 5 +- nnvm/src/compiler/packed_func_ext.cc | 6 +- nnvm/src/top/tensor/transform.cc | 6 +- python/tvm/_ffi/_ctypes/function.py | 17 +- python/tvm/_ffi/_ctypes/node.py | 102 ------ python/tvm/_ffi/_ctypes/object.py | 13 +- python/tvm/_ffi/_cython/base.pxi | 17 +- python/tvm/_ffi/_cython/core.pyx | 2 +- python/tvm/_ffi/_cython/function.pxi | 23 +- python/tvm/_ffi/_cython/node.pxi | 110 ------- python/tvm/_ffi/_cython/object.pxi | 12 +- python/tvm/_ffi/node.py | 59 +--- python/tvm/_ffi/object.py | 23 +- python/tvm/_ffi/runtime_ctypes.py | 3 +- python/tvm/error.py | 1 + python/tvm/relay/backend/profiler_vm.py | 4 + python/tvm/relay/debug.py | 4 - rust/common/src/packed_func.rs | 6 +- rust/frontend/src/function.rs | 2 +- src/api/api_arith.cc | 3 +- src/api/api_base.cc | 11 +- src/api/api_codegen.cc | 6 +- src/api/api_ir.cc | 1 - src/api/api_lang.cc | 93 +++--- src/api/api_pass.cc | 8 +- src/api/api_schedule.cc | 5 +- src/api/dsl_api.cc | 134 +++----- src/arithmetic/analyzer.cc | 7 +- src/arithmetic/canonical_simplify.cc | 6 +- src/arithmetic/const_int_bound.cc | 2 +- src/arithmetic/detect_linear_equation.cc | 2 +- src/arithmetic/int_set.cc | 4 +- src/arithmetic/ir_mutator_with_analyzer.cc | 2 +- src/arithmetic/ir_visitor_with_analyzer.h | 2 +- src/arithmetic/modular_set.cc | 2 +- src/codegen/build_module.cc | 24 +- src/codegen/codegen_c.cc | 2 +- src/codegen/llvm/codegen_llvm.cc | 2 +- src/codegen/spirv/codegen_spirv.cc | 2 +- src/contrib/hybrid/codegen_hybrid.cc | 4 +- src/contrib/hybrid/codegen_hybrid.h | 1 - src/lang/attr_functor.h | 80 ++--- src/lang/attrs.cc | 52 +-- src/lang/data_layout.cc | 8 +- src/lang/expr.cc | 4 +- src/lang/ir.cc | 8 +- src/lang/reflection.cc | 105 +++--- src/node/node.cc | 76 ----- src/op/compute_op.cc | 8 +- src/op/hybrid_op.cc | 4 +- src/op/op_util.cc | 2 +- src/op/tensorize.cc | 2 +- src/pass/arg_binder.cc | 2 +- src/pass/coproc_sync.cc | 6 +- src/pass/hoist_if_then_else.cc | 7 +- src/pass/inject_copy_intrin.cc | 10 +- src/pass/inject_double_buffer.cc | 2 +- src/pass/inject_prefetch.cc | 2 +- src/pass/inject_virtual_thread.cc | 5 +- src/pass/ir_mutator.cc | 2 +- src/pass/lift_attr_scope.cc | 6 +- src/pass/lower_thread_allreduce.cc | 6 +- src/pass/lower_warp_memory.cc | 6 +- src/pass/make_api.cc | 4 +- src/pass/narrow_channel_access.cc | 2 +- src/pass/remap_thread_axis.cc | 6 +- src/pass/split_host_device.cc | 10 +- src/pass/split_pipeline.cc | 8 +- src/pass/storage_access.cc | 12 +- src/pass/storage_flatten.cc | 16 +- src/pass/storage_rewrite.cc | 13 +- src/pass/storage_sync.cc | 6 +- src/pass/unroll_loop.cc | 3 +- src/pass/vectorize_loop.cc | 3 +- src/pass/verify_memory.cc | 6 +- src/relay/backend/compile_engine.cc | 8 +- src/relay/backend/compile_engine.h | 18 +- src/relay/backend/graph_runtime_codegen.cc | 6 +- src/relay/ir/alpha_equal.cc | 15 +- src/relay/ir/expr_functor.cc | 7 +- src/relay/ir/hash.cc | 21 +- src/relay/ir/module.cc | 9 +- src/relay/ir/op.cc | 12 +- src/relay/ir/pretty_printer.cc | 18 +- src/relay/ir/type_functor.h | 14 +- src/relay/op/algorithm/argsort.cc | 2 +- src/relay/op/algorithm/topk.cc | 2 +- src/relay/op/debug.cc | 13 +- src/relay/op/image/resize.cc | 6 +- src/relay/op/nn/bitserial.cc | 38 +-- src/relay/op/nn/convolution.cc | 20 +- src/relay/op/nn/nn.cc | 26 +- src/relay/op/nn/pad.cc | 4 +- src/relay/op/nn/pooling.cc | 16 +- src/relay/op/nn/sparse.cc | 4 +- src/relay/op/nn/upsampling.cc | 2 +- src/relay/op/tensor/reduce.cc | 18 +- src/relay/op/tensor/transform.cc | 75 +++-- src/relay/op/tensor/unary.cc | 10 +- src/relay/op/vision/multibox_op.cc | 8 +- src/relay/op/vision/yolo.cc | 2 +- src/relay/pass/alter_op_layout.cc | 15 +- src/relay/pass/device_annotation.cc | 10 +- src/relay/pass/eta_expand.cc | 4 +- src/relay/pass/fold_constant.cc | 2 +- src/relay/pass/fold_scale_axis.cc | 10 +- src/relay/pass/partial_eval.cc | 14 +- src/relay/pass/pass_manager.cc | 7 +- src/relay/pass/quantize/annotate.cc | 4 +- src/relay/pass/quantize/partition.cc | 3 + src/relay/pass/quantize/quantize.cc | 2 +- src/relay/pass/quantize/quantize.h | 8 +- src/relay/pass/quantize/realize.cc | 22 +- src/relay/pass/type_infer.cc | 21 +- src/relay/pass/type_solver.cc | 2 +- src/relay/qnn/op/concatenate.cc | 2 +- src/relay/qnn/op/convolution.cc | 2 +- src/relay/qnn/op/dense.cc | 2 +- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/quantize.cc | 2 +- src/relay/qnn/op/requantize.cc | 2 +- src/runtime/c_dsl_api.cc | 91 ------ src/runtime/c_runtime_api.cc | 2 +- src/runtime/dsl_api.h | 59 ---- src/runtime/object.cc | 21 +- src/schedule/graph.cc | 2 +- src/schedule/schedule_dataflow_rewrite.cc | 18 +- src/schedule/schedule_lang.cc | 24 +- src/schedule/schedule_ops.cc | 6 +- tests/cpp/expr_test.cc | 4 +- tests/cpp/ir_functor_test.cc | 2 +- tests/cpp/object_protocol_test.cc | 6 +- tests/cpp/packed_func_test.cc | 2 +- tests/python/unittest/test_lang_schedule.py | 6 +- .../unittest/test_runtime_vm_profiler.py | 2 + topi/include/topi/cuda/pooling.h | 2 +- topi/include/topi/cuda/reduction.h | 2 +- topi/include/topi/detail/constant_utils.h | 15 +- topi/include/topi/generic/extern.h | 2 +- topi/src/topi.cc | 5 +- web/tvm_runtime.js | 8 +- 185 files changed, 1442 insertions(+), 2387 deletions(-) delete mode 100644 include/tvm/c_dsl_api.h delete mode 100644 include/tvm/node/memory.h delete mode 100644 include/tvm/runtime/node_base.h delete mode 100644 python/tvm/_ffi/_ctypes/node.py delete mode 100644 python/tvm/_ffi/_cython/node.pxi delete mode 100644 src/node/node.cc delete mode 100644 src/runtime/c_dsl_api.cc delete mode 100644 src/runtime/dsl_api.h diff --git a/golang/src/value.go b/golang/src/value.go index 576331a8cfa0..5e0f78270eaa 100644 --- a/golang/src/value.go +++ b/golang/src/value.go @@ -44,8 +44,8 @@ var KTVMType = int32(C.kTVMType) var KTVMContext = int32(C.kTVMContext) // KArrayHandle is golang type code for TVM kArrayHandle. var KArrayHandle = int32(C.kArrayHandle) -// KNodeHandle is golang type code for TVM kNodeHandle. -var KNodeHandle = int32(C.kNodeHandle) +// KObjectHandle is golang type code for TVM kObjectHandle. +var KObjectHandle = int32(C.kObjectHandle) // KModuleHandle is gonag type code for TVM kModuleHandle. var KModuleHandle = int32(C.kModuleHandle) // KFuncHandle is gonalg type code for TVM kFuncHandle. diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index e12d841519ca..dbd097293593 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -79,7 +79,7 @@ class EnvFunc : public NodeRef { explicit EnvFunc(NodePtr n) : NodeRef(n) {} /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. @@ -124,19 +124,19 @@ class TypedEnvFunc : public NodeRef { /*! \brief short hand for this function type */ using TSelf = TypedEnvFunc; TypedEnvFunc() {} - explicit TypedEnvFunc(NodePtr n) : NodeRef(n) {} + explicit TypedEnvFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Assign global function to a TypedEnvFunc * \param other Another global function. * \return reference to self. */ TSelf& operator=(const EnvFunc& other) { - this->node_ = other.node_; + ObjectRef::operator=(other); return *this; } /*! \return The internal global function pointer */ const EnvFuncNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief Invoke the function. diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 8be1c3604813..e81fa0afd254 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -362,7 +362,7 @@ class IntSet : public NodeRef { /*! \brief constructor */ IntSet() {} // constructor from not container. - explicit IntSet(NodePtr n) : NodeRef(n) {} + explicit IntSet(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -692,7 +692,7 @@ Array DetectClipBound(const Expr& e, // implementation inline const IntSetNode* IntSet::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace arith } // namespace tvm diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index 3b64d1f961e2..fb8927a75613 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -163,7 +163,7 @@ class AttrsEqual { return lhs == rhs; } // node comparator - TVM_DLL bool operator()(const NodeRef& lhs, const NodeRef& rhs) const; + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; protected: friend class AttrsEqualHandler; @@ -203,7 +203,7 @@ class AttrsHash { (static_cast(value.bits()) << 8) | (static_cast(value.lanes()) << 16)); } - TVM_DLL size_t operator()(const NodeRef& value) const; + TVM_DLL size_t operator()(const ObjectRef& value) const; private: friend class AttrsHashHandler; @@ -260,7 +260,7 @@ class BaseAttrsNode : public Node { * \return The comparison result. */ TVM_DLL virtual bool ContentEqual( - const Node* other, AttrsEqual equal) const = 0; + const Object* other, AttrsEqual equal) const = 0; /*! * \brief Content aware hash. * \param hasher The hasher to run the hash. @@ -290,7 +290,7 @@ class Attrs : public NodeRef { private: /*! \return the internal attribute node */ const BaseAttrsNode* ptr() const { - return static_cast(node_.get()); + return static_cast(get()); } }; @@ -315,7 +315,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; - bool ContentEqual(const Node* other, AttrsEqual equal) const final; + bool ContentEqual(const Object* other, AttrsEqual equal) const final; size_t ContentHash(AttrsHash hasher) const final; // type info static constexpr const char* _type_key = "DictAttrs"; @@ -369,7 +369,7 @@ class AttrsEqualVisitor { public: bool result_{true}; // constructor - AttrsEqualVisitor(const Node* lhs, const Node* rhs, const AttrsEqual& equal) + AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal) : lhs_(lhs), rhs_(rhs), equal_(equal) { } template @@ -387,8 +387,8 @@ class AttrsEqualVisitor { } private: - const Node* lhs_; - const Node* rhs_; + const Object* lhs_; + const Object* rhs_; const AttrsEqual& equal_; }; @@ -488,7 +488,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect int value, but get " << expr->type_key(); + LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey(); } } } @@ -521,7 +521,7 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } else if (const ir::UIntImm* op = expr.as()) { *ptr = static_cast(op->value); } else { - LOG(FATAL) << "Expect float value, but get " << expr->type_key(); + LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); } } } @@ -827,7 +827,7 @@ class AttrsNode : public BaseAttrsNode { return visitor.fields_; } - bool ContentEqual(const Node* other, AttrsEqual equal) const final { + bool ContentEqual(const Object* other, AttrsEqual equal) const final { DerivedType* pself = self(); if (pself == other) return true; if (other == nullptr) return false; @@ -839,7 +839,7 @@ class AttrsNode : public BaseAttrsNode { size_t ContentHash(AttrsHash hasher) const final { ::tvm::detail::AttrsHashVisitor visitor(hasher); - visitor.result_ = std::hash()(this->type_key()); + visitor.result_ = this->GetTypeKeyHash(); self()->__VisitAttrs__(visitor); return visitor.result_; } diff --git a/include/tvm/base.h b/include/tvm/base.h index f358f7f5d447..a42de10abef2 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -47,9 +47,10 @@ using ::tvm::AttrVisitor; */ #define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : BaseTypeName(n) {} \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(data_.get()); \ } \ operator bool() const { return this->defined(); } \ using ContainerType = NodeName; @@ -75,12 +76,12 @@ using ::tvm::AttrVisitor; */ #define TVM_DEFINE_NODE_REF_COW(NodeName) \ NodeName* CopyOnWrite() { \ - CHECK(node_ != nullptr); \ - if (!node_.unique()) { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ NodePtr n = make_node(*(operator->())); \ - NodePtr(std::move(n)).swap(node_); \ + ObjectPtr(std::move(n)).swap(data_); \ } \ - return static_cast(node_.get()); \ + return static_cast(data_.get()); \ } /*! \brief Macro to make it easy to define node ref type given node */ @@ -160,7 +161,7 @@ std::string SaveJSON(const NodeRef& node); * * \return The shared_ptr of the Node. */ -NodePtr LoadJSON_(std::string json_str); +ObjectPtr LoadJSON_(std::string json_str); /*! * \brief Load the node from json string. @@ -233,6 +234,7 @@ struct NodeFactoryReg { * \note This is necessary to enable serialization of the Node. */ #define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ .set_creator([](const std::string&) { return ::tvm::make_node(); }) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 1233e9b0b89b..f18ed9206db3 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -51,7 +51,7 @@ enum BufferType : int { class Buffer : public NodeRef { public: Buffer() {} - explicit Buffer(NodePtr n) : NodeRef(n) {} + explicit Buffer(ObjectPtr n) : NodeRef(n) {} /*! * \brief Return a new buffer that is equivalent with current one * but always add stride field. @@ -171,7 +171,7 @@ class BufferNode : public Node { }; inline const BufferNode* Buffer::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 1d57d82e66c6..c985fbe17546 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -93,7 +93,7 @@ class TargetNode : public Node { class Target : public NodeRef { public: Target() {} - explicit Target(NodePtr n) : NodeRef(n) {} + explicit Target(ObjectPtr n) : NodeRef(n) {} /*! * \brief Create a Target given a string * \param target_str the string to parse @@ -110,7 +110,7 @@ class Target : public NodeRef { TVM_DLL static tvm::Target Current(bool allow_not_defined = true); const TargetNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = TargetNode; @@ -256,12 +256,12 @@ class BuildConfigNode : public Node { class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} - explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} + explicit BuildConfig(ObjectPtr n) : NodeRef(n) {} const BuildConfigNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } BuildConfigNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! * \brief Construct a BuildConfig containing a empty build config node. @@ -371,7 +371,7 @@ class GenericFuncNode; class GenericFunc : public NodeRef { public: GenericFunc() {} - explicit GenericFunc(NodePtr n) : NodeRef(n) {} + explicit GenericFunc(ObjectPtr n) : NodeRef(n) {} /*! * \brief Set the default function implementaiton. @@ -478,10 +478,10 @@ class GenericFuncNode : public Node { }; inline GenericFuncNode* GenericFunc::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ +#define TVM_GENERIC_FUNC_REG_VAR_DEF \ static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM /*! diff --git a/include/tvm/c_dsl_api.h b/include/tvm/c_dsl_api.h deleted file mode 100644 index bbbb84926e8e..000000000000 --- a/include/tvm/c_dsl_api.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/c_dsl_api.h - * - * \brief TVM DSL Node C API, used to interact to DSL compilation. - * - * These are only a few functions needed for DSL construction time. - * These function are only available when link libtvm. - * If only TVM runtime is linked, calling these function will trigger error. - * - * \note Most API functions are registerd as PackedFunc and - * can be grabbed via TVMFuncGetGlobal - */ -#ifndef TVM_C_DSL_API_H_ -#define TVM_C_DSL_API_H_ - -#include "runtime/c_runtime_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief handle to node */ -typedef void* NodeHandle; - -/*! - * \brief free the node handle - * \param handle The node handle to be freed. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeFree(NodeHandle handle); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeTypeKey2Index(const char* type_key, - int* out_index); - -/*! - * \brief Get runtime type index of the node. - * \param handle the node handle. - * \param out_index the corresponding type index. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index); - -/*! - * \brief get attributes given key - * \param handle The node handle - * \param key The attribute name - * \param out_value The attribute value - * \param out_type_code The type code of the attribute. - * \param out_success Whether get is successful. - * \return 0 when success, -1 when failure happens - * \note API calls always exchanges with type bits=64, lanes=1 - */ -TVM_DLL int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success); - -/*! - * \brief get attributes names in the node. - * \param handle The node handle - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array); -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif -#endif // TVM_C_DSL_API_H_ diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 143d4295f3e3..346291a6b06a 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -35,7 +35,7 @@ class Channel : public NodeRef { public: /*! \brief default constructor */ Channel() {} - explicit Channel(NodePtr n) : NodeRef(n) {} + explicit Channel(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -67,7 +67,7 @@ struct ChannelNode : public Node { // Inline implementations inline const ChannelNode* Channel::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_CHANNEL_H_ diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index c2ae572de818..ad3da6b347af 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -127,7 +127,7 @@ class LayoutNode : public Node { */ class Layout : public NodeRef { public: - explicit Layout(NodePtr n) : NodeRef(n) {} + explicit Layout(ObjectPtr n) : NodeRef(n) {} /*! \brief default constructor */ Layout() = default; @@ -152,7 +152,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ const LayoutNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! @@ -160,7 +160,7 @@ class Layout : public NodeRef { * \return the pointer to the internal node container */ LayoutNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } /*! @@ -369,7 +369,7 @@ class BijectiveLayout : public NodeRef { }; inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 201a2b485aa6..d884a4d61748 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -49,7 +49,7 @@ class ExprNode : public Node { class Expr : public NodeRef { public: Expr() {} - explicit Expr(NodePtr ptr) : NodeRef(ptr) {} + explicit Expr(ObjectPtr ptr) : NodeRef(ptr) {} /*! * \brief construct from integer. * \param value The value to be constructed. @@ -122,7 +122,7 @@ class Variable : public ExprNode { /*! \brief a named variable in TVM */ class Var : public Expr { public: - explicit Var(NodePtr n) : Expr(n) {} + explicit Var(ObjectPtr n) : Expr(n) {} TVM_DLL explicit Var(std::string name_hint = "v", Type t = Int(32)); /*! @@ -145,7 +145,7 @@ class Var : public Expr { * \return the corresponding Variable. */ const Variable* get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = Variable; @@ -187,7 +187,7 @@ class Integer : public Expr { /*! * \brief constructor from node. */ - explicit Integer(NodePtr node) : Expr(node) {} + explicit Integer(ObjectPtr node) : Expr(node) {} /*! * \brief Construct integer from int value. */ @@ -197,7 +197,7 @@ class Integer : public Expr { * \param other another expression. */ Integer& operator=(const Integer& other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -205,13 +205,13 @@ class Integer : public Expr { * \return the content of the integer. */ const IntImm* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } /*! * \brief convert to int64_t */ operator int64_t() const { - CHECK(node_ != nullptr) + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } @@ -346,7 +346,7 @@ class IterVar : public NodeRef { // construct a new iter var without a domain IterVar() {} // construct from shared ptr. - explicit IterVar(NodePtr n) : NodeRef(n) {} + explicit IterVar(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -423,7 +423,7 @@ class IterVarNode : public Node { // inline implementations inline const IterVarNode* IterVar::operator->() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline IterVar::operator Expr() const { @@ -481,11 +481,11 @@ class IRPrinter { : stream(stream) {} /*! \brief The node to be printed. */ - TVM_DLL void Print(const NodeRef& node); + TVM_DLL void Print(const ObjectRef& node); /*! \brief Print indent to the stream */ TVM_DLL void PrintIndent(); // Allow registration to be printer. - using FType = IRFunctor; + using FType = IRFunctor; TVM_DLL static FType& vtable(); }; @@ -498,10 +498,7 @@ inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT namespace std { template <> -struct hash<::tvm::IterVar> { - std::size_t operator()(const ::tvm::IterVar& k) const { - return k.hash(); - } +struct hash<::tvm::IterVar> : public ::tvm::NodeHash { }; } #endif // TVM_EXPR_H_ diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 079f05f5a7f2..b90804983cfb 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -664,10 +664,10 @@ class CommReducerNode : public Node { }; inline const CommReducerNode* CommReducer::get() const { - return static_cast(node_.get()); + return static_cast(data_.get()); } inline const CommReducerNode* CommReducer::operator->() const { - return static_cast(node_.get()); + return get(); } /*! \brief Reduction operator operator */ @@ -1576,7 +1576,7 @@ namespace std { template <> struct hash<::tvm::ir::TensorKey> { std::size_t operator()(const ::tvm::ir::TensorKey& k) const { - size_t lhs = k.f.hash(); + size_t lhs = ::tvm::NodeHash()(k.f); size_t rhs = static_cast(k.value_index); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index a7d91eacf851..54a5eff6846b 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -84,19 +84,19 @@ class StmtFunctor; } #define STMT_FUNCTOR_DEFAULT { \ return VisitStmtDefault_(op, std::forward(args)...); \ -} + } #define IR_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ #define IR_STMT_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), \ std::forward(args)...); \ }); \ @@ -104,7 +104,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -164,7 +164,7 @@ class ExprFunctor { virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -213,7 +213,7 @@ template class StmtFunctor { private: using TSelf = StmtFunctor; - using FType = IRFunctor; + using FType = IRFunctor; public: /*! \brief the result type of this functor */ @@ -255,7 +255,7 @@ class StmtFunctor { virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Node* op, Args ...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index b82a19d4689c..c910a48620c8 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -65,9 +65,9 @@ class TVM_DLL IRMutator { /*! \brief destructor */ virtual ~IRMutator() {} /*! \brief functor type of expr mutation */ - using FMutateExpr = IRFunctor; + using FMutateExpr = IRFunctor; /*! \brief functor type of stmt mutation */ - using FMutateStmt = IRFunctor; + using FMutateStmt = IRFunctor; /*! \return internal vtable of expr */ static FMutateExpr& vtable_expr(); // NOLINT(*) /*! \return internal stmt of expr */ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index f20b91368587..bebf94585ed6 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -49,7 +49,7 @@ namespace ir { * // The use case is to count number of Variables in the ir tree. * class MyCounter : public IRVisitor { * public: - * int Count(const NodeRef& n) { + * int Count(const ObjectRef& n) { * ret_ = 0; * this->Visit(n); * return ret_; @@ -94,7 +94,7 @@ class TVM_DLL IRVisitor { /*! \brief destructor */ virtual ~IRVisitor() {} /*! \brief functor type of visitor */ - using FVisit = IRFunctor; + using FVisit = IRFunctor; /*! \return internal vtable*/ static FVisit& vtable(); // overloadable visit function. diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index 4da93b80c2ab..e2147d036587 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -44,7 +44,7 @@ class LoweredFuncNode; class LoweredFunc : public ir::FunctionRef { public: LoweredFunc() {} - explicit LoweredFunc(NodePtr n) : FunctionRef(n) {} + explicit LoweredFunc(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -136,17 +136,14 @@ class LoweredFuncNode : public ir::FunctionBaseNode { // Implementations of inline functions inline const LoweredFuncNode* LoweredFunc::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm namespace std { template <> -struct hash<::tvm::LoweredFunc> { - std::size_t operator()(const ::tvm::LoweredFunc& k) const { - return k.hash(); - } +struct hash<::tvm::LoweredFunc> : public tvm::NodeHash { }; } diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index c2c639e374f5..2e1a978f4806 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -38,14 +38,14 @@ namespace tvm { class ArrayNode : public Node { public: /*! \brief the data content */ - std::vector > data; + std::vector data; void VisitAttrs(AttrVisitor* visitor) final { // Visitor to array have no effect. } static constexpr const char* _type_key = "Array"; - TVM_DECLARE_NODE_TYPE_INFO(ArrayNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Node); }; /*! \brief map node content */ @@ -54,32 +54,17 @@ class MapNode : public Node { void VisitAttrs(AttrVisitor* visitor) final { // Visitor to map have no effect. } - // hash function - struct Hash { - size_t operator()(const NodePtr& n) const { - return std::hash()(n.get()); - } - }; - // comparator - struct Equal { - bool operator()( - const NodePtr& a, - const NodePtr& b) const { - return a.get() == b.get(); - } - }; - /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map< - NodePtr, - NodePtr, - Hash, Equal>; + ObjectRef, + ObjectRef, + ObjectHash, ObjectEqual>; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_NODE_TYPE_INFO(MapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Node); }; @@ -90,15 +75,13 @@ class StrMapNode : public Node { // Visitor to map have no effect. } /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - std::string, - NodePtr >; + using ContainerType = std::unordered_map; /*! \brief the data content */ ContainerType data; static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_NODE_TYPE_INFO(StrMapNode, Node); + TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Node); }; /*! @@ -111,9 +94,9 @@ template::difference_type; - using value_type = typename std::iterator_traits::value_type; - using pointer = typename std::iterator_traits::pointer; - using reference = typename std::iterator_traits::reference; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) using iterator_category = typename std::iterator_traits::iterator_category; explicit IterAdapter(TIter iter) : iter_(iter) {} @@ -138,7 +121,7 @@ class IterAdapter { inline bool operator!=(IterAdapter other) const { return !(*this == other); } - inline const typename Converter::ResultType operator*() const { + inline const value_type operator*() const { return Converter::convert(*iter_); } @@ -162,26 +145,27 @@ class Array : public NodeRef { * \brief default constructor */ Array() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Array(Array && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Array(const Array &other) : NodeRef(other.node_) { // NOLINT(*) + Array(const Array &other) { // NOLINT(*) + data_ = std::move(other.data_); } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Array(NodePtr n) : NodeRef(n) {} + explicit Array(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -214,9 +198,9 @@ class Array : public NodeRef { explicit Array(size_t n, const T& val) { auto tmp_node = make_node(); for (size_t i = 0; i < n; ++i) { - tmp_node->data.push_back(val.node_); + tmp_node->data.push_back(val); } - node_ = std::move(tmp_node); + data_ = std::move(tmp_node); } /*! * \brief move assign operator @@ -224,7 +208,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(Array && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -233,7 +217,7 @@ class Array : public NodeRef { * \return reference to self. */ Array& operator=(const Array & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -246,9 +230,9 @@ class Array : public NodeRef { void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType it = begin; it != end; ++it) { - n->data.push_back((*it).node_); + n->data.push_back(T(*it)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read i-th element from array. @@ -256,12 +240,13 @@ class Array : public NodeRef { * \return the i-th element. */ inline const T operator[](size_t i) const { - return T(static_cast(node_.get())->data[i]); + return DowncastNoCheck( + static_cast(data_.get())->data[i]); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! * \brief copy on write semantics @@ -272,12 +257,12 @@ class Array : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline ArrayNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief push a new item to the back of the list @@ -285,7 +270,7 @@ class Array : public NodeRef { */ inline void push_back(const T& item) { ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item.node_); + n->data.push_back(item); } /*! * \brief set i-th element of the array. @@ -294,7 +279,7 @@ class Array : public NodeRef { */ inline void Set(size_t i, const T& value) { ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value.node_; + n->data[i] = value; } /*! \return whether array is empty */ inline bool empty() const { @@ -303,34 +288,34 @@ class Array : public NodeRef { /*! \brief specify container node */ using ContainerType = ArrayNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = T; - static inline T convert(const NodePtr& n) { - return T(n); + static inline T convert(const ObjectRef& n) { + return DowncastNoCheck(n); } }; - using iterator = IterAdapter >::const_iterator>; + using iterator = IterAdapter::const_iterator>; using reverse_iterator = IterAdapter< - Ptr2NodeRef, - std::vector >::const_reverse_iterator>; + ValueConverter, + std::vector::const_reverse_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return rbegin iterator */ inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(node_.get())->data.rbegin()); + return reverse_iterator(static_cast(data_.get())->data.rbegin()); } /*! \return rend iterator */ inline reverse_iterator rend() const { - return reverse_iterator(static_cast(node_.get())->data.rend()); + return reverse_iterator(static_cast(data_.get())->data.rend()); } }; @@ -355,26 +340,26 @@ class Map : public NodeRef { * \brief default constructor */ Map() { - node_ = make_node(); + data_ = make_node(); } /*! * \brief move constructor * \param other source */ Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer * \param n the container pointer */ - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} /*! * \brief constructor from iterator * \param begin begin of iterator @@ -406,7 +391,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } /*! @@ -415,7 +400,7 @@ class Map : public NodeRef { * \return reference to self. */ Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } /*! @@ -428,10 +413,9 @@ class Map : public NodeRef { void assign(IterType begin, IterType end) { NodePtr n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first.node_, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } /*! * \brief Read element from map. @@ -439,7 +423,8 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -447,17 +432,18 @@ class Map : public NodeRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return V(static_cast(node_.get())->data.at(key.node_)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } /*! \return The number of elements of the key */ inline size_t count(const K& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key.node_); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } /*! * \brief copy on write semantics @@ -468,12 +454,12 @@ class Map : public NodeRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } /*! * \brief set the Map. @@ -482,7 +468,7 @@ class Map : public NodeRef { */ inline void Set(const K& key, const V& value) { MapNode* n = this->CopyOnWrite(); - n->data[key.node_] = value.node_; + n->data[key] = value; } /*! \return whether array is empty */ @@ -492,29 +478,31 @@ class Map : public NodeRef { /*! \brief specify container node */ using ContainerType = MapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - NodePtr, - NodePtr >& n) { - return std::make_pair(K(n.first), V(n.second)); + ObjectRef, + ObjectRef>& n) { + return std::make_pair(DowncastNoCheck(n.first), + DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, MapNode::ContainerType::const_iterator>; + ValueConverter, MapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator(static_cast(node_.get())->data.find(key.node_)); + return iterator( + static_cast(data_.get())->data.find(key)); } }; @@ -524,14 +512,14 @@ class Map : public NodeRef { public: // for code reuse Map() { - node_ = make_node(); + data_ = make_node(); } Map(Map && other) { // NOLINT(*) - node_ = std::move(other.node_); + data_ = std::move(other.data_); } - Map(const Map &other) : NodeRef(other.node_) { // NOLINT(*) + Map(const Map &other) : NodeRef(other.data_) { // NOLINT(*) } - explicit Map(NodePtr n) : NodeRef(n) {} + explicit Map(ObjectPtr n) : NodeRef(n) {} template Map(IterType begin, IterType end) { assign(begin, end); @@ -545,76 +533,77 @@ class Map : public NodeRef { assign(init.begin(), init.end()); } Map& operator=(Map && other) { - node_ = std::move(other.node_); + data_ = std::move(other.data_); return *this; } Map& operator=(const Map & other) { - node_ = other.node_; + data_ = other.data_; return *this; } template void assign(IterType begin, IterType end) { auto n = make_node(); for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, - i->second.node_)); + n->data.emplace(std::make_pair(i->first, i->second)); } - node_ = std::move(n); + data_ = std::move(n); } inline const V operator[](const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline const V at(const std::string& key) const { - return V(static_cast(node_.get())->data.at(key)); + return DowncastNoCheck( + static_cast(data_.get())->data.at(key)); } inline size_t size() const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.size(); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.size(); } inline size_t count(const std::string& key) const { - if (node_.get() == nullptr) return 0; - return static_cast(node_.get())->data.count(key); + if (data_.get() == nullptr) return 0; + return static_cast(data_.get())->data.count(key); } inline StrMapNode* CopyOnWrite() { - if (node_.get() == nullptr || !node_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { NodePtr n = make_node(); - n->data = static_cast(node_.get())->data; - NodePtr(std::move(n)).swap(node_); + n->data = static_cast(data_.get())->data; + ObjectPtr(std::move(n)).swap(data_); } - return static_cast(node_.get()); + return static_cast(data_.get()); } inline void Set(const std::string& key, const V& value) { StrMapNode* n = this->CopyOnWrite(); - n->data[key] = value.node_; + n->data[key] = value; } inline bool empty() const { return size() == 0; } using ContainerType = StrMapNode; - struct Ptr2NodeRef { + struct ValueConverter { using ResultType = std::pair; static inline ResultType convert(const std::pair< - std::string, - NodePtr >& n) { - return std::make_pair(n.first, V(n.second)); + std::string, + ObjectRef>& n) { + return std::make_pair(n.first, DowncastNoCheck(n.second)); } }; using iterator = IterAdapter< - Ptr2NodeRef, StrMapNode::ContainerType::const_iterator>; + ValueConverter, StrMapNode::ContainerType::const_iterator>; /*! \return begin iterator */ inline iterator begin() const { - return iterator(static_cast(node_.get())->data.begin()); + return iterator(static_cast(data_.get())->data.begin()); } /*! \return end iterator */ inline iterator end() const { - return iterator(static_cast(node_.get())->data.end()); + return iterator(static_cast(data_.get())->data.end()); } /*! \return begin iterator */ inline iterator find(const std::string& key) const { - return iterator(static_cast(node_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; diff --git a/include/tvm/node/ir_functor.h b/include/tvm/node/ir_functor.h index 23c5a3fafdab..e902e8fb6d44 100644 --- a/include/tvm/node/ir_functor.h +++ b/include/tvm/node/ir_functor.h @@ -34,10 +34,10 @@ namespace tvm { /*! - * \brief A dynamically dispatched functor on NodeRef in the first argument. + * \brief A dynamically dispatched functor on ObjectRef in the first argument. * * \code - * IRFunctor tostr; + * IRFunctor tostr; * tostr.set_dispatch([](const Add* op, std::string prefix) { * return prefix + "Add"; * }); @@ -60,10 +60,10 @@ template class IRFunctor; template -class IRFunctor { +class IRFunctor { private: - using Function = std::function; - using TSelf = IRFunctor; + using Function = std::function; + using TSelf = IRFunctor; /*! \brief internal function table */ std::vector func_; @@ -75,8 +75,8 @@ class IRFunctor { * \param n The node to be dispatched * \return Whether dispatching function is registered for n's type. */ - inline bool can_dispatch(const NodeRef& n) const { - uint32_t type_index = n.type_index(); + inline bool can_dispatch(const ObjectRef& n) const { + uint32_t type_index = n->type_index(); return type_index < func_.size() && func_[type_index] != nullptr; } /*! @@ -85,12 +85,12 @@ class IRFunctor { * \param args The additional arguments * \return The result. */ - inline R operator()(const NodeRef& n, Args... args) const { - uint32_t type_index = n.type_index(); + inline R operator()(const ObjectRef& n, Args... args) const { + uint32_t type_index = n->type_index(); CHECK(type_index < func_.size() && func_[type_index] != nullptr) << "IRFunctor calls un-registered function on type " - << Node::TypeIndex2Key(type_index); + << n->GetTypeKey(); return func_[type_index](n, std::forward(args)...); } /*! @@ -101,19 +101,19 @@ class IRFunctor { */ template inline TSelf& set_dispatch(Function f) { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } CHECK(func_[tindex] == nullptr) - << "Dispatch for " << Node::TypeIndex2Key(tindex) + << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! * \brief set the dispacher for type TNode - * This allows f to used detailed const Node pointer to replace NodeRef + * This allows f to used detailed const Node pointer to replace ObjectRef * * \param f The function to be set. * \tparam TNode the type of Node to be dispatched. @@ -121,8 +121,8 @@ class IRFunctor { */ template inline TSelf& set_dispatch(std::function f) { // NOLINT(*) - Function fun = [f](const NodeRef& n, Args... args) { - return f(static_cast(n.node_.get()), + Function fun = [f](const ObjectRef& n, Args... args) { + return f(static_cast(n.get()), std::forward(args)...); }; return this->set_dispatch(fun); @@ -135,7 +135,7 @@ class IRFunctor { */ template inline TSelf& clear_dispatch() { // NOLINT(*) - uint32_t tindex = Node::TypeKey2Index(TNode::_type_key); + uint32_t tindex = TNode::RuntimeTypeIndex(); CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; @@ -172,7 +172,7 @@ class IRFunctor { * f(e, this); * } * - * using FType = IRFunctor; + * using FType = IRFunctor; * // function to return global function table * static FType& vtable(); * }; @@ -232,15 +232,15 @@ template class IRFunctorStaticRegistry; template -class IRFunctorStaticRegistry { +class IRFunctorStaticRegistry { private: - IRFunctor *irf_; + IRFunctor *irf_; std::shared_ptr free_list; - using TSelf = IRFunctorStaticRegistry; + using TSelf = IRFunctorStaticRegistry; public: - IRFunctorStaticRegistry(IRFunctor *irf) { + IRFunctorStaticRegistry(IRFunctor *irf) { irf_ = irf; free_list = std::make_shared(); } @@ -261,12 +261,12 @@ class IRFunctorStaticRegistry { * the compiler to deduce the template types. */ template -IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( - IRFunctor *irf) { - return IRFunctorStaticRegistry(irf); +IRFunctorStaticRegistry MakeIRFunctorStaticRegistry( + IRFunctor *irf) { + return IRFunctorStaticRegistry(irf); } -#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ +#define TVM_AUTO_REGISTER_VAR_DEF(ClsName) \ static TVM_ATTRIBUTE_UNUSED auto __make_functor ## _ ## ClsName /*! diff --git a/include/tvm/node/memory.h b/include/tvm/node/memory.h deleted file mode 100644 index 1bba57144e19..000000000000 --- a/include/tvm/node/memory.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/node/memory.h - * \brief Node memory management. - */ -#ifndef TVM_NODE_MEMORY_H_ -#define TVM_NODE_MEMORY_H_ - -#include -#include "node.h" - -namespace tvm { -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - */ -template -inline NodePtr make_node(Args&&... args); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. -// -template -class SimpleNodeAllocator { - public: - template - static T* New(Args&&... args) { - return new T(std::forward(args)...); - } - static NodeBase::FDeleter Deleter() { - return Deleter_; - } - - private: - static void Deleter_(NodeBase* ptr) { - delete static_cast(ptr); - } -}; - -template -inline NodePtr make_node(Args&&... args) { - using Allocator = SimpleNodeAllocator; - static_assert(std::is_base_of::value, - "make_node can only be used to create NodeBase"); - T* node = Allocator::New(std::forward(args)...); - node->deleter_ = Allocator::Deleter(); - return NodePtr(node); -} - -} // namespace tvm -#endif // TVM_NODE_MEMORY_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index cb18e46e9a5c..8203ee69f686 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -25,7 +25,9 @@ #include #include -#include +#include +#include +#include #include #include #include @@ -38,13 +40,6 @@ class DataType; class Node; class NodeRef; -namespace runtime { -// forward declaration -class NDArray; -// forward declaration -class ObjectRef; -} // namespace runtime - /*! * \brief Visitor class to each node content. * The content is going to be called for each field. @@ -74,15 +69,17 @@ class TVM_DLL AttrVisitor { //! \endcond }; +/*! \brief Reuse the type index in he runtime. */ +using TypeIndex = runtime::TypeIndex; + /*! * \brief base class of node container in DSL AST. */ -class TVM_DLL Node : public NodeBase { +class Node : public runtime::Object { public: /*! \brief virtual destructor */ virtual ~Node() {} - /*! \return The unique type key of the node */ - virtual const char* type_key() const = 0; + /*! * \brief Apply visitor to each field of the Node * Visitor could mutate the content of the node. @@ -90,272 +87,79 @@ class TVM_DLL Node : public NodeBase { * \param visitor The visitor */ virtual void VisitAttrs(AttrVisitor* visitor) {} - /*! \return the type index of the node */ - virtual uint32_t type_index() const = 0; - /*! - * \brief Whether this node derives from node with type_index=tid. - * Implemented by TVM_DECLARE_NODE_TYPE_INFO - * - * \param tid The type index. - * \return the check result. - */ - virtual bool _DerivedFrom(uint32_t tid) const; - /*! - * \brief get a runtime unique type index given a type key - * \param type_key Type key of a type. - * \return the corresponding type index. - */ - static uint32_t TypeKey2Index(const char* type_key); - /*! - * \brief get type key from type index. - * \param index The type index - * \return the corresponding type key. - */ - static const char* TypeIndex2Key(uint32_t index); - /*! - * \return whether the type is derived from - */ - template - inline bool derived_from() const; - /*! - * \return whether the node is of type T - * \tparam The type to be checked. - */ - template - inline bool is_type() const; - /*! - * \brief Get a NodePtr that holds reference to this Node. - * \return the NodePtr - */ - inline NodePtr GetNodePtr() const; - // node ref can see this - friend class NodeRef; + static constexpr const char* _type_key = "Node"; + static constexpr uint32_t _type_index = TypeIndex::kDynamic; + + TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object); }; -/*! \brief Base class of all node reference object */ -class NodeRef { + +/*! + * \brief Base class of all node reference object + * NodeRef is just a alias of ObjectRef. + */ +class NodeRef : public runtime::ObjectRef { public: /*! \brief type indicate the container type */ using ContainerType = Node; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator==(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool same_as(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator<(const NodeRef& other) const; - /*! - * \brief Comparator - * \param other Another node ref. - * \return the compare result. - */ - inline bool operator!=(const NodeRef& other) const; - /*! \return the hash function for NodeRef */ - inline size_t hash() const; - /*! \return whether the expression is null */ - inline bool defined() const; - /*! \return the internal type index of IRNode */ - inline uint32_t type_index() const; + /*! \return the internal node pointer */ - inline const Node* get() const; + const Node* get() const { + return static_cast(ObjectRef::get()); + } /*! \return the internal node pointer */ - inline const Node* operator->() const; - /*! - * \brief Downcast this ir node to its actual type (e.g. Add, or - * Select). This returns nullptr if the node is not of the requested - * type. Example usage: - * - * if (const Add *add = node->as()) { - * // This is an add node - * } - * \tparam T the target type, must be subtype of IRNode - */ - template - inline const T *as() const; + const Node* operator->() const { + return get(); + } /*! * \brief A more powerful version of as that also works with * intermediate base types. * \tparam T the target type, must be subtype of IRNode */ template - inline const T *as_derived() const; + const T *as_derived() const { + return as(); + } /*! \brief default constructor */ NodeRef() = default; - explicit NodeRef(NodePtr node) : node_(node) {} - /*! \brief the internal node object, do not touch */ - NodePtr node_; + explicit NodeRef(runtime::ObjectPtr ptr) : ObjectRef(ptr) {} }; -/*! - * \brief Get a reference type from a Node ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the node alive beyond the scope of the function. - * - * \param ptr The node pointer - * \tparam RefType The reference type - * \tparam NodeType The node type - * \return The corresponding RefType - */ -template -inline RefType GetRef(const NodeType* ptr); - -/*! - * \brief Downcast a base reference type to a more specific type. - * - * \param ref The inptut reference - * \return The corresponding SubRef. - * \tparam SubRef The target specific reference type. - * \tparam BaseRef the current reference type. - */ -template -inline SubRef Downcast(BaseRef ref); - /*! * \brief helper macro to declare type information in a base node. */ -#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ - bool _DerivedFrom(uint32_t tid) const override { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } +#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) /*! * \brief helper macro to declare type information in a terminal node */ -#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ - const char* type_key() const final { \ - return TypeName::_type_key; \ - } \ - uint32_t type_index() const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - return tidx; \ - } \ - bool _DerivedFrom(uint32_t tid) const final { \ - static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ - if (tidx == tid) return true; \ - return Parent::_DerivedFrom(tid); \ - } - -// implementations of inline functions after this -template -inline bool Node::derived_from() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return this->_DerivedFrom(type_id); -} - - -template -inline bool Node::is_type() const { - // use static field so query only happens once. - static uint32_t type_id = Node::TypeKey2Index(T::_type_key); - return type_id == this->type_index(); -} +#define TVM_DECLARE_NODE_TYPE_INFO(TypeName, Parent) \ + TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); -inline NodePtr Node::GetNodePtr() const { - return NodePtr(const_cast(this)); -} +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::GetRef; +using runtime::Downcast; +using runtime::make_object; +using runtime::ObjectHash; +using runtime::ObjectEqual; -template -inline RefType GetRef(const NodeType* ptr) { - static_assert(std::is_base_of::value, - "Can only cast to the ref of same container type"); - return RefType(ptr->GetNodePtr()); -} - -template -inline SubRef Downcast(BaseRef ref) { - CHECK(ref->template is_type() || - ref->template derived_from()) - << "Downcast from " << ref->type_key() << " to " - << SubRef::ContainerType::_type_key << " failed."; - return SubRef(std::move(ref.node_)); -} - -inline const Node* NodeRef::get() const { - return node_.get(); -} - -inline const Node* NodeRef::operator->() const { - return node_.get(); -} - -inline bool NodeRef::defined() const { - return node_.get() != nullptr; -} - -inline bool NodeRef::operator==(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} +using NodeHash = ObjectHash; +using NodeEqual = ObjectEqual; -inline bool NodeRef::same_as(const NodeRef& other) const { - return node_.get() == other.node_.get(); -} - -inline bool NodeRef::operator<(const NodeRef& other) const { - return node_.get() < other.node_.get(); -} - -inline bool NodeRef::operator!=(const NodeRef& other) const { - return node_.get() != other.node_.get(); -} - -inline size_t NodeRef::hash() const { - return std::hash()(node_.get()); -} - -inline uint32_t NodeRef::type_index() const { - CHECK(node_.get() != nullptr) - << "null type"; - return get()->type_index(); -} - -template -inline const T* NodeRef::as() const { - const Node* ptr = static_cast(get()); - if (ptr && ptr->is_type()) { - return static_cast(ptr); - } - return nullptr; -} - -template -inline const T* NodeRef::as_derived() const { - const Node* ptr = static_cast(get()); - if (ptr && (ptr->is_type() || ptr->derived_from())) { - return static_cast(ptr); - } - return nullptr; +/*! + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + */ +template +inline NodePtr make_node(Args&&... args) { + return runtime::make_object(std::forward(args)...); } - -/*! \brief The hash function for nodes */ -struct NodeHash { - size_t operator()(const NodeRef& a) const { - return a.hash(); - } -}; - -/*! \brief The equal comparator for nodes */ -struct NodeEqual { - bool operator()(const NodeRef& a, const NodeRef& b) const { - return a.get() == b.get(); - } -}; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index b950aa952f04..b942464d4907 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -651,7 +651,7 @@ inline Tensor compute(Array shape, // inline function. inline const OperationNode* Operation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_OPERATION_H_ diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 5951594b873c..48d46fdf2fc6 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -37,6 +37,7 @@ #include "runtime/packed_func.h" namespace tvm { + using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; @@ -47,86 +48,82 @@ namespace runtime { * \tparam T the type to be checked. */ template -struct NodeTypeChecker { - static inline bool Check(Node* sptr) { - // This is the only place in the project where RTTI is used - // It can be turned off, but will make non strict checking. - // TODO(tqchen) possibly find alternative to turn of RTTI +struct ObjectTypeChecker { + static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - // always allow nullptr. - if (sptr == nullptr) return true; - return sptr->derived_from(); + if (ptr == nullptr) return true; + return ptr->IsInstance(); } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + static void PrintName(std::ostream& os) { // NOLINT(*) using ContainerType = typename T::ContainerType; os << ContainerType::_type_key; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - ArrayNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const ArrayNode* n = static_cast(ptr); for (const auto& p : n->data) { - if (!NodeTypeChecker::Check(p.get())) { + if (!ObjectTypeChecker::Check(p.get())) { return false; } } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "array<"; - NodeTypeChecker::PrintName(os); - os << ">"; + static void PrintName(std::ostream& os) { // NOLINT(*) + os << "List["; + ObjectTypeChecker::PrintName(os); + os << "]"; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - StrMapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const StrMapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -struct NodeTypeChecker > { - static inline bool Check(Node* sptr) { - if (sptr == nullptr) return true; - if (!sptr->is_type()) return false; - MapNode* n = static_cast(sptr); +struct ObjectTypeChecker > { + static bool Check(const Object* ptr) { + if (ptr == nullptr) return true; + if (!ptr->IsInstance()) return false; + const MapNode* n = static_cast(ptr); for (const auto& kv : n->data) { - if (!NodeTypeChecker::Check(kv.first.get())) return false; - if (!NodeTypeChecker::Check(kv.second.get())) return false; + if (!ObjectTypeChecker::Check(kv.first.get())) return false; + if (!ObjectTypeChecker::Check(kv.second.get())) return false; } return true; } - static inline void PrintName(std::ostringstream& os) { // NOLINT(*) - os << "map<"; - NodeTypeChecker::PrintName(os); + static void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "Map["; + ObjectTypeChecker::PrintName(os); os << ','; - NodeTypeChecker::PrintName(os); - os << '>'; + ObjectTypeChecker::PrintName(os); + os << ']'; } }; template -inline std::string NodeTypeName() { +inline std::string ObjectTypeName() { std::ostringstream os; - NodeTypeChecker::PrintName(os); + ObjectTypeChecker::PrintName(os); return os.str(); } @@ -138,12 +135,12 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { std::is_base_of::value, "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(NodePtr(nullptr)); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TNodeRef(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Expr() const { @@ -156,18 +153,20 @@ inline TVMArgValue::operator tvm::Expr() const { if (type_code_ == kDLFloat) { return Expr(static_cast(value_.v_float64)); } - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - if (sptr->is_type()) { - return IterVar(sptr)->var; + + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + + if (ptr->IsInstance()) { + return IterVar(ObjectPtr(ptr))->var; } - if (sptr->is_type()) { - return Tensor(sptr)(); + if (ptr->IsInstance()) { + return Tensor(ObjectPtr(ptr))(); } - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Expr(sptr); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Expr(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Integer() const { @@ -177,68 +176,36 @@ inline TVMArgValue::operator tvm::Integer() const { CHECK_GE(value_.v_int64, std::numeric_limits::min()); return Integer(static_cast(value_.v_int64)); } - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return Integer(sptr); -} - -inline NodePtr& TVMArgValue::node_sptr() { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return *ptr >(); + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return Integer(ObjectPtr(ptr)); } - template -inline bool TVMArgValue::IsNodeType() const { - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = - *ptr >(); - return NodeTypeChecker::Check(sptr.get()); +inline bool TVMPODValue_::IsObjectRef() const { + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); + Object* ptr = static_cast(value_.v_handle); + return ObjectTypeChecker::Check(ptr); } // extensions for TVMRetValue -inline TVMRetValue& TVMRetValue::operator=( - const NodePtr& other) { - if (other.get() == nullptr) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other); - } - return *this; -} - -inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { - if (!other.defined()) { - SwitchToPOD(kNull); - } else { - SwitchToClass >(kNodeHandle, other.node_); - } - return *this; -} - template inline TNodeRef TVMRetValue::AsNodeRef() const { static_assert( std::is_base_of::value, "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); - TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - NodePtr& sptr = *ptr >(); - CHECK(NodeTypeChecker::Check(sptr.get())) - << "Expected type " << NodeTypeName() - << " but get " << sptr->type_key(); - return TNodeRef(sptr); -} + TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); -inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) - if (other.defined()) { - values_[i].v_handle = const_cast*>(&(other.node_)); - type_codes_[i] = kNodeHandle; - } else { - type_codes_[i] = kNull; - } + Object* ptr = static_cast(value_.v_handle); + + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() + << " but get " << ptr->GetTypeKey(); + return TNodeRef(ObjectPtr(ptr)); } // type related stuffs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 4329c438e8a0..e54d88d5a393 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -52,7 +52,7 @@ class PatternNode : public RelayNode { class Pattern : public NodeRef { public: Pattern() {} - explicit Pattern(NodePtr p) : NodeRef(p) {} + explicit Pattern(ObjectPtr p) : NodeRef(p) {} using ContainerType = PatternNode; }; diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f94ba5e26068..15330b00e961 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -83,10 +83,12 @@ using NodeEqual = ::tvm::NodeEqual; #define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \ class TypeName : public NodeRefBase { \ public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRefBase(n) {} \ + TypeName() {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : NodeRefBase(n) { \ + } \ const NodeName* operator->() const { \ - return static_cast(node_.get()); \ + return static_cast(get()); \ } \ operator bool() { return this->defined(); } \ using ContainerType = NodeName; \ @@ -127,7 +129,7 @@ class SourceName : public NodeRef { * \return the pointer to the internal node container */ inline const SourceNameNode* operator->() const { - return static_cast(this->node_.get()); + return static_cast(get()); } /*! diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b1b8d6a7154e..281b99297e78 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -541,10 +541,11 @@ RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); // implementataions inline const Type& ExprNode::checked_type() const { - CHECK(checked_type_.defined()) << "internal error: the type checker has " - "not populated the checked_type " - "field for " - << GetRef(this); + CHECK(checked_type_.defined()) + << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " + << GetRef(this); return this->checked_type_; } @@ -557,7 +558,7 @@ inline const TTypeNode* ExprNode::type_as() const { const TTypeNode* node = checked_type_.as(); CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->type_key(); + << ", but get " << checked_type_->GetTypeKey(); return node; } diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e0d940c5d1a5..8bc87a27f66f 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -57,8 +57,8 @@ class ExprFunctor; #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class ExprFunctor { private: using TSelf = ExprFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -117,7 +117,7 @@ class ExprFunctor { virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index d05099f781ac..a0422fa7f446 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -78,9 +78,9 @@ class ValueNode : public RelayNode { class Value : public NodeRef { public: Value() {} - explicit Value(NodePtr n) : NodeRef(n) {} + explicit Value(ObjectPtr n) : NodeRef(n) {} const ValueNode* operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } using ContainerType = ValueNode; diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 8b17020a1132..10d72349d0f5 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -281,10 +281,10 @@ class ModuleNode : public RelayNode { struct Module : public NodeRef { Module() {} - explicit Module(NodePtr p) : NodeRef(p) {} + explicit Module(ObjectPtr<::tvm::Object> p) : NodeRef(p) {} - inline ModuleNode* operator->() const { - return static_cast(node_.get()); + ModuleNode* operator->() const { + return static_cast(get_mutable()); } using ContainerType = ModuleNode; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 0a6d3725655f..572c194bc269 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -138,7 +138,7 @@ class Op : public relay::Expr { /*! \brief default constructor */ Op() {} /*! \brief constructor from node pointer */ - explicit Op(NodePtr n) : Expr(n) {} + explicit Op(ObjectPtr n) : Expr(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -221,11 +221,12 @@ class OpRegistry { const Attrs&, const TypeReporter&)> type_rel_func); /*! - * \brief Set the type key of attributes. - * \param type_key The type of of the attrs field. + * \brief Set the the attrs type key and index to be AttrsType. + * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - inline OpRegistry& set_attrs_type_key(const std::string& type_key); + template + inline OpRegistry& set_attrs_type(); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -397,7 +398,7 @@ class OpMap { // implementations inline const OpNode* Op::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } template @@ -496,10 +497,10 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) return *this; } -inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) - const std::string& type_key) { - get()->attrs_type_key = type_key; - get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str()); +template +inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) + get()->attrs_type_key = AttrsType::_type_key; + get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); return *this; } diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 7f1c47e03592..c15523cb25de 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -57,8 +57,8 @@ class PatternFunctor; #define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), \ std::forward(args)...); \ }); @@ -66,7 +66,7 @@ template class PatternFunctor { private: using TSelf = PatternFunctor; - using FType = tvm::IRFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -103,7 +103,7 @@ class PatternFunctor { virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Node* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->type_key(); + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a2119c90f750..08ea3075cb83 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -134,16 +134,16 @@ class PassContext : public NodeRef { * \return const access pointer. */ const PassContextNode* operator->() const { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get()); } /*! * \brief mutable accessor. * \return mutable access pointer. */ PassContextNode* operator->() { - CHECK(node_.get() != nullptr); - return static_cast(node_.get()); + CHECK(get() != nullptr); + return static_cast(get_mutable()); } /*! * \brief Construct a PassContext containing the default configurations. diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 16e36785c533..a5cc3c83383e 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -58,7 +58,7 @@ class TypeNode : public RelayNode { class Type : public NodeRef { public: Type() {} - explicit Type(NodePtr p) : NodeRef(p) {} + explicit Type(ObjectPtr p) : NodeRef(p) {} using ContainerType = TypeNode; }; @@ -430,10 +430,11 @@ class TypeReporterNode : public Node { class TypeReporter : public NodeRef { public: TypeReporter() {} - explicit TypeReporter(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : NodeRef(n) { } TypeReporterNode* operator->() const { - return static_cast(node_.get()); + return const_cast( + static_cast(get())); } using ContainerType = TypeReporterNode; }; diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b058fd63a2f5..267504beb11a 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -98,13 +98,12 @@ typedef enum { kTVMType = 5U, kTVMContext = 6U, kArrayHandle = 7U, - kNodeHandle = 8U, + kObjectHandle = 8U, kModuleHandle = 9U, kFuncHandle = 10U, kStr = 11U, kBytes = 12U, kNDArrayContainer = 13U, - kObjectHandle = 14U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 6b4f01e4ac9b..01c08d324fcb 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -69,7 +69,7 @@ class ObjAllocatorBase { "make_node can only be used to create NodeBase"); T* ptr = Handler::New(static_cast(this), std::forward(args)...); - ptr->type_index_ = T::type_index(); + ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); } diff --git a/include/tvm/runtime/node_base.h b/include/tvm/runtime/node_base.h deleted file mode 100644 index 8b47c18a09a7..000000000000 --- a/include/tvm/runtime/node_base.h +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/node_base.h - * \brief Base data structure for Node. - * - * \note Node is not a runtime feature. - * This file only exposes the signature of NodePtr for PackedFunc. - */ -#ifndef TVM_RUNTIME_NODE_BASE_H_ -#define TVM_RUNTIME_NODE_BASE_H_ - -#include -#include - -namespace tvm { - -// forward declarations -template -class NodePtr; -class Node; -class NodeRef; - -/*! - * \brief Base class of Node for runtime destructor purposes. - * - * Node is a reference counted object which is used to construct AST. - * Each node is backed by a custom deleter, which deletes the object. - * Do not call create raw Node pointer, always use tvm::make_node. - * - * \note In most cases, please inheritate tvm::Node. - * \sa Node, NodePtr, make_node - */ -class NodeBase { - public: - /*! - * \brief type of NodeBase deleter - * \param self pointer to the NodeBase. - */ - typedef void (*FDeleter)(NodeBase* self); - - protected: - // default constructor and copy constructor - NodeBase() {} - // override the copy and assign constructors to do nothing. - // This is to make sure only contents, but not deleter and ref_counter - // are copied when a child class copies itself. - NodeBase(const NodeBase& other) { // NOLINT(*) - } - NodeBase(NodeBase&& other) { // NOLINT(*) - } - NodeBase& operator=(const NodeBase& other) { //NOLINT(*) - return *this; - } - NodeBase& operator=(NodeBase&& other) { //NOLINT(*) - return *this; - } - - private: - /*! \brief Internal reference counter */ - std::atomic ref_counter_{0}; - /*! - * \brief deleter of this object to enable customized allocation. - * If the deleter is nullptr, no deletion will be performed. - * The creator of the Node must always set the deleter field properly. - */ - FDeleter deleter_ = nullptr; - // reference counting functions - void IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); - } - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } - } - int use_count() const { - return ref_counter_.load(std::memory_order_relaxed); - } - // friend declaration - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; - -/*! - * \brief Smart pointer for Node containers, - * must be subclass of NodeBase - * \tparam T the content data type. - */ -template -class NodePtr { - public: - /*! \brief default constructor */ - NodePtr() {} - /*! \brief default constructor */ - NodePtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - } - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - NodePtr(const NodePtr& other) // NOLINT(*) - : NodePtr(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - NodePtr(NodePtr&& other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of::value, - "can only assign of child class NodePtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~NodePtr() { - this->reset(); - } - /*! - * \brief Swap this array with another NDArray - * \param other The other NDArray - */ - void swap(NodePtr& other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T* get() const { - return static_cast(data_); - } - /*! - * \return The pointer - */ - T* operator->() const { - return get(); - } - /*! - * \return The reference - */ - T& operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(const NodePtr& other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - NodePtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignmemt - * \param other The value to be assigned. - * \return reference to self. - */ - NodePtr& operator=(NodePtr&& other) { // NOLINT(*) - // copy-and-swap idiom - NodePtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } - /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } - /*! \return Whether two NodePtr do not equals each other */ - bool operator==(const NodePtr& other) const { - return data_ == other.data_; - } - /*! \return Whether two NodePtr equals each other */ - bool operator!=(const NodePtr& other) const { - return data_ != other.data_; - } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } - - private: - /*! \brief internal pointer field */ - NodeBase* data_{nullptr}; - /*! - * \brief constructor from NodeBase - * \param data The node base pointer - */ - explicit NodePtr(NodeBase* data) - : data_(data) { - if (data != nullptr) { - data_->IncRef(); - } - } - // friend declaration - friend class Node; - template - friend class NodePtr; - template - friend NodePtr make_node(Args&&...); -}; -} // namespace tvm - -#endif // TVM_RUNTIME_NODE_BASE_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 7291510c16df..143f3bb35220 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -65,7 +65,7 @@ enum TypeIndex { * - _type_index: * Static type index of the object, if assigned to TypeIndex::kDynamic * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::type_index(); + * Runtime type index can be accessed by ObjectType::TypeIndex(); * - _type_key: * The unique string identifier of tyep type. * - _type_final: @@ -147,10 +147,23 @@ class Object { * \param self pointer to the Object. */ typedef void (*FDeleter)(Object* self); - /*! \return The internal type index of the object. */ + /*! \return The internal runtime type index of the object. */ uint32_t type_index() const { return type_index_; } + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + return TypeIndex2Key(type_index_); + } + /*! + * \return A hash value of the return of GetTypeKey. + */ + size_t GetTypeKeyHash() const { + return TypeIndex2KeyHash(type_index_); + } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. @@ -159,6 +172,25 @@ class Object { template inline bool IsInstance() const; + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); + /*! + * \brief Get the type key hash of the corresponding index from runtime. + * \param tindex The type index. + * \return the related key-hash. + */ + TVM_DLL static size_t TypeIndex2KeyHash(uint32_t tindex); + /*! + * \brief Get the type index of the corresponding key from runtime. + * \param key The type key. + * \return the result. + */ + TVM_DLL static uint32_t TypeKey2Index(const char* key); + #if TVM_OBJECT_ATOMIC_REF_COUNTER using RefCounterType = std::atomic; #else @@ -170,9 +202,30 @@ class Object { static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; - static const uint32_t _GetOrAllocRuntimeTypeIndex() { + static uint32_t _GetOrAllocRuntimeTypeIndex() { return 0; } + static uint32_t RuntimeTypeIndex() { + return 0; + } + + // Default constructor and copy constructor + Object() {} + // Override the copy and assign constructors to do nothing. + // This is to make sure only contents, but not deleter and ref_counter + // are copied when a child class copies itself. + // This will enable us to use make_object(*obj_ptr) + // to copy an existing object. + Object(const Object& other) { // NOLINT(*) + } + Object(Object&& other) { // NOLINT(*) + } + Object& operator=(const Object& other) { //NOLINT(*) + return *this; + } + Object& operator=(Object&& other) { //NOLINT(*) + return *this; + } protected: // The fields of the base object cell. @@ -215,18 +268,6 @@ class Object { uint32_t type_child_slots, bool type_child_slots_can_overflow); - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - */ - TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); - - /*! - * \brief Get the type index of the corresponding key from runtime. - * \param key The type key. - */ - TVM_DLL static uint32_t TypeKey2Index(const char* key); - private: // reference counter related operations /*! \brief developer function, increases reference counter. */ @@ -256,6 +297,32 @@ class Object { friend class TVMObjectCAPI; }; +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the node alive beyond the scope of the function. + * + * \param ptr The node pointer + * \tparam RefType The reference type + * \tparam ObjectType The node type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const ObjectType* ptr); + +/*! + * \brief Downcast a base reference type to a more specific type. + * + * \param ref The inptut reference + * \return The corresponding SubRef. + * \tparam SubRef The target specific reference type. + * \tparam BaseRef the current reference type. + */ +template +inline SubRef Downcast(BaseRef ref); + /*! * \brief A custom smart pointer for Object. * \tparam T the content data type. @@ -389,7 +456,7 @@ class ObjectPtr { /*! \brief internal pointer field */ Object* data_{nullptr}; /*! - * \brief constructor from NodeBase + * \brief constructor from Object * \param data The data pointer */ explicit ObjectPtr(Object* data) : data_(data) { @@ -400,6 +467,7 @@ class ObjectPtr { // friend classes friend class Object; friend class ObjectRef; + friend struct ObjectHash; template friend class ObjectPtr; template @@ -407,6 +475,9 @@ class ObjectPtr { friend class TVMPODValue_; friend class TVMArgsSetter; friend class TVMRetValue; + friend class TVMArgValue; + template + friend RefType GetRef(const ObjType* ptr); }; /*! \brief Base class of all object reference */ @@ -416,10 +487,54 @@ class ObjectRef { ObjectRef() = default; /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { + return data_ == other.data_; + } + /*! + * \brief Comparator + * \param other Another node ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { + return data_ != other.data_; + } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { + return data_.get() < other.data_.get(); + } + /*! \return whether the expression is null */ + bool defined() const { + return data_ != nullptr; + } /*! \return the internal object pointer */ - inline const Object* get() const; + const Object* get() const { + return data_.get(); + } /*! \return the internal node pointer */ - inline const Object* operator->() const; + const Object* operator->() const { + return get(); + } + /*! \return whether the reference is unique */ + bool unique() const { + return data_.unique(); + } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -434,25 +549,81 @@ class ObjectRef { template inline const ObjectType* as() const; - /*! \brief type indicate the container type */ + /*! \brief type indicate the container type. */ using ContainerType = Object; protected: /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { + return data_.get(); + } + /*! + * \brief Internal helper function downcast a ref without check. + * \note Only used for internal dev purposes. + * \tparam T The target reference type. + * \return The casted result. + */ + template + static T DowncastNoCheck(ObjectRef ref) { + return T(std::move(ref.data_)); + } + /*! + * \brief Internal helper function get data_ as ObjectPtr of ObjectType. + * \note only used for internal dev purpose. + * \tparam ObjectType The corresponding object type. + * \return the corresponding type. + */ + template + static ObjectPtr GetDataPtr(const ObjectRef& ref) { + return ObjectPtr(ref.data_.data_); + } // friend classes. + friend struct ObjectHash; friend class TVMRetValue; friend class TVMArgsSetter; + template + friend SubRef Downcast(BaseRef ref); }; + +/*! \brief ObjectRef hash functor */ +struct ObjectHash { + size_t operator()(const ObjectRef& a) const { + return operator()(a.data_); + } + + template + size_t operator()(const ObjectPtr& a) const { + return std::hash()(a.get()); + } +}; + + +/*! \brief ObjectRef equal functor */ +struct ObjectEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { + return a.same_as(b); + } + + template + size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { + return a == b; + } +}; + + /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static const uint32_t type_index() { \ - if (_type_index != TypeIndex::kDynamic) return _type_index; \ + static const uint32_t RuntimeTypeIndex() { \ + if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return _type_index; \ + } \ return _GetOrAllocRuntimeTypeIndex(); \ } \ static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ @@ -551,11 +722,11 @@ inline bool Object::IsInstance() const { if (TargetType::_type_final) { // if the target type is a final type // then we only need to check the equivalence. - return self->type_index_ == TargetType::type_index(); + return self->type_index_ == TargetType::RuntimeTypeIndex(); } else { // if target type is a non-leaf type // Check if type index falls into the range of reserved slots. - uint32_t begin = TargetType::type_index(); + uint32_t begin = TargetType::RuntimeTypeIndex(); // The condition will be optimized by constant-folding. if (TargetType::_type_child_slots != 0) { uint32_t end = begin + TargetType::_type_child_slots; @@ -565,22 +736,15 @@ inline bool Object::IsInstance() const { } if (!TargetType::_type_child_slots_can_overflow) return false; // Invariance: parent index is always smaller than the child. - if (self->type_index_ < TargetType::type_index()) return false; + if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; // The rare slower-path, check type hierachy. - return self->DerivedFrom(TargetType::type_index()); + return self->DerivedFrom(TargetType::RuntimeTypeIndex()); } } else { return false; } } -inline const Object* ObjectRef::get() const { - return data_.data_; -} - -inline const Object* ObjectRef::operator->() const { - return get(); -} template inline const ObjectType* ObjectRef::as() const { @@ -591,7 +755,27 @@ inline const ObjectType* ObjectRef::as() const { return nullptr; } } + +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, + "Can only cast to the ref of same container type"); + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); +} + +template +inline SubRef Downcast(BaseRef ref) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + return SubRef(std::move(ref.data_)); +} + } // namespace runtime + +template +using NodePtr = runtime::ObjectPtr; + } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 2bfa3323e4f1..649a5058a9a5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -40,7 +40,6 @@ #include "module.h" #include "ndarray.h" #include "object.h" -#include "node_base.h" // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY @@ -52,6 +51,8 @@ namespace tvm { class Integer; class DataType; class Expr; +class Node; +class NodeRef; namespace runtime { @@ -490,9 +491,12 @@ class TVMPODValue_ { return NDArray(static_cast(value_.v_handle)); } operator ObjectRef() const { - if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kNull) { + return ObjectRef(ObjectPtr(nullptr)); + } TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); - return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); + return ObjectRef( + ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); @@ -512,9 +516,14 @@ class TVMPODValue_ { CHECK_LT(type_code_, kExtEnd); return static_cast(value_.v_handle)[0]; } + template::value>::type> + inline bool IsObjectRef() const; int type_code() const { return type_code_; } + /*! * \brief return handle as specific pointer type. * \tparam T the data type. @@ -567,6 +576,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -616,15 +626,9 @@ class TVMArgValue : public TVMPODValue_ { typename = typename std::enable_if< std::is_class::value>::type> inline operator T() const; - template::value>::type> - inline bool IsNodeType() const; inline operator tvm::DataType() const; inline operator tvm::Expr() const; inline operator tvm::Integer() const; - // get internal node ptr, if it is node - inline NodePtr& node_sptr(); }; /*! @@ -663,6 +667,8 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator ObjectRef; + using TVMPODValue_::IsObjectRef; + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } @@ -760,11 +766,19 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(ObjectRef other) { - this->Clear(); - type_code_ = kObjectHandle; - // move the handle out - value_.v_handle = other.data_.data_; - other.data_.data_ = nullptr; + return operator=(std::move(other.data_)); + } + template + TVMRetValue& operator=(ObjectPtr other) { + if (other.data_ != nullptr) { + this->Clear(); + type_code_ = kObjectHandle; + // move the handle out + value_.v_handle = other.data_; + other.data_ = nullptr; + } else { + SwitchToPOD(kNull); + } return *this; } TVMRetValue& operator=(PackedFunc f) { @@ -814,7 +828,7 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kNodeHandle && + CHECK(type_code_ != kObjectHandle && type_code_ != kFuncHandle && type_code_ != kModuleHandle && type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; @@ -827,8 +841,6 @@ class TVMRetValue : public TVMPODValue_ { inline operator T() const; template inline TNodeRef AsNodeRef() const; - inline TVMRetValue& operator=(const NodeRef& other); - inline TVMRetValue& operator=(const NodePtr& other); // type related inline operator tvm::DataType() const; inline TVMRetValue& operator=(const tvm::DataType& other); @@ -857,11 +869,6 @@ class TVMRetValue : public TVMPODValue_ { *this = other.operator NDArray(); break; } - case kNodeHandle: { - SwitchToClass >( - kNodeHandle, *other.template ptr >()); - break; - } case kObjectHandle: { *this = other.operator ObjectRef(); break; @@ -908,7 +915,6 @@ class TVMRetValue : public TVMPODValue_ { case kStr: delete ptr(); break; case kFuncHandle: delete ptr(); break; case kModuleHandle: delete ptr(); break; - case kNodeHandle: delete ptr >(); break; case kNDArrayContainer: { static_cast(value_.v_handle)->DecRef(); break; @@ -939,7 +945,6 @@ inline const char* TypeCode2Str(int type_code) { case kBytes: return "bytes"; case kHandle: return "handle"; case kNull: return "NULL"; - case kNodeHandle: return "NodeHandle"; case kArrayHandle: return "ArrayHandle"; case kTVMType: return "TVMType"; case kTVMContext: return "TVMContext"; @@ -1057,8 +1062,6 @@ inline PackedFunc::FType PackedFunc::body() const { return body_; } - - // internal namespace namespace detail { @@ -1163,8 +1166,12 @@ class TVMArgsSetter { type_codes_[i] = kNDArrayContainer; } void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kObjectHandle; + if (value.defined()) { + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kObjectHandle; + } else { + type_codes_[i] = kNull; + } } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { @@ -1181,8 +1188,6 @@ class TVMArgsSetter { typename = typename std::enable_if< extension_type_info::code != 0>::type> inline void operator()(size_t i, const T& value) const; - // NodeRef related extenstions: in tvm/packed_func_ext.h - inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const tvm::DataType& t) const; private: diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index af3e929ac3fa..36265667e5b6 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -56,7 +56,7 @@ enum AttachType : int { class Stage : public NodeRef { public: Stage() {} - explicit Stage(NodePtr n) : NodeRef(n) {} + explicit Stage(ObjectPtr n) : NodeRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule @@ -280,7 +280,7 @@ class Stage : public NodeRef { class Schedule : public NodeRef { public: Schedule() {} - explicit Schedule(NodePtr n) : NodeRef(n) {} + explicit Schedule(ObjectPtr n) : NodeRef(n) {} /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -403,7 +403,7 @@ class Schedule : public NodeRef { class IterVarRelation : public NodeRef { public: IterVarRelation() {} - explicit IterVarRelation(NodePtr n) : NodeRef(n) {} + explicit IterVarRelation(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -417,7 +417,7 @@ class IterVarRelation : public NodeRef { class IterVarAttr : public NodeRef { public: IterVarAttr() {} - explicit IterVarAttr(NodePtr n) : NodeRef(n) {} + explicit IterVarAttr(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -745,25 +745,25 @@ class SingletonNode : public IterVarRelationNode { // implementations inline const StageNode* Stage::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline StageNode* Stage::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline ScheduleNode* Schedule::operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline const IterVarAttrNode* IterVarAttr::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index f37cc7bed7d1..6471c9c69a62 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -50,7 +50,7 @@ class Tensor : public NodeRef { public: /*! \brief default constructor, used internally */ Tensor() {} - explicit Tensor(NodePtr n) : NodeRef(n) {} + explicit Tensor(ObjectPtr n) : NodeRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -141,7 +141,7 @@ class Operation : public ir::FunctionRef { public: /*! \brief default constructor */ Operation() {} - explicit Operation(NodePtr n) : FunctionRef(n) {} + explicit Operation(ObjectPtr n) : FunctionRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -189,7 +189,7 @@ class TensorNode : public Node { // Implementations of inline functions inline const TensorNode* Tensor::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } inline size_t Tensor::ndim() const { @@ -250,19 +250,17 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::Operation> { - std::size_t operator()(const ::tvm::Operation& k) const { - return k.hash(); - } +struct hash<::tvm::Operation> : public ::tvm::NodeHash { }; template <> struct hash<::tvm::Tensor> { std::size_t operator()(const ::tvm::Tensor& k) const { + ::tvm::NodeHash hasher; if (k.defined() && k->op.defined()) { - return k->op.hash(); + return hasher(k->op); } else{ - return k.hash(); + return hasher(k); } } }; diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index b5ca6eb4358b..152a27f6e2a9 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -112,7 +112,7 @@ class TensorIntrinNode : public Node { }; inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } // Internal node container of tensor intrinsic calling. @@ -170,7 +170,7 @@ class TensorIntrinCallNode : public Node { }; inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { - return static_cast(node_.get()); + return static_cast(get()); } } // namespace tvm diff --git a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc index 1eff6c45e1fc..b4bfd4270775 100644 --- a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc @@ -242,7 +242,7 @@ extern "C" int funcInvokeCallback(TVMValue *args, for (int i = 0; i < numArgs; ++i) { TVMValue arg = args[i]; int tcode = typeCodes[i]; - if (tcode == kNodeHandle || tcode == kFuncHandle || tcode == kModuleHandle) { + if (tcode == kObjectHandle || tcode == kFuncHandle || tcode == kModuleHandle) { TVMCbArgToReturn(&arg, tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); @@ -259,8 +259,8 @@ extern "C" int funcInvokeCallback(TVMValue *args, reinterpret_cast(resourceHandle), jargs); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); - const int prevNumStrArg = e->tvmFuncArgPushedStrs.size(); - const int prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); + const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); + const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); // convert returned (java) TVMValue to (C) TVMValue env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue); diff --git a/nnvm/include/nnvm/compiler/util.h b/nnvm/include/nnvm/compiler/util.h index fa8b69f9b70a..9555c0e7b3ea 100644 --- a/nnvm/include/nnvm/compiler/util.h +++ b/nnvm/include/nnvm/compiler/util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -56,7 +56,7 @@ inline tvm::Array ShapeToArray(TShape shape) { * \return An Array of Expr, where each element is a constant int32 */ inline tvm::Array ShapeToIntArray(TShape shape) { - return tvm::Array(ShapeToArray(shape).node_); + return tvm::Downcast >(ShapeToArray(shape)); } } // namespace compiler } // namespace nnvm diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index 542455969b8b..5ce78d1d58d6 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -388,6 +388,9 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") *rv = ret; }); +TVM_REGISTER_NODE_TYPE(GraphFuncNode); +TVM_REGISTER_NODE_TYPE(GraphCacheEntryNode); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const GraphFuncNode *op, IRPrinter *p) { p->stream << "GraphFunc(name=" << op->func_name diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index 35287f5a9358..e8d33cb4be7e 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -92,7 +92,7 @@ class GraphCacheEntry : public ::tvm::NodeRef { GraphCacheEntry() {} explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} GraphCacheEntryNode* operator->() { - return static_cast(node_.get()); + return static_cast(get_mutable()); } using ContainerType = GraphCacheEntryNode; }; diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 3a847de83d9f..7b324ba100ad 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include diff --git a/nnvm/src/compiler/packed_func_ext.cc b/nnvm/src/compiler/packed_func_ext.cc index bbcc62a99ad8..45f1451663e6 100644 --- a/nnvm/src/compiler/packed_func_ext.cc +++ b/nnvm/src/compiler/packed_func_ext.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -115,7 +115,7 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute") const Array& out_info) -> Array { TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info); - if ((*ret.ptr<::tvm::NodePtr >())->derived_from()) { + if (ret.IsObjectRef()) { return {ret.operator Tensor()}; } else { return ret; diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 5496a4c674f6..c48ae0061f9e 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -1237,7 +1237,7 @@ Array GetIntArray(Array arr) { CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Array(arr.node_); + return Downcast >(arr); } NNVM_REGISTER_OP(slice_like) diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 22fb6c335dcc..2f0b5babda4d 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement +# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import """Function configuration API.""" from __future__ import absolute_import @@ -32,9 +32,8 @@ from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .node import NodeBase +from .object import ObjectBase, _set_class_node from . import object as _object -from . import node as _node FunctionHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p @@ -108,9 +107,9 @@ def _make_tvm_args(args, temp_args): values = (TVMValue * num_args)() type_codes = (ctypes.c_int * num_args)() for i, arg in enumerate(args): - if isinstance(arg, NodeBase): + if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None type_codes[i] = TypeCode.NULL @@ -148,7 +147,7 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.NODE_HANDLE + type_codes[i] = TypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle @@ -164,9 +163,6 @@ def _make_tvm_args(args, temp_args): values[i].v_handle = arg.handle type_codes[i] = TypeCode.FUNC_HANDLE temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): - values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE else: raise TypeError("Don't know how to handle type %s" % type(arg)) return values, type_codes, num_args @@ -226,7 +222,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE) + assert ret_tcode.value == TypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -247,7 +243,6 @@ def _handle_return_func(x): return _CLASS_FUNCTION(handle, False) # setup return handle for function type -_node.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py deleted file mode 100644 index 39fe0ef35525..000000000000 --- a/python/tvm/_ffi/_ctypes/node.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, protected-access -# pylint: disable=no-member, missing-docstring, not-callable -from __future__ import absolute_import - -import ctypes -from ..base import _LIB, check_call, c_str -from ..node_generic import _set_class_node_base -from .types import TVMValue, TypeCode -from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func - -NodeHandle = ctypes.c_void_p -__init_by_constructor__ = None - -"""Maps node type to its constructor""" -NODE_TYPE = {} - -def _register_node(index, cls): - """register node class""" - NODE_TYPE[index] = cls - -def _return_node(x): - """Return node function""" - handle = x.v_handle - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - tindex = ctypes.c_int() - check_call(_LIB.TVMNodeGetTypeIndex(handle, ctypes.byref(tindex))) - cls = NODE_TYPE.get(tindex.value, NodeBase) - # Avoid calling __init__ of cls, instead directly call __new__ - # This allows child class to implement their own __init__ - node = cls.__new__(cls) - node.handle = handle - return node - - -RETURN_SWITCH[TypeCode.NODE_HANDLE] = _return_node -C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( - _return_node, TypeCode.NODE_HANDLE) - - -class NodeBase(object): - __slots__ = ["handle"] - # pylint: disable=no-member - def __del__(self): - if _LIB is not None: - check_call(_LIB.TVMNodeFree(self.handle)) - - def __getattr__(self, name): - ret_val = TVMValue() - ret_type_code = ctypes.c_int() - ret_success = ctypes.c_int() - check_call(_LIB.TVMNodeGetAttr( - self.handle, c_str(name), - ctypes.byref(ret_val), - ctypes.byref(ret_type_code), - ctypes.byref(ret_success))) - if not ret_success.value: - raise AttributeError( - "'%s' object has no attribute '%s'" % (str(type(self)), name)) - return RETURN_SWITCH[ret_type_code.value](ret_val) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # assign handle first to avoid error raising - self.handle = None - handle = __init_by_constructor__(fconstructor, args) - if not isinstance(handle, NodeHandle): - handle = NodeHandle(handle) - self.handle = handle - -_set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 5ddceb166677..c3ae56822198 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -21,6 +21,7 @@ import ctypes from ..base import _LIB, check_call from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func +from ..node_generic import _set_class_node_base ObjectHandle = ctypes.c_void_p @@ -29,6 +30,13 @@ """Maps object type to its constructor""" OBJECT_TYPE = {} +_CLASS_NODE = None + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class + + def _register_object(index, cls): """register object class""" OBJECT_TYPE[index] = cls @@ -40,7 +48,7 @@ def _return_object(x): handle = ObjectHandle(handle) tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) - cls = OBJECT_TYPE.get(tindex.value, ObjectBase) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_NODE) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) @@ -83,3 +91,6 @@ def __init_handle_by_constructor__(self, fconstructor, *args): if not isinstance(handle, ObjectHandle): handle = ObjectHandle(handle) self.handle = handle + + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 76fa96376b47..4b7b2c88ffa5 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -31,13 +31,12 @@ cdef enum TVMTypeCode: kTVMType = 5 kTVMContext = 6 kArrayHandle = 7 - kNodeHandle = 8 + kObjectHandle = 8 kModuleHandle = 9 kFuncHandle = 10 kStr = 11 kBytes = 12 kNDArrayContainer = 13 - kObjectHandle = 14 kExtBegin = 15 cdef extern from "tvm/runtime/c_runtime_api.h": @@ -78,7 +77,7 @@ ctypedef void* TVMStreamHandle ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* ObjectHandle -ctypedef void* NodeHandle + ctypedef struct TVMNDArrayContainer: DLTensor dl_tensor @@ -134,18 +133,6 @@ cdef extern from "tvm/runtime/c_runtime_api.h": int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) -cdef extern from "tvm/c_dsl_api.h": - int TVMNodeFree(NodeHandle handle) - int TVMNodeTypeKey2Index(const char* type_key, - int* out_index) - int TVMNodeGetTypeIndex(NodeHandle handle, - int* out_index) - int TVMNodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* out_value, - int* out_type_code, - int* out_success) - cdef inline py_str(const char* x): if PY_MAJOR_VERSION < 3: return x diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index a9349338fc6a..cbf9d5859046 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -17,7 +17,7 @@ include "./base.pxi" include "./object.pxi" -include "./node.pxi" +# include "./node.pxi" include "./function.pxi" include "./ndarray.pxi" diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index ceacf7407170..a2360427b6c7 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -41,10 +41,9 @@ cdef int tvm_callback(TVMValue* args, for i in range(num_args): value = args[i] tcode = type_codes[i] - if (tcode == kNodeHandle or + if (tcode == kObjectHandle or tcode == kFuncHandle or tcode == kModuleHandle or - tcode == kObjectHandle or tcode > kExtBegin): CALL(TVMCbArgToReturn(&value, tcode)) @@ -98,9 +97,9 @@ cdef inline int make_arg(object arg, list temp_args) except -1: """Pack arguments into c args tvm call accept""" cdef unsigned long long ptr - if isinstance(arg, NodeBase): - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle + if isinstance(arg, ObjectBase): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, NDArrayBase): value[0].v_handle = (arg).chandle tcode[0] = (kNDArrayContainer if @@ -152,12 +151,9 @@ cdef inline int make_arg(object arg, temp_args.append(tstr) elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_node(arg) - value[0].v_handle = (arg).chandle - tcode[0] = kNodeHandle - temp_args.append(arg) - elif isinstance(arg, _CLASS_OBJECT): value[0].v_handle = (arg).chandle tcode[0] = kObjectHandle + temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): value[0].v_handle = c_handle(arg.handle) tcode[0] = kModuleHandle @@ -188,9 +184,7 @@ cdef inline bytearray make_ret_bytes(void* chandle): cdef inline object make_ret(TVMValue value, int tcode): """convert result to return value.""" - if tcode == kNodeHandle: - return make_ret_node(value.v_handle) - elif tcode == kObjectHandle: + if tcode == kObjectHandle: return make_ret_object(value.v_handle) elif tcode == kNull: return None @@ -314,6 +308,7 @@ cdef class FunctionBase: _CLASS_FUNCTION = None _CLASS_MODULE = None _CLASS_OBJECT = None +_CLASS_NODE = None def _set_class_module(module_class): """Initialize the module.""" @@ -327,3 +322,7 @@ def _set_class_function(func_class): def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class + +def _set_class_node(node_class): + global _CLASS_NODE + _CLASS_NODE = node_class diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi deleted file mode 100644 index 5e0c366e5600..000000000000 --- a/python/tvm/_ffi/_cython/node.pxi +++ /dev/null @@ -1,110 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from ... import _api_internal -from ..base import string_types -from ..node_generic import _set_class_node_base - -"""Maps node type to its constructor""" -NODE_TYPE = [] - -def _register_node(int index, object cls): - """register node class""" - while len(NODE_TYPE) <= index: - NODE_TYPE.append(None) - NODE_TYPE[index] = cls - - -cdef inline object make_ret_node(void* chandle): - global NODE_TYPE - cdef int tindex - cdef list node_type - cdef object cls - node_type = NODE_TYPE - CALL(TVMNodeGetTypeIndex(chandle, &tindex)) - if tindex < len(node_type): - cls = node_type[tindex] - if cls is not None: - obj = cls.__new__(cls) - else: - obj = NodeBase.__new__(NodeBase) - else: - obj = NodeBase.__new__(NodeBase) - (obj).chandle = chandle - return obj - - -cdef class NodeBase: - cdef void* chandle - - cdef _set_handle(self, handle): - cdef unsigned long long ptr - if handle is None: - self.chandle = NULL - else: - ptr = handle.value - self.chandle = (ptr) - - property handle: - def __get__(self): - if self.chandle == NULL: - return None - else: - return ctypes_handle(self.chandle) - - def __set__(self, value): - self._set_handle(value) - - def __dealloc__(self): - CALL(TVMNodeFree(self.chandle)) - - def __getattr__(self, name): - cdef TVMValue ret_val - cdef int ret_type_code, ret_succ - CALL(TVMNodeGetAttr(self.chandle, c_str(name), - &ret_val, &ret_type_code, &ret_succ)) - if ret_succ == 0: - raise AttributeError( - "'%s' object has no attribute '%s'" % (type(self), name)) - return make_ret(ret_val, ret_type_code) - - def __init_handle_by_constructor__(self, fconstructor, *args): - """Initialize the handle by calling constructor function. - - Parameters - ---------- - fconstructor : Function - Constructor function. - - args: list of objects - The arguments to the constructor - - Note - ---- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ - # avoid error raised during construction. - self.chandle = NULL - cdef void* chandle - ConstructorCall( - (fconstructor).chandle, - kNodeHandle, args, &chandle) - self.chandle = chandle - -_set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 90be6a9c5b74..9561eab94ea2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -16,6 +16,8 @@ # under the License. """Maps object type to its constructor""" +from ..node_generic import _set_class_node_base + OBJECT_TYPE = [] def _register_object(int index, object cls): @@ -27,6 +29,7 @@ def _register_object(int index, object cls): cdef inline object make_ret_object(void* chandle): global OBJECT_TYPE + global _CLASS_NODE cdef unsigned tindex cdef list object_type cdef object cls @@ -39,9 +42,11 @@ cdef inline object make_ret_object(void* chandle): if cls is not None: obj = cls.__new__(cls) else: - obj = ObjectBase.__new__(ObjectBase) + # default use node base class + # TODO(tqchen) change to object after Node unifies with Object + obj = _CLASS_NODE.__new__(_CLASS_NODE) else: - obj = ObjectBase.__new__(ObjectBase) + obj = _CLASS_NODE.__new__(_CLASS_NODE) (obj).chandle = chandle return obj @@ -94,3 +99,6 @@ cdef class ObjectBase: (fconstructor).chandle, kObjectHandle, args, &chandle) self.chandle = chandle + + +_set_class_node_base(ObjectBase) diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index baca89d628b8..c6c151af9053 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -21,21 +21,8 @@ import ctypes import sys from .. import _api_internal +from .object import Object, register_object, _set_class_node from .node_generic import NodeGeneric, convert_to_node, const -from .base import _LIB, check_call, c_str, py_str, _FFI_MODE - -IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError -try: - # pylint: disable=wrong-import-position - if _FFI_MODE == "ctypes": - raise ImportError() - if sys.version_info >= (3, 0): - from ._cy3.core import _register_node, NodeBase as _NodeBase - else: - from ._cy2.core import _register_node, NodeBase as _NodeBase -except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.node import _register_node, NodeBase as _NodeBase def _new_object(cls): @@ -43,20 +30,22 @@ def _new_object(cls): return cls.__new__(cls) -class NodeBase(_NodeBase): +class NodeBase(Object): """NodeBase is the base class of all TVM language AST object.""" def __repr__(self): return _api_internal._format_str(self) def __dir__(self): - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - check_call(_LIB.TVMNodeListAttrNames( - self.handle, ctypes.byref(size), ctypes.byref(plist))) - names = [] - for i in range(size.value): - names.append(py_str(plist[i])) - return names + fnames = _api_internal._NodeListAttrNames(self) + size = fnames(-1) + return [fnames(i) for i in range(size)] + + def __getattr__(self, name): + try: + return _api_internal._NodeGetAttr(self, name) + except AttributeError: + raise AttributeError( + "%s has no attribute %s" % (str(type(self)), name)) def __hash__(self): return _api_internal._raw_ptr(self) @@ -95,24 +84,6 @@ def same_as(self, other): return self.__hash__() == other.__hash__() -def register_node(type_key=None): - """register node type - - Parameters - ---------- - type_key : str or cls - The type key of the node - """ - node_name = type_key if isinstance(type_key, str) else type_key.__name__ - - def register(cls): - """internal register function""" - tindex = ctypes.c_int() - ret = _LIB.TVMNodeTypeKey2Index(c_str(node_name), ctypes.byref(tindex)) - if ret == 0: - _register_node(tindex.value, cls) - return cls - - if isinstance(type_key, str): - return register - return register(type_key) +# pylint: disable=invalid-name +register_node = register_object +_set_class_node(NodeBase) diff --git a/python/tvm/_ffi/object.py b/python/tvm/_ffi/object.py index be8b086a50f9..002fd27af0fd 100644 --- a/python/tvm/_ffi/object.py +++ b/python/tvm/_ffi/object.py @@ -20,25 +20,25 @@ import sys import ctypes -from .base import _FFI_MODE, check_call, _LIB, c_str +from .base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError try: - # pylint: disable=wrong-import-position + # pylint: disable=wrong-import-position,unused-import if _FFI_MODE == "ctypes": raise ImportError() if sys.version_info >= (3, 0): - from ._cy3.core import _set_class_object + from ._cy3.core import _set_class_object, _set_class_node from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object else: - from ._cy2.core import _set_class_object + from ._cy2.core import _set_class_object, _set_class_node from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object except IMPORT_EXCEPT: - # pylint: disable=wrong-import-position - from ._ctypes.function import _set_class_object + # pylint: disable=wrong-import-position,unused-import + from ._ctypes.function import _set_class_object, _set_class_node from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object @@ -75,8 +75,15 @@ def register(cls): tindex = cls._type_index else: tidx = ctypes.c_uint() - check_call(_LIB.TVMObjectTypeKey2Index( - c_str(object_name), ctypes.byref(tidx))) + if not _RUNTIME_ONLY: + check_call(_LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + else: + # directly skip unknown objects during runtime. + ret = _LIB.TVMObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx)) + if ret != 0: + return cls tindex = tidx.value _register_object(tindex, cls) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 00e19459df76..2dbb67dfbf73 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -36,13 +36,12 @@ class TypeCode(object): TVM_TYPE = 5 TVM_CONTEXT = 6 ARRAY_HANDLE = 7 - NODE_HANDLE = 8 + OBJECT_HANDLE = 8 MODULE_HANDLE = 9 FUNC_HANDLE = 10 STR = 11 BYTES = 12 NDARRAY_CONTAINER = 13 - OBJECT_HANDLE = 14 EXT_BEGIN = 15 diff --git a/python/tvm/error.py b/python/tvm/error.py index b5a7ed2374b7..a6d4f701d2a6 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -49,6 +49,7 @@ def __init__(self, msg): register_error("ValueError", ValueError) register_error("TypeError", TypeError) +register_error("AttributeError", AttributeError) @register_error diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index b36715249f0a..ded5d0d13bd7 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -62,6 +62,10 @@ def compile(mod, target=None, target_host=None, params=None): compiler._compile(mod, target, target_host) return vm.Executable(compiler._get_exec()) +def enabled(): + """Whether vm profiler is enabled.""" + return hasattr(_vm, "_VMCompilerProfiler") + class VMCompilerProfiler(vm.VMCompiler): """Build Relay module to run on VM runtime.""" def __init__(self): diff --git a/python/tvm/relay/debug.py b/python/tvm/relay/debug.py index ee30f25d88c1..8887a7eb3c7c 100644 --- a/python/tvm/relay/debug.py +++ b/python/tvm/relay/debug.py @@ -17,12 +17,8 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" from __future__ import absolute_import -from .base import NodeBase, register_relay_node from ..api import register_func -@register_relay_node -class InterpreterState(NodeBase): - pass # pylint: disable=unused-argument def _debugger_init(expr, stack): diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index d9399492264b..848d5c00ab3f 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -71,7 +71,7 @@ macro_rules! TVMPODValue { Context(TVMContext), Handle(*mut c_void), ArrayHandle(TVMArrayHandle), - NodeHandle(*mut c_void), + ObjectHandle(*mut c_void), ModuleHandle(TVMModuleHandle), FuncHandle(TVMFunctionHandle), NDArrayContainer(*mut c_void), @@ -92,7 +92,7 @@ macro_rules! TVMPODValue { TVMTypeCode_kTVMContext => Context($value.v_ctx), TVMTypeCode_kHandle => Handle($value.v_handle), TVMTypeCode_kArrayHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kNodeHandle => NodeHandle($value.v_handle), + TVMTypeCode_kObjectHandle => ObjectHandle($value.v_handle), TVMTypeCode_kModuleHandle => ModuleHandle($value.v_handle), TVMTypeCode_kFuncHandle => FuncHandle($value.v_handle), TVMTypeCode_kNDArrayContainer => NDArrayContainer($value.v_handle), @@ -124,7 +124,7 @@ macro_rules! TVMPODValue { TVMTypeCode_kArrayHandle, ) }, - NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kObjectHandle), ModuleHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle), FuncHandle(val) => ( diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 948711276304..01d0c58cfc5d 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -264,7 +264,7 @@ unsafe extern "C" fn tvm_callback( for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMTypeCode_kNodeHandle as c_int + if tcode == ffi::TVMTypeCode_kObjectHandle as c_int || tcode == ffi::TVMTypeCode_kFuncHandle as c_int || tcode == ffi::TVMTypeCode_kModuleHandle as c_int { diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index f31f02b1eaf4..c57e2afaa8eb 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -117,8 +117,7 @@ TVM_REGISTER_API("arith._CreateAnalyzer") }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - auto& sptr = args[1].node_sptr(); - if (sptr->is_type()) { + if (args[1].IsObjectRef()) { self->Bind(args[0], args[1].operator Range()); } else { self->Bind(args[0], args[1].operator Expr()); diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 28ebb4d65005..c25c35f636e6 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,7 +30,7 @@ namespace tvm { TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); + CHECK(args[0].type_code() == kObjectHandle); std::ostringstream os; os << args[0].operator NodeRef(); *ret = os.str(); @@ -38,9 +38,8 @@ TVM_REGISTER_API("_format_str") TVM_REGISTER_API("_raw_ptr") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(args[0].type_code() == kNodeHandle); - *ret = reinterpret_cast( - args[0].node_sptr().get()); + CHECK(args[0].type_code() == kObjectHandle); + *ret = reinterpret_cast(args[0].value().v_handle); }); TVM_REGISTER_API("_save_json") diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index 73e26719cf15..f2ca67e6e2f9 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,7 @@ namespace codegen { TVM_REGISTER_API("codegen._Build") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Build({args[0]}, args[1]); } else { *ret = Build(args[0], args[1]); diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index b8ee1441fe12..9312c5532302 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of API functions related to IR build * \file api_ir.cc */ diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index aa0ce47b4a37..f3d6c5f6ab62 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -57,25 +57,26 @@ TVM_REGISTER_API("_str") TVM_REGISTER_API("_Array") .set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector > data; + std::vector data; for (int i = 0; i < args.size(); ++i) { if (args[i].type_code() != kNull) { - data.push_back(args[i].node_sptr()); + data.push_back(args[i].operator ObjectRef()); } else { - data.push_back(NodePtr(nullptr)); + data.push_back(ObjectRef(nullptr)); } } auto node = make_node(); node->data = std::move(data); - *ret = node; + *ret = runtime::ObjectRef(node); }); TVM_REGISTER_API("_ArrayGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { int64_t i = args[1]; - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; *ret = n->data[static_cast(i)]; @@ -83,10 +84,11 @@ TVM_REGISTER_API("_ArrayGetItem") TVM_REGISTER_API("_ArraySize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); *ret = static_cast( - static_cast(sptr.get())->data.size()); + static_cast(ptr)->data.size()); }); TVM_REGISTER_API("_Map") @@ -98,10 +100,10 @@ TVM_REGISTER_API("_Map") for (int i = 0; i < args.num_args; i += 2) { CHECK(args[i].type_code() == kStr) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of the map to be NodeRef"; data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].node_sptr())); + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -110,12 +112,12 @@ TVM_REGISTER_API("_Map") // Container node. MapNode::ContainerType data; for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kNodeHandle) + CHECK(args[i].type_code() == kObjectHandle) << "key of str map need to be str"; - CHECK(args[i + 1].type_code() == kNodeHandle) + CHECK(args[i + 1].type_code() == kObjectHandle) << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].node_sptr(), - args[i + 1].node_sptr())); + data.emplace(std::make_pair(args[i].operator ObjectRef(), + args[i + 1].operator ObjectRef())); } auto node = make_node(); node->data = std::move(data); @@ -125,31 +127,33 @@ TVM_REGISTER_API("_Map") TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast(n->data.size()); } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - CHECK(args[1].type_code() == kNodeHandle); - auto* n = static_cast(sptr.get()); - auto it = n->data.find(args[1].node_sptr()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + CHECK(args[1].type_code() == kObjectHandle); + auto* n = static_cast(ptr); + auto it = n->data.find(args[1].operator ObjectRef()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; *ret = (*it).second; } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); auto it = n->data.find(args[1].operator std::string()); CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; @@ -159,16 +163,17 @@ TVM_REGISTER_API("_MapGetItem") TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK(args[0].type_code() == kNodeHandle); - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); - CHECK(args[1].type_code() == kNodeHandle); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); + CHECK_EQ(args[0].type_code(), kObjectHandle); *ret = static_cast( - n->data.count(args[1].node_sptr())); + n->data.count(args[1].operator ObjectRef())); } else { - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); *ret = static_cast( n->data.count(args[1].operator std::string())); } @@ -176,9 +181,11 @@ TVM_REGISTER_API("_MapCount") TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { - auto& sptr = args[0].node_sptr(); - if (sptr->is_type()) { - auto* n = static_cast(sptr.get()); + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + + if (ptr->IsInstance()) { + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { rkvs->data.push_back(kv.first); @@ -186,10 +193,10 @@ TVM_REGISTER_API("_MapItems") } *ret = rkvs; } else { - auto* n = static_cast(sptr.get()); + auto* n = static_cast(ptr); auto rkvs = make_node(); for (const auto& kv : n->data) { - rkvs->data.push_back(ir::StringImm::make(kv.first).node_); + rkvs->data.push_back(ir::StringImm::make(kv.first)); rkvs->data.push_back(kv.second); } *ret = rkvs; @@ -426,7 +433,7 @@ TVM_REGISTER_API("_ScheduleCacheRead") TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsNodeType()) { + if (args[1].IsObjectRef()) { *ret = args[0].operator Schedule() .cache_write(args[1].operator Tensor(), args[2]); } else { diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index d2352496c2b4..dd0415afd9eb 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -35,7 +35,7 @@ namespace ir { TVM_REGISTER_API("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = Simplify(args[0].operator Stmt(), args[1]); } else { @@ -52,7 +52,7 @@ TVM_REGISTER_API("ir_pass.Simplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { if (args.size() > 1) { *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); } else { @@ -69,7 +69,7 @@ TVM_REGISTER_API("ir_pass.CanonicalSimplify") TVM_REGISTER_API("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); } else { *ret = Substitute(args[0].operator Expr(), args[1].operator Map()); @@ -78,7 +78,7 @@ TVM_REGISTER_API("ir_pass.Substitute") TVM_REGISTER_API("ir_pass.Equal") .set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsNodeType()) { + if (args[0].IsObjectRef()) { *ret = Equal(args[0].operator Stmt(), args[1].operator Stmt()); } else { *ret = Equal(args[0].operator Expr(), args[1].operator Expr()); diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 177360bf2ebb..cf0e0f3c6b7a 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Implementation of API functions related to schedule pass. * \file api_schedule.cc */ diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 89e999f73edb..64805c9e8aa0 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -18,36 +18,18 @@ */ /*! - * Copyright (c) 2016 by Contributors * Implementation of DSL API * \file dsl_api.cc */ -#include #include -#include #include #include +#include #include #include -#include -#include "../runtime/dsl_api.h" namespace tvm { namespace runtime { -/*! \brief entry to to easily hold returning information */ -struct TVMAPIThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; - /*! \brief result holder for retruning string */ - std::string ret_str; -}; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMAPIThreadLocalStore; - -using TVMAPINode = NodePtr; struct APIAttrGetter : public AttrVisitor { std::string skey; @@ -138,93 +120,71 @@ struct APIAttrDir : public AttrVisitor { } }; -class DSLAPIImpl : public DSLAPI { - public: - void NodeFree(NodeHandle handle) const final { - delete static_cast(handle); - } - void NodeTypeKey2Index(const char* type_key, - int* out_index) const final { - *out_index = static_cast(Node::TypeKey2Index(type_key)); - } - void NodeGetTypeIndex(NodeHandle handle, - int* out_index) const final { - *out_index = static_cast( - (*static_cast(handle))->type_index()); - } - void NodeGetAttr(NodeHandle handle, - const char* key, - TVMValue* ret_val, - int* ret_type_code, - int* ret_success) const final { - TVMRetValue rv; +struct NodeAPI { + static void GetAttr(TVMArgs args, TVMRetValue* ret) { + NodeRef ref = args[0]; + Node* tnode = const_cast(ref.get()); APIAttrGetter getter; - TVMAPINode* tnode = static_cast(handle); - getter.skey = key; - getter.ret = &rv; + getter.skey = args[1].operator std::string(); + getter.ret = ret; + + bool success; if (getter.skey == "type_key") { - ret_val->v_str = (*tnode)->type_key(); - *ret_type_code = kStr; - *ret_success = 1; - return; - } else if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&getter); - *ret_success = getter.found_ref_object || rv.type_code() != kNull; + *ret = tnode->GetTypeKey(); + success = true; + } else if (!tnode->IsInstance()) { + tnode->VisitAttrs(&getter); + success = getter.found_ref_object || ret->type_code() != kNull; } else { // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); - auto it = dnode->dict.find(key); + DictAttrsNode* dnode = static_cast(tnode); + auto it = dnode->dict.find(getter.skey); if (it != dnode->dict.end()) { - *ret_success = 1; - rv = (*it).second; + success = true; + *ret = (*it).second; } else { - *ret_success = 0; + success = false; } } - if (*ret_success) { - if (rv.type_code() == kStr || - rv.type_code() == kTVMType) { - TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); - e->ret_str = rv.operator std::string(); - *ret_type_code = kStr; - ret_val->v_str = e->ret_str.c_str(); - } else { - rv.MoveToCHost(ret_val, ret_type_code); - } + if (!success) { + LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey() + << " object has no attributed " << getter.skey; } } - void NodeListAttrNames(NodeHandle handle, - int *out_size, - const char*** out_array) const final { - TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); - ret->ret_vec_str.clear(); - TVMAPINode* tnode = static_cast(handle); + + static void ListAttrNames(TVMArgs args, TVMRetValue* ret) { + NodeRef ref = args[0]; + Node* tnode = const_cast(ref.get()); + auto names = std::make_shared >(); APIAttrDir dir; - dir.names = &(ret->ret_vec_str); + dir.names = names.get(); - if (!(*tnode)->is_type()) { - (*tnode)->VisitAttrs(&dir); + if (!tnode->IsInstance()) { + tnode->VisitAttrs(&dir); } else { // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode->get()); + DictAttrsNode* dnode = static_cast(tnode); for (const auto& kv : dnode->dict) { - ret->ret_vec_str.push_back(kv.first); + names->push_back(kv.first); } } - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); + + *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); } }; -TVM_REGISTER_GLOBAL("dsl_api.singleton") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static DSLAPIImpl impl; - void* ptr = &impl; - *rv = ptr; - }); +TVM_REGISTER_GLOBAL("_NodeGetAttr") +.set_body(NodeAPI::GetAttr); + +TVM_REGISTER_GLOBAL("_NodeListAttrNames") +.set_body(NodeAPI::ListAttrNames); + } // namespace runtime } // namespace tvm diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index acd964935c25..98e25742592d 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -36,9 +36,7 @@ Analyzer::Analyzer() int_set(this) { } -void Analyzer::Bind(const VarExpr& v, const Expr& expr) { - Var var(v.node_); - +void Analyzer::Bind(const VarExpr& var, const Expr& expr) { Expr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); @@ -49,9 +47,8 @@ void Analyzer::Bind(const VarExpr& v, const Expr& expr) { this->canonical_simplify.Update(var, new_expr); } -void Analyzer::Bind(const VarExpr& v, const Range& range) { +void Analyzer::Bind(const VarExpr& var, const Range& range) { CHECK(range.defined()); - Var var(v.node_); if (is_one(range->extent)) { this->Bind(var, range->min); } else { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d80e4969d5c2..02e8079c9c7b 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -629,7 +629,7 @@ Mutate_(const Mul* op, const Expr& self) { } if (const auto* bconst = b.as()) { if (a.as()) { - SumExpr ret(std::move(a.node_)); + SumExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->MulToSelf(bconst->value); return std::move(ret); } else { @@ -931,7 +931,7 @@ Mutate_(const Mod* op, const Expr& self) { int64_t new_base = psum->base % cval; if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); } @@ -992,7 +992,7 @@ Mutate_(const FloorMod* op, const Expr& self) { // Simplify the offset constant if necessary. // floormod(x - 5, 3) => floormod(x + 1, 3) int64_t new_base = floormod(psum->base, cval); - SumExpr sum_expr(std::move(a.node_)); + SumExpr sum_expr = Downcast(std::move(a)); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); } else { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index d5c012d302dc..168486ee0018 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -39,7 +39,7 @@ ConstIntBound::ConstIntBound( auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - node_ = std::move(node); + data_ = std::move(node); } inline void PrintBoundValue(std::ostream& os, int64_t val) { diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 3c5f12a7379e..7da020efc42a 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -176,7 +176,7 @@ bool DetectClipBound( if (const Variable* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { - var = Var(n.node_); + var = Downcast(n); flag = 1; } else if (flag == 1) { if (!var.same_as(n)) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 0e24714daf1f..313b34ded034 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -40,7 +40,7 @@ IntervalSet::IntervalSet(Expr min_value, Expr max_value) { auto node = make_node(); node->min_value = std::move(min_value); node->max_value = std::move(max_value); - node_ = std::move(node); + data_ = std::move(node); } IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { @@ -506,7 +506,7 @@ class IntervalSetEvaluator : } IntervalSet VisitExprDefault_(const Node* op) final { - DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); return IntervalSet::Everything(); } diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 04e166ae52c0..cda9d585ace1 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -87,7 +87,7 @@ Stmt IRMutatorWithAnalyzer:: Mutate_(const AttrStmt* op, const Stmt& s) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index 71eea50e4c72..918f2e89501f 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -47,7 +47,7 @@ class IRVisitorWithAnalyzer final : public IRVisitor { void Visit_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 08454dd0ef5a..9e363e7cf99a 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -41,7 +41,7 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { node->coeff = coeff; node->base = base; // finish construction. - node_ = std::move(node); + data_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3f1c32243a23..66340e9c9021 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -34,6 +34,7 @@ namespace tvm { TVM_REGISTER_NODE_TYPE(TargetNode); +TVM_REGISTER_NODE_TYPE(GenericFuncNode); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const TargetNode *op, IRPrinter *p) { @@ -51,9 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) */ Target CreateTarget(const std::string& target_name, const std::vector& options) { - auto target = Target(make_node()); - auto t = static_cast(target.node_.get()); - + auto t = make_node(); t->target_name = target_name; std::string libs_flag = "-libs="; @@ -137,7 +136,7 @@ Target CreateTarget(const std::string& target_name, return target::stackvm(); } - return target; + return Target(t); } TVM_REGISTER_API("_TargetCreate") @@ -674,7 +673,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); struct GenericFunc::Manager { - std::unordered_map > fmap; + std::unordered_map fmap; // mutex std::mutex mutex; @@ -694,10 +693,11 @@ GenericFunc GenericFunc::Get(const std::string& name) { if (it == m->fmap.end()) { auto f = make_node(); f->name_ = name; - m->fmap[name] = f; - return GenericFunc(f); + auto gf = GenericFunc(f); + m->fmap[name] = gf; + return gf; } else { - return GenericFunc(it->second); + return it->second; } } @@ -707,12 +707,12 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) auto it = m->fmap.find(name); CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name; func->name_ = name; - m->fmap[name] = func.node_; + m->fmap[name] = func; } GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { - auto node = static_cast(node_.get()); + bool allow_override) { + auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) << "Generic function already registered for " << node->name_; @@ -736,7 +736,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, } void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { - auto node = static_cast(node_.get()); + auto node = static_cast(get()); auto target = Target::Current(true); PackedFunc func; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index ecf62ab0cfac..ab203f2aa28a 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -806,7 +806,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { void CodeGenC::VisitStmt_(const AttrStmt* op) { if (op->attr_key == ir::attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { BindThreadIndex(iv); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index d009290bb2fe..de54e242ff40 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1173,7 +1173,7 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) { void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv); diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 7caf3a258b6f..6a3b0571c9ab 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -606,7 +606,7 @@ void CodeGenSPIRV::VisitStmt_(const Allocate* op) { void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent) { - IterVar iv(op->node.node_); + IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 54616adc214e..778b6b1a7811 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -300,7 +300,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == ir::attr::realize_scope) { - auto v = FunctionRef(op->node.node_); + auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); } else { @@ -408,7 +408,7 @@ void CodeGenHybrid::PrintIndent() { std::string CodeGenHybrid::GetVarID(const Variable *v) { if (binds_.count(v)) return binds_[v]; - auto key = std::make_pair(v->GetNodePtr().get(), 0); + auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 498838fc908f..866756996f8d 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file codegen_hybrid.h * \brief Common utilities to generated C style code. */ diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 995dfb392e87..b9391e4895b9 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -44,17 +44,17 @@ class AttrFunctor; #define ATTR_FUNCTOR_DISPATCH(OP) \ vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.node_.get()), \ + [](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), \ std::forward(args)...); \ }); \ // A functor for common attribute information. template -class AttrFunctor { +class AttrFunctor { private: - using TSelf = AttrFunctor; - using FType = tvm::IRFunctor; + using TSelf = AttrFunctor; + using FType = tvm::IRFunctor; public: /*! \brief the result type of this functor */ @@ -65,7 +65,7 @@ class AttrFunctor { * \param args Additional arguments. * \return The result of the call */ - virtual R VisitAttr(const NodeRef& n, Args... args) { + virtual R VisitAttr(const ObjectRef& n, Args... args) { static FType vtable = InitVTable(); if (vtable.can_dispatch(n)) { return vtable(n, this, std::forward(args)...); @@ -73,7 +73,7 @@ class AttrFunctor { return VisitAttrDefault_(n.get(), std::forward(args)...); } } - virtual R VisitAttrDefault_(const Node* node, Args... args) = 0; + virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -143,60 +143,60 @@ class AttrFunctor { }; class AttrsEqualHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Check if lhs equals rhs * \param lhs The left operand. * \param rhs The right operand. */ - bool Equal(const NodeRef& lhs, const NodeRef& rhs); + bool Equal(const ObjectRef& lhs, const ObjectRef& rhs); protected: - bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final; - bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::IntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::UIntImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloatImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::StringImm* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::GT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LT* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::LE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::EQ* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::NE* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::And* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Or* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Not* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Cast* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Call* lhs, const NodeRef& other) final; - bool VisitAttr_(const ir::Select* lhs, const NodeRef& other) final; + bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::IntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::LE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::EQ* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::NE* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final; }; class AttrsHashHandler : - protected AttrFunctor { + protected AttrFunctor { public: /*! * \brief Get hash value of node * \param node The node to be hashed. */ - size_t Hash(const NodeRef& node) { + size_t Hash(const ObjectRef& node) { if (!node.defined()) return 0; return this->VisitAttr(node); } protected: - size_t VisitAttrDefault_(const Node* lhs) final; + size_t VisitAttrDefault_(const Object* lhs) final; size_t VisitAttr_(const ir::IntImm* lhs) final; size_t VisitAttr_(const ir::UIntImm* lhs) final; size_t VisitAttr_(const ir::FloatImm* lhs) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index c5b14ac577ec..a299e17996e0 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -40,7 +40,7 @@ void DictAttrsNode::InitByPackedArgs( for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; - if (val.type_code() == kNodeHandle) { + if (val.type_code() == kObjectHandle) { dict.Set(key, val.operator NodeRef()); } else if (val.type_code() == kStr) { dict.Set(key, Expr(val.operator std::string())); @@ -72,14 +72,14 @@ TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); using namespace ir; // Equal handler. -bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) { +bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) { if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; return this->VisitAttr(lhs, rhs); } -bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) { - if (lhs->derived_from()) { +bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) { + if (lhs->IsInstance()) { AttrsEqual equal; equal.handler_ = this; return static_cast(lhs)->ContentEqual( @@ -88,58 +88,58 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) return lhs == other.get(); } -bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->value == rhs->value; } return false; } -bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false; + if (!Equal(lhs->data[i], rhs->data[i])) return false; } } return true; } -bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (rhs->data.size() != lhs->data.size()) return false; for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); if (it == rhs->data.end()) return false; - if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false; + if (!Equal(kv.second, it->second)) return false; } } return true; } #define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \ - bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \ + bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \ if (const auto* rhs = other.as()) { \ if (!Equal(lhs->a, rhs->a)) return false; \ if (!Equal(lhs->b, rhs->b)) return false; \ @@ -167,7 +167,7 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(NE); TVM_DEFINE_ATTRS_BINOP_EQUAL(And); TVM_DEFINE_ATTRS_BINOP_EQUAL(Or); -bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return Equal(lhs->a, rhs->a); } else { @@ -175,7 +175,7 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { if (lhs->type != rhs->type) return false; return Equal(lhs->value, rhs->value); @@ -184,7 +184,7 @@ bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const ObjectRef& other) { if (const auto* rhs = other.as()) { return lhs->name == rhs->name && @@ -196,7 +196,7 @@ bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) { +bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const ObjectRef& other) { if (const auto* rhs = other.as