From 41f6a7f8b7ea8e3aa0726cd38715d82fe24f0fba Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Wed, 23 Apr 2025 10:24:21 +0530 Subject: [PATCH 01/13] add instance norm support --- include/tvm/relax/attrs/nn.h | 16 +++ .../torch/exported_program_translator.py | 36 ++++++ python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 114 ++++++++++++++++++ python/tvm/relax/op/op_attrs.py | 3 + python/tvm/relax/transform/legalize_ops/nn.py | 11 ++ python/tvm/topi/nn/instance_norm.py | 92 +++++++++++--- src/relax/op/nn/nn.cc | 110 +++++++++++++++++ 8 files changed, 368 insertions(+), 15 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index f0f80ad8f4a0..01b40ec2c694 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -534,6 +534,22 @@ struct GroupNormAttrs : public tvm::AttrsNode { } }; // struct GroupNormAttrs +/*! \brief Attributes used in group_norm operator */ +struct InstanceNormAttrs : public tvm::AttrsNode { + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(InstanceNormAttrs, "relax.attrs.InstanceNormAttrs") { + TVM_ATTR_FIELD(axes).describe("The axis along which the normalization is applied."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct InstanceNormAttrs + /*! \brief Attributes used in rms_norm operator */ struct RMSNormAttrs : public tvm::AttrsNode { Array axes; diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index df532fd1ea04..8ad2e1c3bd43 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -25,6 +25,7 @@ import torch import tvm from tvm import relax +import tvm.topi from .base_fx_graph_translator import BaseFXGraphImporter @@ -99,7 +100,36 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: training=training, )[0] ) + def _instance_norm(self, node: fx.Node, training) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + ignore_running_stats = ( + node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) + ) + track_running_stats = not ignore_running_stats + momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) + eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) + + if track_running_stats: + training = True + return self.block_builder.emit( + relax.op.nn.instance_norm( + data=x, + gamma=weight, + beta=bias, + axis=[0,1], # Always over channel + epsilon=eps, + ) + ) + # return self.block_builder.emit_te(tvm.topi.nn.instance_norm,x,weight,bias,[0,1],eps) def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: # This method is called for batch_norm in training mode # TODO does not have correctness! @@ -113,6 +143,11 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: training = False return self._batch_norm(node, training) + def _instance_norm_no_training(self, node: fx.Node) -> relax.Var: + # This method is called for batch_norm in eval mode + training = False + return self._instance_norm(node, training) + def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_groups = node.args[1] @@ -387,6 +422,7 @@ def create_convert_map( "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, + "instance_norm.default": self._instance_norm_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 08ecda275c3e..090b78d94397 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -51,4 +51,5 @@ silu, softmax, softplus, + instance_norm ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index e234e8ad7b18..b9237f2bb91e 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -2071,3 +2071,117 @@ def attention_var_len( causal_mask, window_size, ) # type: ignore + + +def instance_norm( + data: Expr, + gamma: Expr, + beta: Expr, + axis: int, + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, + +) -> Expr: + r""" + Batch normalization layer (Ioffe and Szegedy, 2014). + + Normalizes the input at each batch, i.e. applies a transformation + that maintains the mean activation close to 0 and the activation + standard deviation close to 1. + + .. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + + Both *mean* and *var* returns a scalar by treating the input as a vector. + + Then compute the normalized output, which has the same shape as input, as following: + + .. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} + * gamma[i] + beta[i] + + Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` + have shape *(k,)*. + + Besides the inputs and the outputs, this operator accepts two auxiliary + states, ``moving_mean`` and ``moving_var``, which are *k*-length + vectors. They are global statistics for the whole dataset, which are updated by + + .. code:: python + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + + The parameter ``axis`` specifies which axis of the input shape denotes + the 'channel' (separately normalized groups). The default is 1. + Specifying -1 sets the channel axis to be the last item in the input shape. + + .. note:: + + This operator has two modes: + + - Training mode. + - Use the mean and var computed from THIS batch to normalize. + - Update and then return the running mean and running var. + + - Inference mode. + - Use the running_mean and running_var parameters to normalize. + - Do not update the running mean and running var. Just return the original value. + + In the legalization stage, this operator will be legalized to the training mode by default. + + You can use tvm.relax.transform.DecomposeOpsForInference to decompose the operator, so it + executes the inference mode computation. Similarly, use + tvm.relax.transform.DecomposeOpsForTraining to execute the training mode computation. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + moving_mean : relax.Expr + Running mean of input. + + moving_var : relax.Expr + Running variance of input. + + axis : int + The axis along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + momentum : float + The value used for the moving_mean and moving_var update. + + training : bool + A boolean value to indicate whether training or in eval mode. By default. + relax instance_norm is training mode. To transform it to inference mode, + can use DecomposeOpsForInference. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.instance_norm( # type: ignore + data, gamma, beta, axis, epsilon, center, scale, + ) \ No newline at end of file diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index fe527e38e8a8..c2032af30933 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -93,6 +93,9 @@ class BatchNormAttrs(Attrs): class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" +@tvm._ffi.register_object("relax.attrs.InstanceNormAttrs") +class InstanceNormAttrs(Attrs): + """Attributes used in instance_norm operator""" @tvm._ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index f18ad6097f06..6bf2a976b8d8 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -633,6 +633,17 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: call.attrs.epsilon, ) +@register_legalize("relax.nn.instance_norm") +def _nn_instance_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.instance_norm, + data=call.args[0], + gamma=call.args[1], + beta=call.args[2], + axis=call.attrs.axes, + epsilon=call.attrs.epsilon, + ) + @register_legalize("relax.nn.rms_norm") def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr: diff --git a/python/tvm/topi/nn/instance_norm.py b/python/tvm/topi/nn/instance_norm.py index d119b57bfdee..386664223487 100644 --- a/python/tvm/topi/nn/instance_norm.py +++ b/python/tvm/topi/nn/instance_norm.py @@ -16,32 +16,94 @@ # under the License. """Instance normalization operator.""" from .. import cpp +from tvm import te +from tvm import topi +from functools import reduce +from typing import Union,List +# def instance_norm(data, gamma, beta, axis, epsilon=1e-5): +# """Instance normalization operator. -def instance_norm(data, gamma, beta, axis, epsilon=1e-5): - """Instance normalization operator. +# Parameters +# ---------- +# data : tvm.te.Tensor +# N-D with shape (d_0, d_1, ..., d_{N-1}) + +# gamma: tvm.te.Tensor +# K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + +# beta: tvm.te.Tensor +# Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + +# axis : list of int +# Axis over the normalization applied (the axis along which the mean and variance are +# computed) + +# epsilon : float +# The epsilon value to avoid division by zero. + +# Returns +# ------- +# result : tvm.te.Tensor +# N-D with shape (d_0, d_1, ..., d_{N-1}) +# """ +# return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon) + +def instance_norm( + data: te.Tensor, + gamma: te.Tensor, + beta: te.Tensor, + axis: Union[int, List[int]] = [0, 1], + epsilon: float = 1e-5, +) -> te.Tensor: + """Instance normalization over spatial dimensions. + + Normalizes each instance in a batch independently per channel, + typically used in style transfer and vision models. Parameters ---------- - data : tvm.te.Tensor - N-D with shape (d_0, d_1, ..., d_{N-1}) + data : te.Tensor + Input tensor with shape [N, C, H, W]. - gamma: tvm.te.Tensor - K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + gamma : te.Tensor + Scale tensor of shape [C]. - beta: tvm.te.Tensor - Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k + beta : te.Tensor + Offset tensor of shape [C]. - axis : list of int - Axis over the normalization applied (the axis along which the mean and variance are - computed) + axis : int or list of int, default=[0, 1] + Axes to preserve (typically N and C). Reduction happens over the rest. epsilon : float - The epsilon value to avoid division by zero. + Small value added to variance to avoid divide-by-zero. Returns ------- - result : tvm.te.Tensor - N-D with shape (d_0, d_1, ..., d_{N-1}) + out : te.Tensor + Instance-normalized tensor with same shape as input. """ - return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon) + if isinstance(axis, int): + axis = [axis] + + shape = [1] * len(data.shape) + for ax in axis: + print(type(int(ax))) + shape[int(ax)] = data.shape[int(ax)] + + reduce_axes = [i for i in range(len(data.shape)) if i not in axis] + shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) + + mean = topi.sum(data, axis=reduce_axes) / shape_prod + mean_rs = topi.reshape(mean, shape) + + var = topi.sum(topi.power(data - mean_rs, 2), axis=reduce_axes) / shape_prod + var_rs = topi.reshape(var, shape) + + gamma_rs = topi.reshape(gamma, shape) + beta_rs = topi.reshape(beta, shape) + + normalized = (data - mean_rs) / topi.sqrt(var_rs + epsilon) + out = normalized * gamma_rs + beta_rs + + return out \ No newline at end of file diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 16b8f467ff0f..0f0b7bccc70b 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -21,6 +21,7 @@ #include #include +#include "tvm/relax/attrs/nn.h" namespace tvm { namespace relax { @@ -624,6 +625,115 @@ TVM_REGISTER_OP("relax.nn.group_norm") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.nn.instance_norm */ +TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); + +Expr instance_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, + bool scale) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.instance_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.instance_norm").set_body_typed(instance_norm); + +StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + const TensorStructInfo& data_sinfo = input_sinfo[0]; + + // Check dtype: must be float/bfloat + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() && + !data_sinfo->dtype.is_bfloat()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "InstanceNorm requires the input data to have float or bfloat dtype. " + "However, the given data dtype is " + << data_sinfo->dtype); + } + + // Check gamma and beta shapes and dtypes. + if (input_sinfo.size() > 1 && !input_sinfo[1]->IsUnknownDtype() && + input_sinfo[1]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "InstanceNorm requires gamma to have the same dtype as data. " + "However, the given gamma dtype is " + << input_sinfo[1]->dtype); + } + + if (input_sinfo.size() > 2 && !input_sinfo[2]->IsUnknownDtype() && + input_sinfo[2]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "InstanceNorm requires beta to have the same dtype as data. " + "However, the given beta dtype is " + << input_sinfo[2]->dtype); + } + + if (input_sinfo.size() > 1 && !input_sinfo[1]->IsUnknownNdim() && input_sinfo[1]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "InstanceNorm requires gamma to have ndim=1. " + "However, the given gamma ndim is " + << input_sinfo[1]->ndim); + } + + if (input_sinfo.size() > 2 && !input_sinfo[2]->IsUnknownNdim() && input_sinfo[2]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "InstanceNorm requires beta to have ndim=1. " + "However, the given beta ndim is " + << input_sinfo[2]->ndim); + } + + //The most tricky part: check if the dimension of gamma/beta matches the axis shape of data + + return data_sinfo; +} + +InferLayoutOutput InferLayoutInstanceNorm(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + std::vector initial_layouts; + for (size_t i = 0; i < 3; ++i) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); + } + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + + //InstanceNorm typically normalize across spatial dimensions, but keep channel dim untouched + //So just keeping the layout as original. Handling sub layouts are out of scope and should be handled by decompose/fusion methods + + ObjectPtr new_attrs = make_object(*attrs); + //const auto* input_sinfo = GetStructInfoAs(call->args[0]); + //int ndim = input_sinfo->ndim; <- no need to normalize axes since it will most likely normalize across width/height axis + //std::vector new_axis; + //for (const auto& axis : attrs->axes) { + // new_axis.push_back(FindAxis(layout->layout, (axis->value + ndim) % ndim)); + //} + //new_attrs->axes = std::move(new_axis); <- NO NEED to normalize as mentioned above + return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, + Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.nn.instance_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoInstanceNorm) + .set_attr("FRelaxInferLayout", InferLayoutInstanceNorm) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); /* relax.nn.rms_norm */ TVM_REGISTER_NODE_TYPE(RMSNormAttrs); From 1a0decd68c622a6688f9ec85939c1db73481b8f0 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 24 Apr 2025 10:14:58 +0530 Subject: [PATCH 02/13] add support for instance norm --- include/tvm/relax/attrs/nn.h | 4 +-- .../torch/exported_program_translator.py | 15 ++-------- python/tvm/relax/op/nn/__init__.py | 2 +- python/tvm/relax/op/nn/nn.py | 4 +-- python/tvm/topi/nn/instance_norm.py | 29 ------------------- 5 files changed, 7 insertions(+), 47 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 01b40ec2c694..53547a182a55 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -534,7 +534,7 @@ struct GroupNormAttrs : public tvm::AttrsNode { } }; // struct GroupNormAttrs -/*! \brief Attributes used in group_norm operator */ +/*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public tvm::AttrsNode { Array axes; double epsilon; @@ -542,7 +542,7 @@ struct InstanceNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(InstanceNormAttrs, "relax.attrs.InstanceNormAttrs") { - TVM_ATTR_FIELD(axes).describe("The axis along which the normalization is applied."); + TVM_ATTR_FIELD(axes).describe("The axes along which the normalization is applied."); TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); TVM_ATTR_FIELD(center).describe( "Indicating if the beta offset will be added to the normalized tensor."); diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8ad2e1c3bd43..83d36132ef40 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -100,7 +100,7 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: training=training, )[0] ) - def _instance_norm(self, node: fx.Node, training) -> relax.Var: + def _instance_norm(self, node: fx.Node) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -108,18 +108,8 @@ def _instance_norm(self, node: fx.Node, training) -> relax.Var: dtype = x.struct_info.dtype weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) - running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) - running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) - ignore_running_stats = ( - node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) - ) - track_running_stats = not ignore_running_stats - momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - if track_running_stats: - training = True - return self.block_builder.emit( relax.op.nn.instance_norm( data=x, @@ -145,8 +135,7 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: def _instance_norm_no_training(self, node: fx.Node) -> relax.Var: # This method is called for batch_norm in eval mode - training = False - return self._instance_norm(node, training) + return self._instance_norm(node) def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 090b78d94397..0b90d0cca831 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -35,6 +35,7 @@ gelu, gelu_tanh, group_norm, + instance_norm, layer_norm, leakyrelu, log_softmax, @@ -51,5 +52,4 @@ silu, softmax, softplus, - instance_norm ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index b9237f2bb91e..94cab248a021 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -2084,9 +2084,9 @@ def instance_norm( ) -> Expr: r""" - Batch normalization layer (Ioffe and Szegedy, 2014). + Instance normalization layer - Normalizes the input at each batch, i.e. applies a transformation + Normalizes the input at each instance, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. diff --git a/python/tvm/topi/nn/instance_norm.py b/python/tvm/topi/nn/instance_norm.py index 386664223487..eae4b50e4051 100644 --- a/python/tvm/topi/nn/instance_norm.py +++ b/python/tvm/topi/nn/instance_norm.py @@ -15,40 +15,11 @@ # specific language governing permissions and limitations # under the License. """Instance normalization operator.""" -from .. import cpp from tvm import te from tvm import topi from functools import reduce from typing import Union,List -# def instance_norm(data, gamma, beta, axis, epsilon=1e-5): -# """Instance normalization operator. - -# Parameters -# ---------- -# data : tvm.te.Tensor -# N-D with shape (d_0, d_1, ..., d_{N-1}) - -# gamma: tvm.te.Tensor -# K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k - -# beta: tvm.te.Tensor -# Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k - -# axis : list of int -# Axis over the normalization applied (the axis along which the mean and variance are -# computed) - -# epsilon : float -# The epsilon value to avoid division by zero. - -# Returns -# ------- -# result : tvm.te.Tensor -# N-D with shape (d_0, d_1, ..., d_{N-1}) -# """ -# return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon) - def instance_norm( data: te.Tensor, gamma: te.Tensor, From c6fdeab731f064b31892fb585ef917cab08cd215 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 24 Apr 2025 10:25:55 +0530 Subject: [PATCH 03/13] remove unwanted comments --- .../torch/exported_program_translator.py | 1 - python/tvm/relax/op/nn/nn.py | 104 +++++------------- src/relax/op/nn/nn.cc | 10 -- 3 files changed, 29 insertions(+), 86 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 83d36132ef40..8eaedb0b7d39 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -25,7 +25,6 @@ import torch import tvm from tvm import relax -import tvm.topi from .base_fx_graph_translator import BaseFXGraphImporter diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 94cab248a021..72571d524bd4 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -2077,111 +2077,65 @@ def instance_norm( data: Expr, gamma: Expr, beta: Expr, - axis: int, + axis: List[int], epsilon: float = 1e-5, center: bool = True, scale: bool = True, - ) -> Expr: r""" - Instance normalization layer - - Normalizes the input at each instance, i.e. applies a transformation - that maintains the mean activation close to 0 and the activation - standard deviation close to 1. + Applies Instance Normalization to the input tensor. - .. math:: + Instance Normalization is a technique commonly used in deep learning, + particularly in tasks such as style transfer, where the normalization is + applied per-instance (i.e., per-sample in a batch) rather than across a batch. + It normalizes the input tensor for each sample independently, along the + specified channel axis. - data\_mean[i] = mean(data[:,i,:,...]) \\ - data\_var[i] = var(data[:,i,:,...]) - - Both *mean* and *var* returns a scalar by treating the input as a vector. - - Then compute the normalized output, which has the same shape as input, as following: + Mathematically, for each channel :math:`i` of the input, the normalized output is computed as: .. math:: - out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} - * gamma[i] + beta[i] + \mu_i &= \text{mean}(x[:, i, ...]) \\ + \sigma_i^2 &= \text{var}(x[:, i, ...]) \\ + \hat{x}[:, i, ...] &= \frac{x[:, i, ...] - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} - Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` - have shape *(k,)*. + \text{output}[:, i, ...] = \gamma[i] \cdot \hat{x}[:, i, ...] + \beta[i] - Besides the inputs and the outputs, this operator accepts two auxiliary - states, ``moving_mean`` and ``moving_var``, which are *k*-length - vectors. They are global statistics for the whole dataset, which are updated by - - .. code:: python - - moving_mean = moving_mean * momentum + data_mean * (1 - momentum) - moving_var = moving_var * momentum + data_var * (1 - momentum) - - The parameter ``axis`` specifies which axis of the input shape denotes - the 'channel' (separately normalized groups). The default is 1. - Specifying -1 sets the channel axis to be the last item in the input shape. - - .. note:: - - This operator has two modes: - - - Training mode. - - Use the mean and var computed from THIS batch to normalize. - - Update and then return the running mean and running var. - - - Inference mode. - - Use the running_mean and running_var parameters to normalize. - - Do not update the running mean and running var. Just return the original value. - - In the legalization stage, this operator will be legalized to the training mode by default. - - You can use tvm.relax.transform.DecomposeOpsForInference to decompose the operator, so it - executes the inference mode computation. Similarly, use - tvm.relax.transform.DecomposeOpsForTraining to execute the training mode computation. + where :math:`\gamma` and :math:`\beta` are learnable parameters of shape equal to the + number of channels. Parameters ---------- data : relax.Expr - The input data to the operator. + The input tensor to be normalized. gamma : relax.Expr - The gamma scale factor. + The scale tensor (gamma) applied after normalization. Must have shape matching the number of channels. beta : relax.Expr - The beta offset factor. - - moving_mean : relax.Expr - Running mean of input. - - moving_var : relax.Expr - Running variance of input. + The offset tensor (beta) applied after normalization. Must have shape matching the number of channels. - axis : int - The axis along which the normalization is applied. - - epsilon : float - Small float added to variance to avoid dividing by zero. + axis : List[int], default = [0, 1] + The axes to retain during normalization. Typically `[0, 1]` corresponding to batch (N) and channel (C) + dimensions. Normalization is applied over the remaining spatial axes. - center : bool - Indicating if the beta offset will be added to the normalized tensor. - - scale : bool - Indicating if the gamma scale will be multiplied. + epsilon : float, optional + A small constant added to the variance to avoid division by zero. Default is 1e-5. - momentum : float - The value used for the moving_mean and moving_var update. + center : bool, optional + If True, add `beta` to the normalized tensor. Default is True. - training : bool - A boolean value to indicate whether training or in eval mode. By default. - relax instance_norm is training mode. To transform it to inference mode, - can use DecomposeOpsForInference. + scale : bool, optional + If True, multiply the normalized tensor by `gamma`. Default is True. Returns ------- result : relax.Expr - The computed result. + The tensor after applying instance normalization. """ + if isinstance(axis, int): axis = [axis] return _ffi_api.instance_norm( # type: ignore data, gamma, beta, axis, epsilon, center, scale, - ) \ No newline at end of file + ) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 0f0b7bccc70b..331326244ee4 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -709,17 +709,7 @@ InferLayoutOutput InferLayoutInstanceNorm(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - //InstanceNorm typically normalize across spatial dimensions, but keep channel dim untouched - //So just keeping the layout as original. Handling sub layouts are out of scope and should be handled by decompose/fusion methods - ObjectPtr new_attrs = make_object(*attrs); - //const auto* input_sinfo = GetStructInfoAs(call->args[0]); - //int ndim = input_sinfo->ndim; <- no need to normalize axes since it will most likely normalize across width/height axis - //std::vector new_axis; - //for (const auto& axis : attrs->axes) { - // new_axis.push_back(FindAxis(layout->layout, (axis->value + ndim) % ndim)); - //} - //new_attrs->axes = std::move(new_axis); <- NO NEED to normalize as mentioned above return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, Attrs(new_attrs)); } From 38a3c7a627b5e6b44b9b24ad4079049c2cb07c9e Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 24 Apr 2025 10:29:07 +0530 Subject: [PATCH 04/13] remove redundant comments --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- src/relax/op/nn/nn.cc | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8eaedb0b7d39..c98d545995de 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -118,7 +118,7 @@ def _instance_norm(self, node: fx.Node) -> relax.Var: epsilon=eps, ) ) - # return self.block_builder.emit_te(tvm.topi.nn.instance_norm,x,weight,bias,[0,1],eps) + def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: # This method is called for batch_norm in training mode # TODO does not have correctness! diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 331326244ee4..b719805d8c0e 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -688,8 +688,6 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx << input_sinfo[2]->ndim); } - //The most tricky part: check if the dimension of gamma/beta matches the axis shape of data - return data_sinfo; } From 055c9705e87fe387fb0319e0caed16001510b3af Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 24 Apr 2025 10:51:46 +0530 Subject: [PATCH 05/13] remove unused declaration --- src/relax/op/nn/nn.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b719805d8c0e..01d528cc7bb6 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -645,7 +645,6 @@ TVM_REGISTER_GLOBAL("relax.op.nn.instance_norm").set_body_typed(instance_norm); StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); - const auto* attrs = call->attrs.as(); const TensorStructInfo& data_sinfo = input_sinfo[0]; // Check dtype: must be float/bfloat From 295e4c4a081728be2b568f53e0d7a42b2656c4d3 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Fri, 2 May 2025 12:55:40 +0530 Subject: [PATCH 06/13] Made changes in sematic logic --- include/tvm/relax/attrs/nn.h | 5 +- include/tvm/topi/nn/instance_norm.h | 87 ++++++++++++- .../torch/exported_program_translator.py | 48 +++---- .../tvm/relax/frontend/torch/fx_translator.py | 33 +++++ python/tvm/relax/op/nn/nn.py | 123 ++++++++---------- python/tvm/relax/op/op_attrs.py | 2 + python/tvm/relax/transform/legalize_ops/nn.py | 2 + python/tvm/topi/nn/instance_norm.py | 65 +++------ src/relax/op/nn/nn.cc | 94 ++++++------- src/relax/op/nn/nn.h | 3 + src/topi/nn.cc | 2 +- .../test_frontend_from_exported_program.py | 45 +++++++ tests/python/relax/test_frontend_from_fx.py | 47 +++++++ 13 files changed, 364 insertions(+), 192 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 53547a182a55..3ea9f5b1e021 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -536,13 +536,16 @@ struct GroupNormAttrs : public tvm::AttrsNode { /*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public tvm::AttrsNode { + int channel_axis; Array axes; double epsilon; bool center; bool scale; TVM_DECLARE_ATTRS(InstanceNormAttrs, "relax.attrs.InstanceNormAttrs") { - TVM_ATTR_FIELD(axes).describe("The axes along which the normalization is applied."); + TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); + TVM_ATTR_FIELD(axes).describe( + "The axes that along which the normalization is applied."); TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); TVM_ATTR_FIELD(center).describe( "Indicating if the beta offset will be added to the normalized tensor."); diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 28b1a819a8ae..9d8d7a81eb08 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -25,7 +25,6 @@ #define TVM_TOPI_NN_INSTANCE_NORM_H_ #include -#include #include #include @@ -43,6 +42,7 @@ using namespace tvm::te; * d_{axis_k} == r_k * \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where * d_{axis_k} == r_k + * \param channel_axis The axis of the channel dimension * \param axis The axis to normalize over (the axis along which mean and variance are * computed). * \param epsilon The epsilon value to avoid division by zero. @@ -50,10 +50,91 @@ using namespace tvm::te; * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ -inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, +inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,int channel_axis, const Array& axis, double epsilon, std::string name = "T_instance_norm", std::string tag = kInjective) { - return layer_norm(data, gamma, beta, axis, epsilon, name, tag); + const auto& data_type = data->dtype; + const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; + const auto& beta_type = beta.defined() ? beta->dtype : data_type; + ICHECK(data_type == gamma_type && data_type == beta_type) + << "instance_norm: data, gamma and beta must have the same type"; + ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + << "instance_norm: only support float32 and float16 for now"; + bool is_float16 = data_type == DataType::Float(16); + // sum x and x^2 + auto ndim = data->shape.size(); + ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + auto real_axis = GetRealAxis(static_cast(ndim), axis); + auto reduce_axes = MakeReduceAxes(real_axis, data); + auto target_shape = + MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); + auto func = MakeTupleSumReducer(); + + auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, + &data](const Array& indices) { + Array eval_range; + int arg_counter = 0; + int red_counter = 0; + + for (size_t i = 0; i < ndim; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + // real_axis contains i + eval_range.push_back(reduce_axes[red_counter]); + red_counter++; + } else { + eval_range.push_back(indices[arg_counter]); + arg_counter++; + } + } + auto square = [is_float16](const PrimExpr& x) { + if (is_float16) { + return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x); + } + return x * x; + }; + if (is_float16) { + return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))}, + reduce_axes, nullptr); + } else { + return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); + } + }; + + auto temp_x_x2 = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduce); + + auto temp_x = temp_x_x2[0]; + auto temp_x2 = temp_x_x2[1]; + + auto reduce_extent = make_const(data->dtype, 1); + for (int i : real_axis) { + reduce_extent *= data->shape[i]; + } + auto instance_norm_func = [&](const Array& indices) { + Array reduce_indices, non_reduce_indices; + + for (int i = 0, n = static_cast(indices.size()); i < n; ++i) { + if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { + reduce_indices.push_back(indices[i]); + } else { + non_reduce_indices.push_back(indices[i]); + } + } + Var channel; + channel = indices[channel_axis]; + auto mean = temp_x(non_reduce_indices) / reduce_extent; + auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; + auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); + if (is_float16) { + instance_norm = Cast(DataType::Float(16), instance_norm); + } + instance_norm = topi::multiply(instance_norm, gamma(channel)); + if (beta.defined()) { + instance_norm = topi::add(instance_norm, beta(channel)); + } + return instance_norm; + }; + return tvm::te::compute(data->shape, instance_norm_func, name, tag); } } // namespace nn diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c98d545995de..e211398ca967 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -99,25 +99,6 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: training=training, )[0] ) - def _instance_norm(self, node: fx.Node) -> relax.Var: - import numpy as np - - x = self.env[node.args[0]] - channel = int(self.shape_of(x)[1]) - dtype = x.struct_info.dtype - weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) - bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) - eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - - return self.block_builder.emit( - relax.op.nn.instance_norm( - data=x, - gamma=weight, - beta=bias, - axis=[0,1], # Always over channel - epsilon=eps, - ) - ) def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: # This method is called for batch_norm in training mode @@ -132,10 +113,6 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: training = False return self._batch_norm(node, training) - def _instance_norm_no_training(self, node: fx.Node) -> relax.Var: - # This method is called for batch_norm in eval mode - return self._instance_norm(node) - def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_groups = node.args[1] @@ -284,6 +261,29 @@ def _zeros(self, node: fx.Node) -> relax.Var: ) return self.block_builder.emit(relax.op.zeros(size, dtype)) + def _instance_norm(self, node: fx.Node): + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + gamma = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + beta = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + eps = node.args[4] if node.args[4] else 1e-05 + channel_axis = 1 + dim = len(self.shape_of(x)) + + return self.block_builder.emit( + relax.op.nn.instance_norm( + x, + gamma, + beta, + channel_axis=channel_axis, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + ########## Others ########## def create_convert_map( @@ -410,7 +410,6 @@ def create_convert_map( "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, - "instance_norm.default": self._instance_norm_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, @@ -428,6 +427,7 @@ def create_convert_map( self.env[node.args[1]], self.env[node.args[0]] ), "group_norm.default": self._group_norm, + "instance_norm.default": self._instance_norm, "layer_norm.default": self._layer_norm, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5f65f86a4303..d36081a7f688 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -222,6 +222,36 @@ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + def _instance_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.affine: + weight = self.params[module.weight] + bias = self.params[module.bias] + else: + import numpy as np + + dtype = x.struct_info.dtype + channel = int(self.shape_of(x)[1]) + weight = relax.const(np.ones(channel), dtype=dtype) + bias = relax.const(np.zeros(channel), dtype=dtype) + + eps = module.eps + channel_axis = 1 + dim = len(self.shape_of(x)) + + return self.block_builder.emit( + relax.op.nn.instance_norm( + x, + weight, + bias, + channel_axis=channel_axis, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -652,6 +682,9 @@ def create_convert_map( nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d_module, + nn.InstanceNorm1d: self._instance_norm, + nn.InstanceNorm2d: self._instance_norm, + nn.InstanceNorm3d: self._instance_norm, nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 72571d524bd4..bf4b02c963ef 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1725,6 +1725,61 @@ def group_norm( ) +def instance_norm( + data: Expr, + gamma: Expr, + beta: Expr, + channel_axis: int, + axes: List[int], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Instance normalization + + Parameters + ---------- + data : relax.Expr + Input to which instance_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.instance_norm( # type: ignore + data, + gamma, + beta, + channel_axis, + axes, + epsilon, + center, + scale, + ) + + def rms_norm( data: Expr, weight: Expr, @@ -2071,71 +2126,3 @@ def attention_var_len( causal_mask, window_size, ) # type: ignore - - -def instance_norm( - data: Expr, - gamma: Expr, - beta: Expr, - axis: List[int], - epsilon: float = 1e-5, - center: bool = True, - scale: bool = True, -) -> Expr: - r""" - Applies Instance Normalization to the input tensor. - - Instance Normalization is a technique commonly used in deep learning, - particularly in tasks such as style transfer, where the normalization is - applied per-instance (i.e., per-sample in a batch) rather than across a batch. - It normalizes the input tensor for each sample independently, along the - specified channel axis. - - Mathematically, for each channel :math:`i` of the input, the normalized output is computed as: - - .. math:: - - \mu_i &= \text{mean}(x[:, i, ...]) \\ - \sigma_i^2 &= \text{var}(x[:, i, ...]) \\ - \hat{x}[:, i, ...] &= \frac{x[:, i, ...] - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} - - \text{output}[:, i, ...] = \gamma[i] \cdot \hat{x}[:, i, ...] + \beta[i] - - where :math:`\gamma` and :math:`\beta` are learnable parameters of shape equal to the - number of channels. - - Parameters - ---------- - data : relax.Expr - The input tensor to be normalized. - - gamma : relax.Expr - The scale tensor (gamma) applied after normalization. Must have shape matching the number of channels. - - beta : relax.Expr - The offset tensor (beta) applied after normalization. Must have shape matching the number of channels. - - axis : List[int], default = [0, 1] - The axes to retain during normalization. Typically `[0, 1]` corresponding to batch (N) and channel (C) - dimensions. Normalization is applied over the remaining spatial axes. - - epsilon : float, optional - A small constant added to the variance to avoid division by zero. Default is 1e-5. - - center : bool, optional - If True, add `beta` to the normalized tensor. Default is True. - - scale : bool, optional - If True, multiply the normalized tensor by `gamma`. Default is True. - - Returns - ------- - result : relax.Expr - The tensor after applying instance normalization. - """ - - if isinstance(axis, int): - axis = [axis] - return _ffi_api.instance_norm( # type: ignore - data, gamma, beta, axis, epsilon, center, scale, - ) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index c2032af30933..24eece70a941 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -93,10 +93,12 @@ class BatchNormAttrs(Attrs): class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" + @tvm._ffi.register_object("relax.attrs.InstanceNormAttrs") class InstanceNormAttrs(Attrs): """Attributes used in instance_norm operator""" + @tvm._ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for dropout operator""" diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 6bf2a976b8d8..ed9802fc9e63 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -633,6 +633,7 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: call.attrs.epsilon, ) + @register_legalize("relax.nn.instance_norm") def _nn_instance_norm(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( @@ -640,6 +641,7 @@ def _nn_instance_norm(bb: BlockBuilder, call: Call) -> Expr: data=call.args[0], gamma=call.args[1], beta=call.args[2], + channel_axis=call.attrs.channel_axis, axis=call.attrs.axes, epsilon=call.attrs.epsilon, ) diff --git a/python/tvm/topi/nn/instance_norm.py b/python/tvm/topi/nn/instance_norm.py index eae4b50e4051..a64cd2d80cb4 100644 --- a/python/tvm/topi/nn/instance_norm.py +++ b/python/tvm/topi/nn/instance_norm.py @@ -15,66 +15,33 @@ # specific language governing permissions and limitations # under the License. """Instance normalization operator.""" -from tvm import te -from tvm import topi -from functools import reduce -from typing import Union,List +from .. import cpp -def instance_norm( - data: te.Tensor, - gamma: te.Tensor, - beta: te.Tensor, - axis: Union[int, List[int]] = [0, 1], - epsilon: float = 1e-5, -) -> te.Tensor: - """Instance normalization over spatial dimensions. - Normalizes each instance in a batch independently per channel, - typically used in style transfer and vision models. +def instance_norm(data, gamma, beta, channel_axis, axis, epsilon=1e-5): + """Instance normalization operator. Parameters ---------- - data : te.Tensor - Input tensor with shape [N, C, H, W]. + data : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) - gamma : te.Tensor - Scale tensor of shape [C]. + gamma: tvm.te.Tensor + K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k - beta : te.Tensor - Offset tensor of shape [C]. + beta: tvm.te.Tensor + Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k - axis : int or list of int, default=[0, 1] - Axes to preserve (typically N and C). Reduction happens over the rest. + axis : list of int + Axis over the normalization applied (the axis along which the mean and variance are + computed) epsilon : float - Small value added to variance to avoid divide-by-zero. + The epsilon value to avoid division by zero. Returns ------- - out : te.Tensor - Instance-normalized tensor with same shape as input. + result : tvm.te.Tensor + N-D with shape (d_0, d_1, ..., d_{N-1}) """ - if isinstance(axis, int): - axis = [axis] - - shape = [1] * len(data.shape) - for ax in axis: - print(type(int(ax))) - shape[int(ax)] = data.shape[int(ax)] - - reduce_axes = [i for i in range(len(data.shape)) if i not in axis] - shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) - - mean = topi.sum(data, axis=reduce_axes) / shape_prod - mean_rs = topi.reshape(mean, shape) - - var = topi.sum(topi.power(data - mean_rs, 2), axis=reduce_axes) / shape_prod - var_rs = topi.reshape(var, shape) - - gamma_rs = topi.reshape(gamma, shape) - beta_rs = topi.reshape(beta, shape) - - normalized = (data - mean_rs) / topi.sqrt(var_rs + epsilon) - out = normalized * gamma_rs + beta_rs - - return out \ No newline at end of file + return cpp.nn.instance_norm(data, gamma, beta, channel_axis, axis, epsilon) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 01d528cc7bb6..17405343a529 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -21,7 +21,6 @@ #include #include -#include "tvm/relax/attrs/nn.h" namespace tvm { namespace relax { @@ -416,7 +415,6 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call, } const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; - LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); // While dealing with sub layouts, its adviced to deal with batchnorm @@ -628,9 +626,10 @@ TVM_REGISTER_OP("relax.nn.group_norm") /* relax.nn.instance_norm */ TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); -Expr instance_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, +Expr instance_norm(Expr data, Expr gamma, Expr beta,int channel_axis, Array axes, double epsilon, bool center, bool scale) { ObjectPtr attrs = make_object(); + attrs->channel_axis = std::move(channel_axis); attrs->axes = std::move(axes); attrs->epsilon = epsilon; attrs->center = center; @@ -643,51 +642,49 @@ Expr instance_norm(Expr data, Expr gamma, Expr beta, Array axes, double TVM_REGISTER_GLOBAL("relax.op.nn.instance_norm").set_body_typed(instance_norm); StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + TensorStructInfo data_sinfo = input_sinfo[0]; - const TensorStructInfo& data_sinfo = input_sinfo[0]; - - // Check dtype: must be float/bfloat - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() && - !data_sinfo->dtype.is_bfloat()) { - ctx->ReportFatal(Diagnostic::Error(call) - << "InstanceNorm requires the input data to have float or bfloat dtype. " - "However, the given data dtype is " - << data_sinfo->dtype); - } - - // Check gamma and beta shapes and dtypes. - if (input_sinfo.size() > 1 && !input_sinfo[1]->IsUnknownDtype() && - input_sinfo[1]->dtype != data_sinfo->dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << "InstanceNorm requires gamma to have the same dtype as data. " - "However, the given gamma dtype is " - << input_sinfo[1]->dtype); - } - - if (input_sinfo.size() > 2 && !input_sinfo[2]->IsUnknownDtype() && - input_sinfo[2]->dtype != data_sinfo->dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << "InstanceNorm requires beta to have the same dtype as data. " - "However, the given beta dtype is " - << input_sinfo[2]->dtype); - } - - if (input_sinfo.size() > 1 && !input_sinfo[1]->IsUnknownNdim() && input_sinfo[1]->ndim != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "InstanceNorm requires gamma to have ndim=1. " - "However, the given gamma ndim is " - << input_sinfo[1]->ndim); - } - - if (input_sinfo.size() > 2 && !input_sinfo[2]->IsUnknownNdim() && input_sinfo[2]->ndim != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "InstanceNorm requires beta to have ndim=1. " - "However, the given beta ndim is " - << input_sinfo[2]->ndim); + int channel_axis = -1; + if (!data_sinfo->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + // channel_axis must not be in axes. + if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " expects that channel_axis must not be in axes, but got channel_axis: " + << channel_axis << ", axes: " << attrs->axes); + } } - - return data_sinfo; + const auto* data_shape = data_sinfo->shape.as(); + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); + } + } + } + return data_sinfo; } InferLayoutOutput InferLayoutInstanceNorm(const Call& call, @@ -705,8 +702,13 @@ InferLayoutOutput InferLayoutInstanceNorm(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); - ObjectPtr new_attrs = make_object(*attrs); + std::vector new_axes; + for (const auto& axis : attrs->axes) { + new_axes.push_back(FindAxis(layout->layout, (axis->value))); + } + new_attrs->axes = std::move(new_axes); + new_attrs->channel_axis = FindAxis(layout->layout, attrs->channel_axis); return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, Attrs(new_attrs)); } diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 018741430199..0510859010e6 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -90,6 +90,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, Array axes, double epsilon, bool center, bool scale); +/*! \brief Compute instance normalization. */ +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, double epsilon, bool center, bool scale); + /*! \brief Compute root mean square normalization. */ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 09859e331807..f588a59efe4c 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -179,7 +179,7 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, TVMRetValue* /* Ops from nn/instance_norm.h */ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast(args[4])); + *rv = nn::instance_norm(args[0], args[1], args[2],args[3], args[4], static_cast(args[5])); }); /* Ops from nn/rms_norm.h */ diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f0bb33964ef2..72322e445f1d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2180,6 +2180,51 @@ def main( verify_model(model, example_args, binding, expected1) +def test_instancenorm2d(): + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class InstanceNorm2d(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.InstanceNorm2d(3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.instance_norm( + input_1, + w1, + w2, + channel_axis=1, + axes=[2, 3], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = InstanceNorm2d() + binding = { + "w1": torch.ones(3).detach().numpy(), + "w2": torch.zeros(3).detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + def test_layernorm(): class LayerNorm(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 490a2309aa37..51f453a99c88 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1683,6 +1683,53 @@ def main( verify_model(model, input_info, binding, expected1) +def test_instancenorm2d(): + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class InstanceNorm2d(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.InstanceNorm2d(3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.instance_norm( + input_1, + w1, + w2, + channel_axis=1, + axes=[2, 3], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = InstanceNorm2d() + binding = { + "w1": torch.ones(3).detach().numpy(), + "w2": torch.zeros(3).detach().numpy(), + } + verify_model(model, input_info, binding, expected1) + + operator_binary_1 = [ (operator.add, R.add), (operator.sub, R.subtract), From f0baec556193bbf0bd26a6b1c7038c5a708a104f Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 6 May 2025 16:08:58 +0530 Subject: [PATCH 07/13] fix lint issues --- include/tvm/topi/nn/instance_norm.h | 4 ++-- src/relax/op/nn/nn.cc | 4 ++-- src/relax/op/nn/nn.h | 3 ++- src/topi/nn.cc | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 9d8d7a81eb08..d400721215ec 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -50,8 +50,8 @@ using namespace tvm::te; * \param tag The tag to mark the operation. * \return The normalized tensor, with the same shape as data. */ -inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,int channel_axis, - const Array& axis, double epsilon, +inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta, + int channel_axis, const Array& axis, double epsilon, std::string name = "T_instance_norm", std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 17405343a529..245c5b3ef6ae 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -626,8 +626,8 @@ TVM_REGISTER_OP("relax.nn.group_norm") /* relax.nn.instance_norm */ TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); -Expr instance_norm(Expr data, Expr gamma, Expr beta,int channel_axis, Array axes, double epsilon, bool center, - bool scale) { +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, + Array axes, double epsilon, bool center, bool scale) { ObjectPtr attrs = make_object(); attrs->channel_axis = std::move(channel_axis); attrs->axes = std::move(axes); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 0510859010e6..ab14bc2e5086 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -91,7 +91,8 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax Array axes, double epsilon, bool center, bool scale); /*! \brief Compute instance normalization. */ -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, double epsilon, bool center, bool scale); +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, + Array axes, double epsilon, bool center, bool scale); /*! \brief Compute root mean square normalization. */ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index f588a59efe4c..3bac1d5566d9 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -179,7 +179,8 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, TVMRetValue* /* Ops from nn/instance_norm.h */ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::instance_norm(args[0], args[1], args[2],args[3], args[4], static_cast(args[5])); + *rv = nn::instance_norm(args[0], args[1], args[2], args[3], + args[4], static_cast(args[5])); }); /* Ops from nn/rms_norm.h */ From 29fda763d39d1e96f22bb5dfc3b594a223c396dc Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 6 May 2025 16:44:24 +0530 Subject: [PATCH 08/13] fix whitespace issues --- include/tvm/relax/attrs/nn.h | 3 +-- src/relax/op/nn/nn.cc | 46 ++++++++++++++++++------------------ src/relax/op/nn/nn.h | 4 ++-- src/topi/nn.cc | 4 ++-- 4 files changed, 28 insertions(+), 29 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 3ea9f5b1e021..b5027067fa6d 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -544,8 +544,7 @@ struct InstanceNormAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(InstanceNormAttrs, "relax.attrs.InstanceNormAttrs") { TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); - TVM_ATTR_FIELD(axes).describe( - "The axes that along which the normalization is applied."); + TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization is applied."); TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); TVM_ATTR_FIELD(center).describe( "Indicating if the beta offset will be added to the normalized tensor."); diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 245c5b3ef6ae..097ddcba28a5 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -626,8 +626,8 @@ TVM_REGISTER_OP("relax.nn.group_norm") /* relax.nn.instance_norm */ TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, - Array axes, double epsilon, bool center, bool scale) { +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, + double epsilon, bool center, bool scale) { ObjectPtr attrs = make_object(); attrs->channel_axis = std::move(channel_axis); attrs->axes = std::move(axes); @@ -663,33 +663,33 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx const auto* data_shape = data_sinfo->shape.as(); arith::Analyzer* analyzer = ctx->GetAnalyzer(); for (int i = 1; i < static_cast(op->arguments.size()); ++i) { - if (input_sinfo[i]->dtype != data_sinfo->dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << op << " expects that all inputs must have the same dtype, but got " - << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); - } else if (input_sinfo[i]->ndim != 1) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { ctx->ReportFatal(Diagnostic::Error(call) - << op << " expects that all inputs must have ndim=1, but got " - << input_sinfo[i]->ndim); - } - const auto* shape = input_sinfo[i]->shape.as(); - if (shape != nullptr && data_shape != nullptr) { - PrimExpr channel_size = data_shape->values[channel_axis]; - PrimExpr input_size = shape->values[0]; - if (analyzer->CanProve(channel_size != input_size)) { - ctx->ReportFatal(Diagnostic::Error(call) - << op << " expects that the size of input " << i - << " must be equal to the size of channel_axis, but got " << input_size - << " and " << channel_size); - } + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); } } - return data_sinfo; + } + return data_sinfo; } InferLayoutOutput InferLayoutInstanceNorm(const Call& call, - const Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index ab14bc2e5086..39f8c2d73800 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -91,8 +91,8 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax Array axes, double epsilon, bool center, bool scale); /*! \brief Compute instance normalization. */ -Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, - Array axes, double epsilon, bool center, bool scale); +Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array axes, + double epsilon, bool center, bool scale); /*! \brief Compute root mean square normalization. */ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 3bac1d5566d9..f846e95f95f8 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -179,8 +179,8 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, TVMRetValue* /* Ops from nn/instance_norm.h */ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = nn::instance_norm(args[0], args[1], args[2], args[3], - args[4], static_cast(args[5])); + *rv = + nn::instance_norm(args[0], args[1], args[2], args[3], args[4], static_cast(args[5])); }); /* Ops from nn/rms_norm.h */ From f6ce3dab8cc5b095f5ababab88f62074b7cb1074 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Wed, 7 May 2025 10:35:34 +0530 Subject: [PATCH 09/13] Fix whitespace issue --- src/topi/nn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/topi/nn.cc b/src/topi/nn.cc index e286b8b17975..81d8fbb6b378 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -206,7 +206,7 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](TVMArgs args, TVMRe TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), - args[4].cast>(),args[5].cast()); + args[4].cast>(), args[5].cast()); }); /* Ops from nn/rms_norm.h */ From fa40782902905888c5d59a3ac97b87ff36698d54 Mon Sep 17 00:00:00 2001 From: kavin-mcw Date: Thu, 8 May 2025 07:55:23 +0530 Subject: [PATCH 10/13] ffi change in nn.cc --- src/topi/nn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/topi/nn.cc b/src/topi/nn.cc index 997d56adb04a..b8768de631d2 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -211,7 +211,7 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](ffi::PackedArgs arg }); /* Ops from nn/instance_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body_packed([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast>(), args[5].cast()); From b6247b84d068729d0e1dcbf77690524b21110ec9 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Thu, 8 May 2025 09:12:34 +0530 Subject: [PATCH 11/13] Fix lint issue --- src/topi/nn.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/topi/nn.cc b/src/topi/nn.cc index b8768de631d2..b9eeef74d778 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -211,11 +211,12 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](ffi::PackedArgs arg }); /* Ops from nn/instance_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::instance_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast(), - args[4].cast>(), args[5].cast()); -}); +TVM_REGISTER_GLOBAL("topi.nn.instance_norm") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::instance_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast(), + args[4].cast>(), args[5].cast()); + }); /* Ops from nn/rms_norm.h */ TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { From e422853d4280b93c2bdabbbd1ad2c98c42975801 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Wed, 4 Jun 2025 11:47:11 +0530 Subject: [PATCH 12/13] merge main --- .github/actions/setup/action.yml | 15 +- 3rdparty/cutlass_fpA_intB_gemm | 2 +- .../app/src/main/jni/tvm_runtime.h | 4 +- apps/cpp_rpc/rpc_env.cc | 4 +- apps/cpp_rpc/rpc_env.h | 2 +- apps/cpp_rpc/rpc_server.cc | 4 +- apps/cpp_rpc/rpc_server.h | 2 +- apps/hexagon_launcher/launcher_core.cc | 3 +- apps/hexagon_launcher/launcher_core.h | 2 +- apps/ios_rpc/tvmrpc/RPCServer.mm | 3 +- apps/ios_rpc/tvmrpc/TVMRuntime.mm | 8 +- ci/jenkins/docker-images.ini | 14 +- ci/jenkins/generated/cpu_jenkinsfile.groovy | 7 +- .../templates/cpu_jenkinsfile.groovy.j2 | 4 - ci/jenkins/unity_jenkinsfile.groovy | 4 +- cmake/modules/CUDA.cmake | 1 + conda/build-environment.yaml | 2 - docker/Dockerfile.ci_cpu | 4 + docker/Dockerfile.ci_gpu | 3 + docker/README.md | 6 - .../install/ubuntu_install_nnef.sh | 13 +- docker/python/ci-constraints.txt | 1 + docs/arch/device_target_interactions.rst | 10 +- docs/arch/pass_infra.rst | 4 +- docs/arch/runtime.rst | 10 +- ffi/include/tvm/ffi/any.h | 147 +- ffi/include/tvm/ffi/base_details.h | 6 +- ffi/include/tvm/ffi/c_api.h | 18 +- ffi/include/tvm/ffi/cast.h | 5 + ffi/include/tvm/ffi/container/array.h | 42 +- .../tvm/ffi/container/container_details.h | 18 + ffi/include/tvm/ffi/container/map.h | 43 +- ffi/include/tvm/ffi/container/tuple.h | 42 +- ffi/include/tvm/ffi/container/variant.h | 118 +- ffi/include/tvm/ffi/dtype.h | 10 +- ffi/include/tvm/ffi/error.h | 10 +- ffi/include/tvm/ffi/function.h | 61 +- ffi/include/tvm/ffi/function_details.h | 2 +- ffi/include/tvm/ffi/memory.h | 4 + ffi/include/tvm/ffi/optional.h | 3 + ffi/include/tvm/ffi/rvalue_ref.h | 9 +- ffi/include/tvm/ffi/string.h | 14 +- ffi/include/tvm/ffi/type_traits.h | 152 +- ffi/scripts/benchmark_dlpack.py | 345 ++++ ffi/src/ffi/dtype.cc | 4 +- ffi/src/ffi/ndarray.cc | 2 +- ffi/tests/cpp/test_any.cc | 34 + ffi/tests/cpp/test_dtype.cc | 4 +- ffi/tests/cpp/test_map.cc | 2 +- ffi/tests/cpp/test_string.cc | 31 +- ffi/tests/cpp/test_tuple.cc | 28 + ffi/tests/cpp/test_variant.cc | 28 +- golang/Makefile | 81 - golang/README.md | 126 -- golang/sample/Makefile | 34 - golang/sample/complex.go | 189 -- golang/sample/pack_func_closure_arg.go | 75 - golang/sample/pack_func_closure_return.go | 75 - golang/sample/pack_func_convert.go | 62 - golang/sample/pack_func_handle_arg.go | 78 - golang/sample/pack_func_register.go | 81 - golang/sample/simple.go | 90 - golang/src/array_test.go | 614 ------- golang/src/bytearray.go | 90 - golang/src/bytearray_test.go | 50 - golang/src/device.go | 93 - golang/src/error.go | 49 - golang/src/error_test.go | 45 - golang/src/function.go | 383 ----- golang/src/function_test.go | 349 ---- golang/src/gotvm.cc | 207 --- golang/src/gotvm.h | 60 - golang/src/gotvm_test.go | 48 - golang/src/module.go | 139 -- golang/src/module_test.go | 110 -- golang/src/ndarray.go | 347 ---- golang/src/tvm_runtime_pack.cc | 71 - golang/src/type.go | 90 - golang/src/utils.go | 42 - golang/src/value.go | 378 ---- golang/src/value_test.go | 255 --- include/tvm/arith/int_set.h | 6 +- include/tvm/ir/analysis.h | 2 +- include/tvm/ir/attrs.h | 14 +- include/tvm/ir/env_func.h | 6 +- include/tvm/ir/expr.h | 4 +- include/tvm/ir/function.h | 6 +- include/tvm/ir/instrument.h | 2 +- include/tvm/ir/module.h | 6 +- include/tvm/ir/op.h | 2 +- include/tvm/ir/source_map.h | 3 +- include/tvm/ir/transform.h | 12 +- include/tvm/ir/type.h | 2 +- include/tvm/meta_schedule/arg_info.h | 6 +- include/tvm/meta_schedule/builder.h | 12 +- include/tvm/meta_schedule/cost_model.h | 6 +- include/tvm/meta_schedule/database.h | 18 +- include/tvm/meta_schedule/extracted_task.h | 4 +- include/tvm/meta_schedule/feature_extractor.h | 6 +- include/tvm/meta_schedule/measure_callback.h | 6 +- include/tvm/meta_schedule/measure_candidate.h | 2 +- include/tvm/meta_schedule/mutator.h | 4 +- include/tvm/meta_schedule/postproc.h | 2 +- include/tvm/meta_schedule/profiler.h | 8 +- include/tvm/meta_schedule/runner.h | 8 +- include/tvm/meta_schedule/schedule_rule.h | 59 +- include/tvm/meta_schedule/search_strategy.h | 6 +- include/tvm/meta_schedule/space_generator.h | 4 +- include/tvm/meta_schedule/task_scheduler.h | 12 +- include/tvm/meta_schedule/tune_context.h | 10 +- include/tvm/node/attr_registry_map.h | 2 +- include/tvm/node/node.h | 5 +- include/tvm/node/object_path.h | 8 +- include/tvm/node/reflection.h | 22 +- include/tvm/node/script_printer.h | 11 +- include/tvm/node/serialization.h | 2 +- include/tvm/node/structural_equal.h | 36 +- include/tvm/relax/analysis.h | 2 +- include/tvm/relax/attrs/manipulate.h | 9 + include/tvm/relax/block_builder.h | 2 +- include/tvm/relax/dataflow_matcher.h | 10 +- include/tvm/relax/dataflow_pattern.h | 10 +- include/tvm/relax/distributed/global_info.h | 6 +- include/tvm/relax/exec_builder.h | 2 +- include/tvm/relax/expr.h | 8 +- include/tvm/relax/expr_functor.h | 6 +- include/tvm/relax/nested_msg.h | 10 +- include/tvm/relax/struct_info.h | 12 +- include/tvm/relax/transform.h | 16 +- include/tvm/relax/type.h | 2 +- .../gotvm.go => include/tvm/runtime/base.h | 51 +- include/tvm/runtime/builtin_fp16.h | 2 +- include/tvm/runtime/c_backend_api.h | 34 +- include/tvm/runtime/c_runtime_api.h | 732 -------- include/tvm/runtime/container/array.h | 41 - include/tvm/runtime/container/base.h | 278 --- include/tvm/runtime/container/map.h | 40 - include/tvm/runtime/container/optional.h | 39 - include/tvm/runtime/container/shape_tuple.h | 53 - include/tvm/runtime/container/string.h | 41 - include/tvm/runtime/container/variant.h | 39 - include/tvm/runtime/contrib/papi.h | 4 +- include/tvm/runtime/data_type.h | 18 +- include/tvm/runtime/device_api.h | 42 +- include/tvm/runtime/disco/builtin.h | 2 +- include/tvm/runtime/disco/cuda_ipc_memory.h | 2 +- include/tvm/runtime/disco/disco_worker.h | 2 +- include/tvm/runtime/disco/session.h | 9 +- include/tvm/runtime/{memory.h => int_tuple.h} | 19 +- include/tvm/runtime/logging.h | 2 +- include/tvm/runtime/memory/memory_manager.h | 14 +- include/tvm/runtime/module.h | 85 +- include/tvm/runtime/ndarray.h | 66 +- include/tvm/runtime/nvtx.h | 2 +- include/tvm/runtime/object.h | 15 +- include/tvm/runtime/packed_func.h | 417 +---- include/tvm/runtime/profiling.h | 13 +- include/tvm/runtime/registry.h | 102 -- include/tvm/runtime/relax_vm/executable.h | 4 +- .../runtime/relax_vm/ndarray_cache_support.h | 6 +- include/tvm/runtime/relax_vm/vm.h | 2 +- include/tvm/runtime/serializer.h | 2 +- include/tvm/script/ir_builder/base.h | 10 +- include/tvm/script/ir_builder/relax/frame.h | 2 +- include/tvm/script/ir_builder/relax/ir.h | 2 +- include/tvm/script/ir_builder/tir/ir.h | 39 +- include/tvm/script/printer/doc.h | 16 +- include/tvm/script/printer/ir_docsifier.h | 2 +- .../tvm/script/printer/ir_docsifier_functor.h | 2 +- include/tvm/support/parallel_for.h | 2 +- include/tvm/support/span.h | 109 -- include/tvm/target/codegen.h | 2 +- include/tvm/target/target.h | 4 +- include/tvm/tir/analysis.h | 4 +- include/tvm/tir/buffer.h | 10 +- include/tvm/tir/builtin.h | 20 +- include/tvm/tir/expr.h | 10 +- include/tvm/tir/function.h | 4 +- include/tvm/tir/index_map.h | 6 +- include/tvm/tir/op_attr_types.h | 4 +- include/tvm/tir/schedule/schedule.h | 29 +- include/tvm/tir/schedule/trace.h | 2 +- include/tvm/tir/stmt.h | 16 +- include/tvm/tir/stmt_functor.h | 14 +- include/tvm/tir/transform.h | 4 +- include/tvm/topi/nn/pooling.h | 10 +- include/tvm/topi/utils.h | 4 +- jvm/README.md | 2 +- .../src/main/java/org/apache/tvm/Base.java | 35 +- .../src/main/java/org/apache/tvm/Device.java | 31 +- .../main/java/org/apache/tvm/Function.java | 165 +- .../src/main/java/org/apache/tvm/LibInfo.java | 57 +- .../src/main/java/org/apache/tvm/Module.java | 43 +- .../src/main/java/org/apache/tvm/NDArray.java | 42 +- .../main/java/org/apache/tvm/NDArrayBase.java | 51 +- .../tvm/{ArgTypeCode.java => TVMObject.java} | 27 +- .../main/java/org/apache/tvm/TVMValue.java | 4 +- .../java/org/apache/tvm/TVMValueBytes.java | 1 - .../java/org/apache/tvm/TVMValueDouble.java | 1 - .../java/org/apache/tvm/TVMValueHandle.java | 1 - .../java/org/apache/tvm/TVMValueLong.java | 1 - .../java/org/apache/tvm/TVMValueNull.java | 1 - .../java/org/apache/tvm/TVMValueString.java | 1 - .../main/java/org/apache/tvm/TypeIndex.java | 44 + .../main/java/org/apache/tvm/rpc/Client.java | 3 + .../java/org/apache/tvm/rpc/RPCSession.java | 7 +- .../java/org/apache/tvm/FunctionTest.java | 2 + .../test/java/org/apache/tvm/ModuleTest.java | 3 - .../test/java/org/apache/tvm/rpc/RPCTest.java | 2 + .../src/test/scripts/prepare_test_libs.py | 83 + jvm/native/linux-x86_64/pom.xml | 2 + jvm/native/osx-x86_64/pom.xml | 2 + jvm/native/src/main/native/jni_helper_func.h | 111 +- .../native/org_apache_tvm_native_c_api.cc | 442 ++--- python/setup.py | 16 +- python/tvm/__init__.py | 5 +- python/tvm/_ffi/__init__.py | 31 - python/tvm/_ffi/_pyversion.py | 26 - python/tvm/_ffi/registry.py | 29 - python/tvm/arith/_ffi_api.py | 4 +- python/tvm/arith/analyzer.py | 6 +- python/tvm/arith/int_set.py | 6 +- python/tvm/arith/int_solver.py | 8 +- python/tvm/arith/iter_affine_map.py | 10 +- python/tvm/{_ffi => }/base.py | 17 +- python/tvm/contrib/cc.py | 2 +- python/tvm/contrib/clang.py | 2 +- python/tvm/contrib/coreml_runtime.py | 4 +- python/tvm/contrib/cudnn.py | 4 +- python/tvm/contrib/cutlass/_ffi_api.py | 4 +- python/tvm/contrib/cutlass/build.py | 2 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 6 +- python/tvm/contrib/emcc.py | 4 +- python/tvm/contrib/hexagon/build.py | 2 +- python/tvm/contrib/hexagon/tools.py | 2 +- python/tvm/contrib/miopen.py | 4 +- python/tvm/contrib/mrvl.py | 22 +- python/tvm/contrib/msc/core/_ffi_api.py | 4 +- python/tvm/contrib/msc/core/ir/graph.py | 12 +- .../msc/framework/tensorflow/_ffi_api.py | 4 +- .../msc/framework/tensorrt/_ffi_api.py | 4 +- .../contrib/msc/framework/torch/_ffi_api.py | 4 +- .../tvm/contrib/msc/framework/tvm/_ffi_api.py | 4 +- python/tvm/contrib/msc/plugin/_ffi_api.py | 4 +- python/tvm/contrib/msc/plugin/op/_ffi_api.py | 4 +- python/tvm/contrib/ndk.py | 4 +- python/tvm/contrib/nnpack.py | 4 +- python/tvm/contrib/nvcc.py | 16 +- python/tvm/contrib/pickle_memoize.py | 3 +- python/tvm/contrib/random.py | 4 +- python/tvm/contrib/rocm.py | 10 +- python/tvm/contrib/spirv.py | 2 +- python/tvm/contrib/tar.py | 2 +- python/tvm/contrib/tflite_runtime.py | 4 +- python/tvm/contrib/thrust.py | 2 +- python/tvm/contrib/tvmjs.py | 2 +- python/tvm/contrib/xcode.py | 2 +- python/tvm/dlight/analysis/common_analysis.py | 2 +- python/tvm/dlight/gpu/general_reduction.py | 17 + python/tvm/driver/_ffi_api.py | 4 +- python/tvm/exec/disco_worker.py | 2 +- python/tvm/ffi/convert.py | 5 + python/tvm/ffi/cython/base.pxi | 2 +- python/tvm/ffi/cython/dtype.pxi | 2 +- python/tvm/ffi/cython/function.pxi | 16 + python/tvm/ffi/cython/ndarray.pxi | 2 + python/tvm/generic.py | 19 - python/tvm/ir/_ffi_analysis_api.py | 4 +- python/tvm/ir/_ffi_api.py | 4 +- python/tvm/ir/_ffi_instrument_api.py | 4 +- python/tvm/ir/_ffi_transform_api.py | 4 +- python/tvm/ir/attrs.py | 6 +- python/tvm/ir/base.py | 4 +- python/tvm/ir/diagnostics/__init__.py | 8 +- python/tvm/ir/diagnostics/_ffi_api.py | 4 +- python/tvm/ir/expr.py | 6 +- python/tvm/ir/instrument.py | 6 +- python/tvm/ir/module.py | 11 +- python/tvm/ir/op.py | 4 +- python/tvm/ir/supply.py | 4 +- python/tvm/ir/transform.py | 12 +- python/tvm/ir/type.py | 10 +- python/tvm/ir/type_relation.py | 6 +- python/tvm/{_ffi => }/libinfo.py | 6 +- python/tvm/meta_schedule/_ffi_api.py | 2 +- python/tvm/meta_schedule/arg_info.py | 2 +- python/tvm/meta_schedule/builder/builder.py | 2 +- .../meta_schedule/builder/local_builder.py | 2 +- .../meta_schedule/cost_model/cost_model.py | 2 +- .../tvm/meta_schedule/cost_model/mlp_model.py | 2 +- .../tvm/meta_schedule/cost_model/xgb_model.py | 18 +- python/tvm/meta_schedule/database/database.py | 2 +- .../meta_schedule/database/json_database.py | 2 +- .../meta_schedule/database/memory_database.py | 2 +- .../database/ordered_union_database.py | 2 +- .../database/schedule_fn_database.py | 2 +- .../meta_schedule/database/union_database.py | 2 +- python/tvm/meta_schedule/extracted_task.py | 2 +- .../feature_extractor/feature_extractor.py | 2 +- .../feature_extractor/per_store_feature.py | 2 +- .../measure_callback/add_to_database.py | 2 +- .../measure_callback/measure_callback.py | 2 +- .../measure_callback/remove_build_artifact.py | 2 +- .../measure_callback/update_cost_model.py | 2 +- .../mutator/mutate_compute_location.py | 2 +- .../meta_schedule/mutator/mutate_parallel.py | 2 +- .../mutator/mutate_thread_binding.py | 2 +- .../meta_schedule/mutator/mutate_tile_size.py | 2 +- .../meta_schedule/mutator/mutate_unroll.py | 2 +- python/tvm/meta_schedule/mutator/mutator.py | 4 +- .../disallow_async_strided_mem_copy.py | 2 +- .../postproc/disallow_dynamic_loop.py | 2 +- python/tvm/meta_schedule/postproc/postproc.py | 2 +- .../postproc/rewrite_cooperative_fetch.py | 2 +- .../meta_schedule/postproc/rewrite_layout.py | 2 +- .../rewrite_parallel_vectorize_unroll.py | 2 +- .../postproc/rewrite_reduction_block.py | 2 +- .../postproc/rewrite_tensorize.py | 2 +- .../postproc/rewrite_unbound_block.py | 2 +- .../meta_schedule/postproc/verify_gpu_code.py | 2 +- .../postproc/verify_vtcm_limit.py | 2 +- python/tvm/meta_schedule/profiler.py | 2 +- python/tvm/meta_schedule/relax_integration.py | 2 +- python/tvm/meta_schedule/runner/runner.py | 2 +- .../schedule_rule/add_rfactor.py | 2 +- .../schedule_rule/apply_custom_rule.py | 2 +- .../meta_schedule/schedule_rule/auto_bind.py | 2 +- .../schedule_rule/auto_inline.py | 2 +- .../schedule_rule/cross_thread_reduction.py | 2 +- .../schedule_rule/multi_level_tiling.py | 2 +- .../parallel_vectorize_unroll.py | 2 +- .../schedule_rule/random_compute_location.py | 2 +- .../schedule_rule/schedule_rule.py | 2 +- .../search_strategy/evolutionary_search.py | 2 +- .../search_strategy/replay_func.py | 2 +- .../search_strategy/replay_trace.py | 2 +- .../search_strategy/search_strategy.py | 2 +- .../space_generator/post_order_apply.py | 2 +- .../space_generator/schedule_fn.py | 2 +- .../space_generator/space_generator.py | 2 +- .../space_generator/space_generator_union.py | 2 +- .../task_scheduler/gradient_based.py | 2 +- .../task_scheduler/round_robin.py | 2 +- .../task_scheduler/task_scheduler.py | 2 +- .../testing/validate_database.py | 8 +- python/tvm/meta_schedule/tir_integration.py | 2 +- python/tvm/meta_schedule/tune_context.py | 2 +- python/tvm/meta_schedule/utils.py | 2 +- python/tvm/relax/_ffi_api.py | 4 +- python/tvm/relax/analysis/_ffi_api.py | 4 +- python/tvm/relax/backend/_ffi_api.py | 4 +- python/tvm/relax/backend/metal/coreml.py | 6 +- python/tvm/relax/binding_rewrite.py | 4 +- python/tvm/relax/block_builder.py | 2 +- python/tvm/relax/distributed/_ffi_api.py | 4 +- python/tvm/relax/distributed/global_info.py | 2 +- python/tvm/relax/distributed/struct_info.py | 6 +- .../relax/distributed/transform/_ffi_api.py | 4 +- python/tvm/relax/dpl/_ffi.py | 4 +- python/tvm/relax/dpl/pattern.py | 2 +- python/tvm/relax/dpl/rewrite.py | 2 +- python/tvm/relax/exec_builder.py | 2 +- python/tvm/relax/expr.py | 45 +- python/tvm/relax/expr_functor.py | 6 +- python/tvm/relax/frontend/nn/extern.py | 6 +- python/tvm/relax/frontend/nn/op.py | 24 +- .../tvm/relax/frontend/onnx/onnx_frontend.py | 21 +- .../torch/base_fx_graph_translator.py | 289 +++- .../torch/exported_program_translator.py | 47 +- .../tvm/relax/frontend/torch/fx_translator.py | 122 +- python/tvm/relax/op/__init__.py | 3 +- python/tvm/relax/op/_ffi_api.py | 4 +- python/tvm/relax/op/_op_gradient.py | 4 +- python/tvm/relax/op/builtin/_ffi_api.py | 4 +- python/tvm/relax/op/ccl/_ffi_api.py | 4 +- python/tvm/relax/op/distributed/_ffi_api.py | 4 +- python/tvm/relax/op/grad/_ffi_api.py | 4 +- python/tvm/relax/op/image/_ffi_api.py | 4 +- python/tvm/relax/op/image/image.py | 2 +- python/tvm/relax/op/linear_algebra.py | 27 + python/tvm/relax/op/manipulate.py | 38 + python/tvm/relax/op/memory/_ffi_api.py | 4 +- python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/_ffi_api.py | 4 +- python/tvm/relax/op/nn/nn.py | 25 +- python/tvm/relax/op/op_attrs.py | 74 +- python/tvm/relax/op/vm/_ffi_api.py | 4 +- python/tvm/relax/struct_info.py | 14 +- python/tvm/relax/testing/transform.py | 2 +- python/tvm/relax/training/_ffi_api.py | 4 +- python/tvm/relax/training/utils.py | 2 +- python/tvm/relax/transform/_ffi_api.py | 4 +- .../transform/legalize_ops/linear_algebra.py | 19 + .../transform/legalize_ops/manipulate.py | 14 + python/tvm/relax/transform/transform.py | 8 +- .../relax/transform/tuning_api/_ffi_api.py | 4 +- .../relax/transform/tuning_api/database.py | 2 +- .../transform/tuning_api/default_functions.py | 2 +- .../relax/transform/tuning_api/primitives.py | 2 +- python/tvm/relax/ty.py | 10 +- python/tvm/rpc/_ffi_api.py | 4 +- python/tvm/rpc/base.py | 2 +- python/tvm/rpc/client.py | 6 +- python/tvm/rpc/minrpc.py | 2 +- python/tvm/rpc/proxy.py | 2 +- python/tvm/rpc/server.py | 12 +- python/tvm/rpc/tracker.py | 2 +- python/tvm/runtime/_ffi_api.py | 8 +- python/tvm/runtime/_ffi_node_api.py | 10 +- python/tvm/runtime/disco/_ffi_api.py | 2 +- python/tvm/runtime/disco/process_pool.py | 2 +- python/tvm/runtime/disco/session.py | 4 +- python/tvm/runtime/module.py | 4 +- python/tvm/runtime/ndarray.py | 2 +- python/tvm/runtime/object_path.py | 20 +- python/tvm/runtime/profiling/__init__.py | 2 +- python/tvm/runtime/profiling/_ffi_api.py | 4 +- python/tvm/runtime/relax_vm.py | 8 +- python/tvm/runtime/script_printer.py | 2 +- python/tvm/runtime/support.py | 4 +- python/tvm/script/_ffi_api.py | 4 +- python/tvm/script/ir_builder/_ffi_api.py | 4 +- python/tvm/script/ir_builder/base.py | 2 +- python/tvm/script/ir_builder/ir/_ffi_api.py | 4 +- python/tvm/script/ir_builder/ir/frame.py | 2 +- .../tvm/script/ir_builder/relax/_ffi_api.py | 4 +- .../ir_builder/relax/distributed/_ffi_api.py | 4 +- .../script/ir_builder/relax/distributed/ir.py | 5 +- python/tvm/script/ir_builder/relax/frame.py | 2 +- python/tvm/script/ir_builder/relax/ir.py | 4 + python/tvm/script/ir_builder/tir/_ffi_api.py | 4 +- python/tvm/script/ir_builder/tir/frame.py | 2 +- python/tvm/script/parser/core/parser.py | 2 +- python/tvm/script/printer/_ffi_api.py | 4 +- python/tvm/script/printer/doc.py | 2 +- python/tvm/support.py | 4 +- python/tvm/target/_ffi_api.py | 4 +- python/tvm/target/datatype.py | 4 +- python/tvm/target/detect_target.py | 2 +- python/tvm/target/target.py | 8 +- python/tvm/target/virtual_device.py | 2 +- python/tvm/target/x86.py | 2 +- python/tvm/te/_ffi_api.py | 4 +- python/tvm/te/operation.py | 7 +- python/tvm/te/tensor.py | 14 +- python/tvm/testing/_ffi_api.py | 4 +- python/tvm/testing/popen_pool.py | 6 +- python/tvm/testing/utils.py | 2 +- python/tvm/tir/_ffi_api.py | 4 +- python/tvm/tir/analysis/_ffi_api.py | 4 +- python/tvm/tir/block_dependence_info.py | 2 +- python/tvm/tir/block_scope.py | 2 +- python/tvm/tir/buffer.py | 9 +- python/tvm/tir/data_layout.py | 6 +- python/tvm/tir/expr.py | 72 +- python/tvm/tir/function.py | 8 +- python/tvm/tir/ir_builder.py | 5 +- python/tvm/tir/op.py | 4 +- python/tvm/tir/schedule/_ffi_api.py | 4 +- python/tvm/tir/schedule/analysis.py | 6 +- python/tvm/tir/schedule/instruction.py | 2 +- python/tvm/tir/schedule/schedule.py | 2 +- python/tvm/tir/schedule/state.py | 2 +- python/tvm/tir/schedule/trace.py | 4 +- python/tvm/tir/schedule/transform.py | 2 +- python/tvm/tir/stmt.py | 42 +- python/tvm/tir/tensor_intrin/cuda.py | 2 +- python/tvm/tir/transform/_ffi_api.py | 4 +- python/tvm/tir/transform/function_pass.py | 4 +- python/tvm/topi/__init__.py | 3 +- python/tvm/topi/cpp/cuda.py | 4 +- python/tvm/topi/cpp/generic.py | 4 +- python/tvm/topi/cpp/impl.py | 4 +- python/tvm/topi/cpp/nn.py | 4 +- python/tvm/topi/cpp/rocm.py | 4 +- python/tvm/topi/cpp/utils.py | 4 +- python/tvm/topi/cpp/vision/__init__.py | 4 +- python/tvm/topi/cpp/vision/yolo.py | 4 +- python/tvm/topi/cpp/x86.py | 4 +- python/tvm/topi/generic_op_impl.py | 2 +- python/tvm/topi/image/resize.py | 6 +- python/tvm/topi/math.py | 9 +- python/tvm/topi/nn/conv2d.py | 5 + python/tvm/topi/slice_scatter.py | 74 + python/tvm/topi/transform.py | 3 +- python/tvm/topi/utils.py | 12 +- rust/.gitignore | 4 - rust/.rustfmt.toml | 31 - rust/tvm-macros/Cargo.toml | 37 - rust/tvm-macros/README.md | 20 - rust/tvm-macros/src/external.rs | 198 --- rust/tvm-macros/src/import_module.rs | 133 -- rust/tvm-macros/src/lib.rs | 44 - rust/tvm-macros/src/object.rs | 212 --- rust/tvm-macros/src/util.rs | 48 - rust/tvm-rt/.gitignore | 7 - rust/tvm-rt/Cargo.toml | 95 -- rust/tvm-rt/README.md | 60 - rust/tvm-rt/src/device.rs | 97 -- rust/tvm-rt/src/errors.rs | 97 -- rust/tvm-rt/src/function.rs | 354 ---- rust/tvm-rt/src/lib.rs | 155 -- rust/tvm-rt/src/module.rs | 131 -- rust/tvm-rt/src/ndarray.rs | 515 ------ rust/tvm-rt/src/object/mod.rs | 110 -- rust/tvm-rt/src/object/object_ptr.rs | 555 ------ rust/tvm-rt/src/string.rs | 142 -- rust/tvm-rt/src/to_function.rs | 337 ---- rust/tvm-sys/Cargo.toml | 81 - rust/tvm-sys/README.md | 28 - rust/tvm-sys/build.rs | 274 --- rust/tvm-sys/src/array.rs | 63 - rust/tvm-sys/src/byte_array.rs | 152 -- rust/tvm-sys/src/datatype.rs | 214 --- rust/tvm-sys/src/device.rs | 294 ---- rust/tvm-sys/src/errors.rs | 46 - rust/tvm-sys/src/lib.rs | 72 - rust/tvm-sys/src/packed_func.rs | 400 ----- rust/tvm-sys/src/value.rs | 95 -- src/arith/analyzer.cc | 15 +- src/arith/bound_deducer.cc | 4 +- src/arith/canonical_simplify.cc | 8 +- src/arith/const_fold.h | 44 +- src/arith/const_int_bound.cc | 13 +- src/arith/detect_common_subexpr.cc | 2 +- src/arith/detect_linear_equation.cc | 6 +- src/arith/domain_touched.cc | 12 +- src/arith/int_constraints.cc | 13 +- src/arith/int_set.cc | 38 +- src/arith/iter_affine_map.cc | 38 +- src/arith/modular_set.cc | 4 +- src/arith/narrow_predicate_expression.cc | 5 +- src/arith/presburger_set.cc | 4 +- src/arith/rewrite_simplify.cc | 18 +- src/arith/scalable_expression.cc | 33 +- src/arith/scalable_expression.h | 15 +- src/arith/solve_linear_equation.cc | 4 +- src/arith/solve_linear_inequality.cc | 8 +- src/contrib/msc/core/codegen/code_stack.cc | 10 +- src/contrib/msc/core/codegen/code_stack.h | 8 +- src/contrib/msc/core/ir/graph.cc | 103 +- src/contrib/msc/core/ir/graph_builder.cc | 26 +- src/contrib/msc/core/ir/graph_builder.h | 2 +- src/contrib/msc/core/ir/plugin.cc | 8 +- .../msc/core/printer/msc_base_printer.cc | 2 +- src/contrib/msc/core/printer/print_utils.cc | 8 +- src/contrib/msc/core/printer/print_utils.h | 8 +- .../msc/core/printer/prototxt_printer.cc | 10 +- .../msc/core/transform/bind_named_params.cc | 2 +- src/contrib/msc/core/transform/bind_shape.cc | 2 +- src/contrib/msc/core/transform/fuse_tuple.cc | 23 +- .../msc/core/transform/inline_params.cc | 4 +- .../msc/core/transform/rewrite_utils.cc | 2 +- .../msc/core/transform/set_byoc_attrs.cc | 6 +- .../msc/core/transform/set_expr_layout.cc | 2 +- .../msc/core/transform/set_expr_name.cc | 10 +- src/contrib/msc/core/utils.cc | 16 +- .../msc/framework/tensorflow/codegen.cc | 2 +- .../msc/framework/tensorflow/tf_v1_opcode.cc | 2 +- src/contrib/msc/framework/tensorrt/codegen.cc | 42 +- .../msc/framework/tensorrt/tensorrt_opcode.cc | 18 +- .../framework/tensorrt/transform_tensorrt.cc | 4 +- src/contrib/msc/framework/torch/codegen.cc | 2 +- src/contrib/msc/framework/tvm/codegen.cc | 2 +- src/contrib/msc/plugin/tensorrt_codegen.cc | 10 +- src/contrib/msc/plugin/torch_codegen.cc | 4 +- src/contrib/msc/plugin/tvm_codegen.cc | 8 +- src/ir/analysis.cc | 2 +- src/ir/apply_pass_to_function.cc | 4 +- src/ir/attr_functor.h | 4 +- src/ir/attrs.cc | 6 +- src/ir/diagnostic.cc | 22 +- src/ir/env_func.cc | 8 +- src/ir/expr.cc | 14 +- src/ir/function.cc | 14 +- src/ir/global_info.cc | 9 +- src/ir/global_var_supply.cc | 17 +- src/ir/instrument.cc | 8 +- src/ir/module.cc | 50 +- src/ir/name_supply.cc | 11 +- src/ir/op.cc | 28 +- src/ir/replace_global_vars.cc | 6 +- src/ir/source_map.cc | 23 +- src/ir/transform.cc | 53 +- src/ir/type.cc | 15 +- src/meta_schedule/arg_info.cc | 25 +- src/meta_schedule/builder/builder.cc | 8 +- src/meta_schedule/cost_model/cost_model.cc | 11 +- src/meta_schedule/database/database.cc | 65 +- src/meta_schedule/database/database_utils.cc | 8 +- src/meta_schedule/database/json_database.cc | 5 +- src/meta_schedule/database/memory_database.cc | 2 +- .../database/ordered_union_database.cc | 4 +- .../database/schedule_fn_database.cc | 10 +- src/meta_schedule/database/union_database.cc | 5 +- src/meta_schedule/extracted_task.cc | 2 +- .../feature_extractor/feature_extractor.cc | 4 +- .../feature_extractor/per_store_feature.cc | 5 +- .../measure_callback/add_to_database.cc | 2 +- .../measure_callback/measure_callback.cc | 6 +- .../measure_callback/remove_build_artifact.cc | 2 +- .../measure_callback/update_cost_model.cc | 2 +- .../mutator/mutate_compute_location.cc | 4 +- src/meta_schedule/mutator/mutate_parallel.cc | 7 +- .../mutator/mutate_thread_binding.cc | 4 +- src/meta_schedule/mutator/mutate_tile_size.cc | 7 +- src/meta_schedule/mutator/mutate_unroll.cc | 6 +- src/meta_schedule/mutator/mutator.cc | 17 +- .../disallow_async_strided_mem_copy.cc | 6 +- .../postproc/disallow_dynamic_loop.cc | 2 +- src/meta_schedule/postproc/postproc.cc | 16 +- .../postproc/rewrite_cooperative_fetch.cc | 22 +- src/meta_schedule/postproc/rewrite_layout.cc | 10 +- .../rewrite_parallel_vectorize_unroll.cc | 12 +- .../postproc/rewrite_reduction_block.cc | 4 +- .../postproc/rewrite_tensorize.cc | 2 +- .../postproc/rewrite_unbound_block.cc | 2 +- src/meta_schedule/postproc/verify_gpu_code.cc | 7 +- .../postproc/verify_vtcm_limit.cc | 2 +- src/meta_schedule/profiler.cc | 16 +- src/meta_schedule/runner/runner.cc | 15 +- src/meta_schedule/schedule/cpu/winograd.cc | 8 +- .../schedule/cuda/thread_bind.cc | 4 +- src/meta_schedule/schedule/cuda/winograd.cc | 12 +- .../schedule_rule/add_rfactor.cc | 2 +- .../schedule_rule/apply_custom_rule.cc | 4 +- src/meta_schedule/schedule_rule/auto_bind.cc | 3 +- .../schedule_rule/auto_inline.cc | 4 +- .../schedule_rule/cross_thread_reduction.cc | 6 +- .../schedule_rule/multi_level_tiling.cc | 2 +- .../multi_level_tiling_tensor_core.cc | 10 +- .../multi_level_tiling_wide_vector.cc | 6 +- .../multi_level_tiling_with_intrin.cc | 4 +- .../parallel_vectorize_unroll.cc | 2 +- .../schedule_rule/random_compute_location.cc | 2 +- .../schedule_rule/schedule_rule.cc | 70 +- .../search_strategy/evolutionary_search.cc | 15 +- .../search_strategy/replay_func.cc | 12 +- .../search_strategy/replay_trace.cc | 6 +- .../search_strategy/search_strategy.cc | 16 +- .../space_generator/post_order_apply.cc | 2 +- .../space_generator/schedule_fn.cc | 2 +- .../space_generator/space_generator.cc | 8 +- .../space_generator/space_generator_union.cc | 2 +- .../task_scheduler/gradient_based.cc | 2 +- .../task_scheduler/round_robin.cc | 2 +- .../task_scheduler/task_scheduler.cc | 27 +- src/meta_schedule/trace_apply.cc | 2 +- src/meta_schedule/tune_context.cc | 8 +- src/meta_schedule/utils.h | 16 +- src/node/attr_registry.h | 2 +- src/node/container_printing.cc | 14 +- src/node/object_path.cc | 26 +- src/node/reflection.cc | 12 +- src/node/repr_printer.cc | 6 +- src/node/script_printer.cc | 12 +- src/node/serialization.cc | 36 +- src/node/structural_equal.cc | 16 +- src/node/structural_hash.cc | 64 +- src/relax/analysis/analysis.cc | 12 +- .../analysis/computable_at_compile_time.cc | 2 +- src/relax/analysis/detect_recursion.cc | 2 +- src/relax/analysis/layout_transformation.cc | 2 +- src/relax/analysis/struct_info_analysis.cc | 28 +- src/relax/analysis/tir_op_pattern_kind.cc | 4 +- src/relax/analysis/udchain.cc | 7 +- src/relax/analysis/var2value.cc | 14 +- src/relax/analysis/well_formed.cc | 6 +- src/relax/backend/contrib/clml/codegen.cc | 7 +- .../backend/contrib/codegen_c/codegen_c.h | 16 +- .../contrib/codegen_json/codegen_json.h | 6 +- src/relax/backend/contrib/cublas/codegen.cc | 2 +- src/relax/backend/contrib/cudnn/codegen.cc | 2 +- src/relax/backend/contrib/cutlass/codegen.cc | 10 +- src/relax/backend/contrib/dnnl/codegen.cc | 2 +- src/relax/backend/contrib/hipblas/codegen.cc | 4 +- src/relax/backend/contrib/nnapi/codegen.cc | 4 +- src/relax/backend/contrib/tensorrt/codegen.cc | 7 +- src/relax/backend/contrib/utils.cc | 2 +- src/relax/backend/pattern_registry.cc | 11 +- src/relax/backend/pattern_registry.h | 4 +- src/relax/backend/task_extraction.cc | 2 +- src/relax/backend/vm/codegen_vm.cc | 6 +- src/relax/backend/vm/codegen_vm_tir.cc | 12 +- src/relax/backend/vm/exec_builder.cc | 44 +- src/relax/backend/vm/lower_runtime_builtin.cc | 2 +- src/relax/backend/vm/vm_shape_lower.cc | 6 +- src/relax/distributed/global_info.cc | 8 +- src/relax/distributed/struct_info.cc | 10 +- .../transform/legalize_redistribute.cc | 2 +- .../distributed/transform/lower_distir.cc | 6 +- .../lower_global_view_to_local_view.cc | 2 +- .../transform/propagate_sharding.cc | 4 +- src/relax/distributed/transform/utils.h | 2 +- src/relax/ir/binding_rewrite.cc | 18 +- src/relax/ir/block_builder.cc | 50 +- src/relax/ir/dataflow_block_rewriter.cc | 12 +- src/relax/ir/dataflow_expr_rewriter.cc | 32 +- src/relax/ir/dataflow_matcher.cc | 8 +- src/relax/ir/dataflow_matcher.h | 2 +- src/relax/ir/dataflow_pattern.cc | 78 +- src/relax/ir/dataflow_rewriter.h | 6 +- src/relax/ir/emit_te.cc | 4 +- src/relax/ir/expr.cc | 76 +- src/relax/ir/expr_functor.cc | 4 +- src/relax/ir/py_expr_functor.cc | 50 +- src/relax/ir/struct_info.cc | 33 +- src/relax/ir/transform.cc | 6 +- src/relax/ir/type.cc | 12 +- src/relax/op/ccl/ccl.cc | 9 +- src/relax/op/distributed/distributed.cc | 8 +- src/relax/op/image/resize.cc | 2 +- src/relax/op/memory/view.cc | 15 +- src/relax/op/nn/attention.cc | 8 +- src/relax/op/nn/convolution.cc | 10 +- src/relax/op/nn/nn.cc | 30 +- src/relax/op/nn/pooling.cc | 20 +- src/relax/op/op.cc | 76 +- src/relax/op/op_common.cc | 2 +- src/relax/op/op_common.h | 12 +- src/relax/op/tensor/binary.cc | 4 +- src/relax/op/tensor/binary.h | 2 +- src/relax/op/tensor/create.cc | 24 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/datatype.cc | 4 +- src/relax/op/tensor/grad.cc | 14 +- src/relax/op/tensor/index.cc | 40 +- src/relax/op/tensor/index.h | 6 +- src/relax/op/tensor/linear_algebra.cc | 42 +- src/relax/op/tensor/linear_algebra.h | 8 + src/relax/op/tensor/manipulate.cc | 223 ++- src/relax/op/tensor/manipulate.h | 22 +- src/relax/op/tensor/qdq.cc | 4 +- src/relax/op/tensor/sampling.cc | 3 +- src/relax/op/tensor/search.cc | 4 +- src/relax/op/tensor/set.cc | 12 +- src/relax/op/tensor/sorting.cc | 6 +- src/relax/op/tensor/statistical.cc | 4 +- src/relax/op/tensor/statistical.h | 20 +- src/relax/op/tensor/ternary.cc | 2 +- src/relax/op/tensor/unary.cc | 2 +- src/relax/testing/transform.cc | 2 +- src/relax/training/utils.cc | 4 +- src/relax/training/utils.h | 2 +- src/relax/transform/adjust_matmul_order.cc | 12 +- src/relax/transform/allocate_workspace.cc | 2 +- src/relax/transform/alter_op_impl.cc | 2 +- .../transform/annotate_tir_op_pattern.cc | 3 +- .../attach_attr_layout_free_buffers.cc | 2 +- src/relax/transform/attach_global_symbol.cc | 2 +- src/relax/transform/bind_params.cc | 4 +- src/relax/transform/bind_symbolic_vars.cc | 4 +- src/relax/transform/bundle_model_params.cc | 2 +- src/relax/transform/call_tir_rewrite.cc | 2 +- src/relax/transform/canonicalize_bindings.cc | 21 +- .../transform/combine_parallel_matmul.cc | 7 +- src/relax/transform/compute_prim_value.cc | 2 +- src/relax/transform/convert_dataflow.cc | 4 +- src/relax/transform/convert_layout.cc | 8 +- src/relax/transform/dataflow_inplace.cc | 10 +- src/relax/transform/dead_code_elimination.cc | 6 +- src/relax/transform/decompose_ops.cc | 4 +- .../transform/eliminate_common_subexpr.cc | 4 +- src/relax/transform/expand_matmul_of_sum.cc | 2 +- src/relax/transform/expand_tuple_arguments.cc | 7 +- src/relax/transform/few_shot_tuning.cc | 14 +- src/relax/transform/fold_constant.cc | 27 +- src/relax/transform/fuse_ops.cc | 26 +- src/relax/transform/fuse_tir.cc | 6 +- src/relax/transform/gradient.cc | 4 +- src/relax/transform/infer_amp_utils.h | 5 +- src/relax/transform/inline_functions.cc | 6 +- src/relax/transform/kill_after_last_use.cc | 2 +- src/relax/transform/lambda_lift.cc | 8 +- src/relax/transform/lazy_transform_params.cc | 6 +- src/relax/transform/legalize_ops.cc | 4 +- src/relax/transform/lift_transform_params.cc | 5 +- src/relax/transform/lower_alloc_tensor.cc | 2 +- .../transform/merge_composite_functions.cc | 6 +- src/relax/transform/meta_schedule.cc | 14 +- src/relax/transform/normalize.cc | 8 +- src/relax/transform/realize_vdevice.cc | 8 +- src/relax/transform/remove_purity_checking.cc | 3 +- src/relax/transform/remove_unused_outputs.cc | 4 +- .../transform/remove_unused_parameters.cc | 2 +- .../reorder_permute_dims_after_concat.cc | 2 +- .../transform/reorder_take_after_matmul.cc | 2 +- src/relax/transform/rewrite_cuda_graph.cc | 10 +- .../transform/rewrite_dataflow_reshape.cc | 2 +- src/relax/transform/run_codegen.cc | 3 +- .../transform/split_call_tir_by_pattern.cc | 15 +- .../transform/split_layout_rewrite_preproc.cc | 6 +- .../transform/static_plan_block_memory.cc | 14 +- src/relax/transform/to_mixed_precision.cc | 10 +- src/relax/transform/to_non_dataflow.cc | 2 +- src/relax/transform/topological_sort.cc | 6 +- src/relax/transform/tuning_api/database.cc | 32 +- src/relax/transform/tuning_api/primitives.cc | 55 +- .../transform/update_param_struct_info.cc | 3 +- src/relax/transform/update_vdevice.cc | 2 +- src/relax/transform/utils.h | 4 +- src/relax/utils.cc | 2 +- src/runtime/builtin_fp16.cc | 8 +- src/runtime/c_runtime_api.cc | 805 --------- src/runtime/const_loader_module.cc | 14 +- src/runtime/container.cc | 112 -- src/runtime/contrib/amx/amx_config.cc | 7 +- .../contrib/arm_compute_lib/acl_allocator.h | 2 +- .../contrib/arm_compute_lib/acl_runtime.cc | 6 +- src/runtime/contrib/bnns/bnns_json_runtime.cc | 6 +- src/runtime/contrib/cblas/cblas.cc | 8 +- src/runtime/contrib/cblas/dnnl_blas.cc | 4 +- src/runtime/contrib/cblas/gemm_common.h | 2 +- src/runtime/contrib/cblas/mkl.cc | 10 +- src/runtime/contrib/clml/clml_runtime.cc | 6 +- src/runtime/contrib/clml/clml_runtime.h | 2 +- src/runtime/contrib/coreml/coreml_runtime.h | 2 +- src/runtime/contrib/coreml/coreml_runtime.mm | 7 +- src/runtime/contrib/cublas/cublas.cc | 8 +- .../contrib/cublas/cublas_json_runtime.cc | 6 +- src/runtime/contrib/cublas/cublas_utils.cc | 2 +- src/runtime/contrib/cudnn/conv_backward.cc | 10 +- src/runtime/contrib/cudnn/conv_forward.cc | 10 +- .../contrib/cudnn/cudnn_frontend/attention.cc | 2 +- .../contrib/cudnn/cudnn_frontend/attention.h | 2 +- .../contrib/cudnn/cudnn_json_runtime.cc | 6 +- src/runtime/contrib/cudnn/cudnn_utils.cc | 6 +- src/runtime/contrib/cudnn/softmax.cc | 6 +- src/runtime/contrib/curand/curand.cc | 6 +- .../contrib/curand/helper_cuda_kernels.h | 2 +- .../contrib/cutlass/fp16_group_gemm.cu | 6 +- .../cutlass/fp8_blockwise_scaled_gemm.cu | 8 +- src/runtime/contrib/cutlass/fp8_gemm.cu | 10 +- src/runtime/contrib/cutlass/fp8_group_gemm.cu | 10 +- .../contrib/cutlass/weight_preprocess.cc | 5 +- src/runtime/contrib/dnnl/dnnl.cc | 2 +- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 6 +- src/runtime/contrib/dnnl/dnnl_kernel.h | 4 +- .../contrib/edgetpu/edgetpu_runtime.cc | 4 +- src/runtime/contrib/hipblas/hipblas.cc | 14 +- .../contrib/hipblas/hipblas_json_runtime.cc | 13 +- src/runtime/contrib/hipblas/hipblas_utils.cc | 2 +- src/runtime/contrib/json/json_node.h | 2 +- src/runtime/contrib/json/json_runtime.h | 2 +- src/runtime/contrib/miopen/conv_forward.cc | 6 +- src/runtime/contrib/miopen/miopen_utils.cc | 2 +- src/runtime/contrib/miopen/softmax.cc | 6 +- src/runtime/contrib/mps/conv.mm | 6 +- src/runtime/contrib/mps/gemm.mm | 2 +- src/runtime/contrib/mps/mps_utils.h | 2 +- src/runtime/contrib/mrvl/mrvl_hw_runtime.cc | 6 +- src/runtime/contrib/mrvl/mrvl_runtime.cc | 6 +- .../contrib/mrvl/mrvl_sw_runtime_lib.cc | 2 +- .../contrib/mrvl/mrvl_sw_runtime_lib.h | 2 +- src/runtime/contrib/msc/tensorrt_runtime.cc | 7 +- src/runtime/contrib/mscclpp/allreduce.cu | 4 +- src/runtime/contrib/nnapi/nnapi_ops.cc | 2 +- src/runtime/contrib/nnapi/nnapi_runtime.cc | 6 +- src/runtime/contrib/nvshmem/init.cc | 17 +- src/runtime/contrib/nvshmem/kv_transfer.cu | 6 +- .../contrib/nvshmem/memory_allocator.cc | 11 +- src/runtime/contrib/papi/papi.cc | 2 +- src/runtime/contrib/random/random.cc | 12 +- src/runtime/contrib/rocblas/rocblas.cc | 6 +- src/runtime/contrib/sort/sort.cc | 21 +- .../contrib/tensorrt/tensorrt_runtime.cc | 6 +- src/runtime/contrib/tflite/tflite_runtime.cc | 6 +- src/runtime/contrib/tflite/tflite_runtime.h | 3 +- src/runtime/contrib/thrust/thrust.cu | 197 +-- src/runtime/contrib/vllm/attention_kernels.cu | 10 +- src/runtime/contrib/vllm/cache_alloc.cc | 4 +- src/runtime/contrib/vllm/cache_kernels.cu | 10 +- src/runtime/cpu_device_api.cc | 4 +- src/runtime/cuda/cuda_common.h | 2 +- src/runtime/cuda/cuda_device_api.cc | 21 +- src/runtime/cuda/cuda_module.cc | 8 +- src/runtime/cuda/l2_cache_flush.cc | 13 +- src/runtime/debug_compile.cc | 20 +- src/runtime/device_api.cc | 272 +++ src/runtime/disco/bcast_session.cc | 9 +- src/runtime/disco/bcast_session.h | 2 +- src/runtime/disco/builtin.cc | 63 +- src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc | 11 +- .../disco/cuda_ipc/custom_allreduce.cc | 4 +- src/runtime/disco/disco_worker.cc | 17 +- src/runtime/disco/disco_worker_thread.h | 2 +- .../disco/distributed/socket_session.cc | 8 +- src/runtime/disco/loader.cc | 48 +- src/runtime/disco/message_queue.h | 17 +- src/runtime/disco/nccl/nccl.cc | 43 +- src/runtime/disco/nccl/nccl_context.h | 4 +- src/runtime/disco/process_session.cc | 17 +- src/runtime/disco/protocol.h | 47 +- src/runtime/disco/session.cc | 25 +- src/runtime/disco/threaded_session.cc | 20 +- src/runtime/disco/utils.h | 2 +- src/runtime/dso_library.cc | 7 +- src/runtime/file_utils.cc | 17 +- src/runtime/file_utils.h | 4 +- src/runtime/hexagon/hexagon_buffer.h | 4 +- src/runtime/hexagon/hexagon_common.cc | 6 +- src/runtime/hexagon/hexagon_common.h | 3 +- src/runtime/hexagon/hexagon_device_api.cc | 31 +- src/runtime/hexagon/hexagon_module.cc | 2 +- src/runtime/hexagon/hexagon_thread_manager.h | 4 +- src/runtime/hexagon/hexagon_vtcm_pool.h | 4 +- src/runtime/hexagon/ops/conv2d.h | 2 +- src/runtime/hexagon/ops/conv2d_fp16_hvx.cc | 30 +- src/runtime/hexagon/ops/conv2d_quant_hvx.cc | 49 +- src/runtime/hexagon/rpc/android/session.cc | 4 +- src/runtime/hexagon/rpc/hexagon/rpc_server.cc | 9 +- .../hexagon/rpc/simulator/rpc_server.cc | 8 +- src/runtime/hexagon/rpc/simulator/session.cc | 5 +- src/runtime/library_module.cc | 2 +- src/runtime/library_module.h | 2 +- src/runtime/logging.cc | 2 +- src/runtime/memory/memory_manager.cc | 12 +- src/runtime/memory/naive_allocator.h | 2 +- src/runtime/memory/pooled_allocator.h | 2 +- src/runtime/meta_data.h | 4 +- src/runtime/metal/metal_common.h | 4 +- src/runtime/metal/metal_device_api.mm | 8 +- src/runtime/metal/metal_module.h | 2 +- src/runtime/metal/metal_module.mm | 16 +- src/runtime/minrpc/minrpc_interfaces.h | 93 - src/runtime/minrpc/minrpc_logger.cc | 291 ---- src/runtime/minrpc/minrpc_logger.h | 296 ---- src/runtime/minrpc/minrpc_server.h | 721 +------- src/runtime/minrpc/minrpc_server_logging.h | 170 -- .../posix_popen_server/posix_popen_server.cc | 3 - src/runtime/minrpc/rpc_reference.h | 207 +-- src/runtime/module.cc | 29 +- src/runtime/ndarray.cc | 109 +- src/runtime/object.cc | 92 - src/runtime/object_internal.h | 96 -- src/runtime/opencl/opencl_common.h | 14 +- src/runtime/opencl/opencl_device_api.cc | 25 +- src/runtime/opencl/opencl_module.cc | 10 +- src/runtime/opencl/opencl_module.h | 2 +- src/runtime/opencl/opencl_module_spirv.cc | 2 +- src/runtime/pack_args.h | 13 +- src/runtime/packed_func.cc | 32 - src/runtime/profiling.cc | 125 +- src/runtime/regex.cc | 11 +- src/runtime/registry.cc | 266 --- src/runtime/relax_vm/attn_backend.h | 5 +- src/runtime/relax_vm/attn_utils.h | 12 +- src/runtime/relax_vm/builtin.cc | 188 +- .../relax_vm/cuda/cuda_graph_builtin.cc | 19 +- src/runtime/relax_vm/executable.cc | 54 +- src/runtime/relax_vm/hexagon/builtin.cc | 7 +- src/runtime/relax_vm/kv_state.cc | 64 +- src/runtime/relax_vm/kv_state.h | 10 +- src/runtime/relax_vm/lm_support.cc | 50 +- src/runtime/relax_vm/ndarray_cache_support.cc | 42 +- src/runtime/relax_vm/paged_kv_cache.cc | 49 +- src/runtime/relax_vm/rnn_state.cc | 16 +- src/runtime/relax_vm/vm.cc | 41 +- src/runtime/rocm/rocm_common.h | 2 +- src/runtime/rocm/rocm_device_api.cc | 17 +- src/runtime/rocm/rocm_module.cc | 10 +- src/runtime/rpc/rpc_channel.cc | 4 +- src/runtime/rpc/rpc_channel.h | 2 +- src/runtime/rpc/rpc_channel_logger.h | 186 -- src/runtime/rpc/rpc_device_api.cc | 4 +- src/runtime/rpc/rpc_endpoint.cc | 81 +- src/runtime/rpc/rpc_endpoint.h | 3 +- src/runtime/rpc/rpc_event_impl.cc | 4 +- src/runtime/rpc/rpc_local_session.cc | 32 +- src/runtime/rpc/rpc_local_session.h | 2 +- src/runtime/rpc/rpc_module.cc | 56 +- src/runtime/rpc/rpc_pipe_impl.cc | 17 +- src/runtime/rpc/rpc_server_env.cc | 8 +- src/runtime/rpc/rpc_session.cc | 8 +- src/runtime/rpc/rpc_session.h | 4 +- src/runtime/rpc/rpc_socket_impl.cc | 11 +- src/runtime/runtime_base.h | 67 - src/runtime/spirv/spirv_shader.h | 5 +- src/runtime/static_library.cc | 9 +- src/runtime/static_library.h | 2 +- src/runtime/system_library.cc | 19 +- src/runtime/thread_pool.cc | 11 +- src/runtime/thread_storage_scope.h | 2 +- src/runtime/threading_backend.cc | 7 +- src/runtime/vulkan/vulkan_common.h | 5 +- src/runtime/vulkan/vulkan_device_api.cc | 11 +- src/runtime/vulkan/vulkan_module.cc | 6 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 2 +- src/script/ir_builder/base.cc | 30 +- src/script/ir_builder/ir/frame.cc | 2 +- src/script/ir_builder/ir/ir.cc | 20 +- src/script/ir_builder/relax/distributed.cc | 2 +- src/script/ir_builder/relax/ir.cc | 39 +- src/script/ir_builder/tir/ir.cc | 206 +-- src/script/printer/doc.cc | 82 +- .../printer/doc_printer/python_doc_printer.cc | 4 +- src/script/printer/ir/distributed.cc | 19 +- src/script/printer/ir/ir.cc | 2 +- src/script/printer/ir/misc.cc | 2 +- src/script/printer/ir_docsifier.cc | 18 +- src/script/printer/legacy_repr.cc | 12 +- src/script/printer/relax/binding.cc | 4 +- src/script/printer/relax/call.cc | 10 +- src/script/printer/relax/expr.cc | 6 +- src/script/printer/relax/function.cc | 5 +- src/script/printer/relax/region.cc | 2 +- src/script/printer/relax/tir.cc | 4 +- src/script/printer/relax/type.cc | 2 +- src/script/printer/relax/utils.h | 10 +- src/script/printer/tir/block.cc | 17 +- src/script/printer/tir/buffer.cc | 20 +- src/script/printer/tir/expr.cc | 4 +- src/script/printer/tir/for_loop.cc | 8 +- src/script/printer/tir/function.cc | 12 +- src/script/printer/tir/stmt.cc | 34 +- src/script/printer/tir/utils.h | 4 +- src/script/printer/utils.h | 7 +- src/support/array.h | 14 +- src/support/ffi_testing.cc | 103 +- src/support/libinfo.cc | 4 +- src/support/socket.h | 3 +- src/support/utils.h | 2 +- src/target/build_common.h | 2 +- src/target/codegen.cc | 16 +- src/target/datatype/myfloat/myfloat.cc | 2 +- src/target/datatype/posit/posit-wrapper.cc | 2 +- src/target/datatype/registry.cc | 11 +- src/target/datatype/registry.h | 5 +- src/target/intrin_rule.h | 2 +- src/target/llvm/codegen_aarch64.cc | 8 +- src/target/llvm/codegen_amdgpu.cc | 12 +- src/target/llvm/codegen_arm.cc | 4 +- src/target/llvm/codegen_cpu.cc | 8 +- src/target/llvm/codegen_hexagon.cc | 6 +- src/target/llvm/codegen_llvm.cc | 13 +- src/target/llvm/codegen_llvm.h | 2 +- src/target/llvm/codegen_nvptx.cc | 6 +- src/target/llvm/codegen_x86_64.cc | 4 +- src/target/llvm/intrin_rule_llvm.cc | 82 + src/target/llvm/intrin_rule_llvm.h | 2 +- src/target/llvm/intrin_rule_nvptx.cc | 2 +- src/target/llvm/intrin_rule_rocm.cc | 2 +- src/target/llvm/llvm_instance.cc | 63 +- src/target/llvm/llvm_instance.h | 6 +- src/target/llvm/llvm_module.cc | 70 +- src/target/llvm/llvm_module.h | 2 +- src/target/opt/build_cuda_on.cc | 2 +- src/target/source/codegen_c.cc | 17 +- src/target/source/codegen_c_host.cc | 16 +- src/target/source/codegen_cuda.cc | 2 +- src/target/source/codegen_metal.cc | 2 +- src/target/source/codegen_opencl.cc | 4 +- src/target/source/codegen_webgpu.cc | 2 +- src/target/source/source_module.cc | 9 +- src/target/spirv/build_vulkan.cc | 2 +- src/target/spirv/intrin_rule_spirv.cc | 38 +- src/target/spirv/ir_builder.h | 2 +- src/target/tag.cc | 8 +- src/target/target.cc | 70 +- src/target/target_info.cc | 2 +- src/target/target_kind.cc | 15 +- src/target/virtual_device.cc | 2 +- src/te/operation/compute_op.cc | 4 +- src/te/operation/create_primfunc.cc | 35 +- src/te/operation/create_primfunc.h | 2 +- src/te/operation/extern_op.cc | 4 +- src/te/operation/graph.cc | 6 +- src/te/operation/placeholder_op.cc | 6 +- src/te/operation/scan_op.cc | 4 +- src/te/tensor.cc | 14 +- .../analysis/block_access_region_detector.cc | 5 +- .../analysis/buffer_access_lca_detector.cc | 6 +- .../analysis/calculate_allocated_memory.cc | 8 +- src/tir/analysis/control_flow_graph.cc | 16 +- src/tir/analysis/control_flow_graph.h | 2 +- src/tir/analysis/deep_equal.cc | 6 +- src/tir/analysis/estimate_flops.cc | 31 +- src/tir/analysis/identify_memcpy.cc | 4 +- src/tir/analysis/is_pure_function.cc | 2 +- src/tir/analysis/oob_checker.cc | 2 +- src/tir/analysis/stmt_finding.cc | 4 +- src/tir/analysis/var_use_def_analysis.cc | 2 +- src/tir/analysis/verify_gpu_code.cc | 6 +- src/tir/analysis/verify_memory.cc | 6 +- src/tir/analysis/verify_ssa.cc | 6 +- src/tir/analysis/verify_well_formed.cc | 4 +- src/tir/ir/block_dependence_info.cc | 8 +- src/tir/ir/block_scope.cc | 19 +- src/tir/ir/buffer.cc | 17 +- src/tir/ir/data_layout.cc | 34 +- src/tir/ir/data_type_rewriter.cc | 6 +- src/tir/ir/expr.cc | 144 +- src/tir/ir/function.cc | 14 +- src/tir/ir/functor_common.h | 2 +- src/tir/ir/index_map.cc | 25 +- src/tir/ir/script/script_complete.cc | 4 +- src/tir/ir/script/script_complete.h | 2 +- src/tir/ir/specialize.cc | 4 +- src/tir/ir/stmt.cc | 52 +- src/tir/ir/stmt_functor.cc | 14 +- src/tir/ir/transform.cc | 4 +- src/tir/ir/utils.cc | 68 - src/tir/ir/utils.h | 47 - src/tir/op/builtin.cc | 2 +- src/tir/op/op.cc | 76 +- src/tir/schedule/analysis.h | 12 +- src/tir/schedule/analysis/analysis.cc | 51 +- src/tir/schedule/analysis/layout.cc | 4 +- src/tir/schedule/analysis/verify.cc | 6 +- src/tir/schedule/concrete_schedule.cc | 16 +- src/tir/schedule/concrete_schedule.h | 10 +- src/tir/schedule/instruction.cc | 6 +- src/tir/schedule/instruction_traits.h | 14 +- src/tir/schedule/primitive/block_annotate.cc | 4 +- .../schedule/primitive/blockize_tensorize.cc | 10 +- src/tir/schedule/primitive/cache_index.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 10 +- src/tir/schedule/primitive/compute_at.cc | 2 +- .../schedule/primitive/decompose_padding.cc | 2 +- src/tir/schedule/primitive/for_kind.cc | 8 +- .../primitive/layout_transformation.cc | 14 +- .../schedule/primitive/loop_transformation.cc | 16 +- src/tir/schedule/primitive/pad_einsum.cc | 8 +- src/tir/schedule/primitive/read_write_at.cc | 2 +- src/tir/schedule/primitive/reduction.cc | 8 +- src/tir/schedule/primitive/rolling_buffer.cc | 6 +- src/tir/schedule/schedule.cc | 136 +- src/tir/schedule/state.cc | 16 +- src/tir/schedule/trace.cc | 58 +- src/tir/schedule/traced_schedule.cc | 2 +- src/tir/schedule/traced_schedule.h | 9 +- src/tir/schedule/transform.cc | 16 +- src/tir/schedule/transform.h | 2 +- src/tir/schedule/utils.h | 14 +- src/tir/transforms/annotate_device_regions.cc | 5 +- src/tir/transforms/bind_params.cc | 2 +- src/tir/transforms/bound_checker.cc | 4 +- src/tir/transforms/combine_context_call.cc | 4 +- src/tir/transforms/common_subexpr_elim.cc | 6 +- .../transforms/common_subexpr_elim_tools.cc | 4 +- .../transforms/common_subexpr_elim_tools.h | 2 +- src/tir/transforms/compact_buffer_region.cc | 4 +- .../transforms/convert_blocks_to_opaque.cc | 3 +- .../transforms/convert_for_loops_serial.cc | 2 +- src/tir/transforms/decorate_device_scope.cc | 4 +- src/tir/transforms/default_gpu_schedule.cc | 12 +- src/tir/transforms/extract_constants.cc | 4 +- src/tir/transforms/flatten_buffer.cc | 2 +- .../transforms/force_narrow_index_to_i32.cc | 2 +- src/tir/transforms/hoist_expression.cc | 8 +- src/tir/transforms/inject_double_buffer.cc | 4 +- src/tir/transforms/inject_permuted_layout.cc | 6 +- src/tir/transforms/inject_ptx_async_copy.cc | 2 +- src/tir/transforms/inject_ptx_ldg32.cc | 4 +- src/tir/transforms/inject_rolling_buffer.cc | 4 +- .../transforms/inject_software_pipeline.cc | 11 +- src/tir/transforms/inject_virtual_thread.cc | 6 +- .../transforms/inline_private_functions.cc | 17 +- src/tir/transforms/ir_utils.cc | 8 +- src/tir/transforms/lift_thread_binding.cc | 2 +- src/tir/transforms/loop_partition.cc | 4 +- src/tir/transforms/lower_async_dma.cc | 2 +- .../lower_cross_thread_reduction.cc | 14 +- src/tir/transforms/lower_custom_datatypes.cc | 4 +- .../transforms/lower_device_kernel_launch.cc | 8 +- .../lower_device_storage_access_info.cc | 4 +- src/tir/transforms/lower_init_block.cc | 4 +- src/tir/transforms/lower_intrin.cc | 6 +- src/tir/transforms/lower_match_buffer.cc | 2 +- src/tir/transforms/lower_opaque_block.cc | 8 +- src/tir/transforms/lower_thread_allreduce.cc | 16 +- src/tir/transforms/lower_tvm_builtin.cc | 13 +- src/tir/transforms/lower_vtcm_alloc.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 4 +- src/tir/transforms/make_packed_api.cc | 15 +- src/tir/transforms/make_unpacked_api.cc | 6 +- .../manifest_shared_memory_local_stage.cc | 2 +- src/tir/transforms/memhammer_coalesce.cc | 2 +- .../memhammer_intermediate_stage.cc | 2 +- .../transforms/memhammer_lower_auto_copy.cc | 4 +- src/tir/transforms/memhammer_rewrite_rule.h | 2 +- .../memhammer_tensorcore_rewrite.cc | 10 +- .../merge_shared_memory_allocations.cc | 4 +- src/tir/transforms/narrow_datatype.cc | 4 +- .../plan_update_buffer_allocation_location.cc | 4 +- src/tir/transforms/primfunc_utils.cc | 6 +- src/tir/transforms/profile_instrumentation.cc | 2 +- .../reduce_branching_through_overcompute.cc | 2 +- src/tir/transforms/remap_thread_axis.cc | 8 +- src/tir/transforms/remove_assume.cc | 4 +- src/tir/transforms/remove_no_op.cc | 4 +- src/tir/transforms/remove_store_undef.cc | 4 +- .../remove_weight_layout_rewrite_block.cc | 2 +- src/tir/transforms/renew_defs.cc | 4 +- .../transforms/renormalize_split_pattern.cc | 4 +- src/tir/transforms/rewrite_unsafe_select.cc | 4 +- src/tir/transforms/simplify.cc | 10 +- src/tir/transforms/skip_assert.cc | 4 +- src/tir/transforms/split_host_device.cc | 4 +- src/tir/transforms/storage_rewrite.cc | 6 +- .../transforms/tensorcore_infer_fragment.cc | 4 +- src/tir/transforms/thread_storage_sync.cc | 4 +- .../transforms/transform_mma_buffer_layout.cc | 2 +- src/tir/transforms/unify_thread_binding.cc | 4 +- src/tir/transforms/unroll_loop.cc | 4 +- .../transforms/unsupported_dtype_legalize.cc | 13 +- .../using_assume_to_reduce_branches.cc | 4 +- src/tir/transforms/vectorize_loop.cc | 12 +- src/topi/broadcast.cc | 36 +- src/topi/einsum.cc | 2 +- src/topi/elemwise.cc | 76 +- src/topi/nn.cc | 117 +- src/topi/reduction.cc | 26 +- src/topi/transform.cc | 172 +- src/topi/utils.cc | 9 +- src/topi/vision.cc | 10 +- .../hexagon/hexagon_buffer_tests.cc | 3 +- .../hexagon/hexagon_device_api_tests.cc | 1 + .../hexagon/hexagon_user_dma_tests.cc | 1 + .../hexagon/hexagon_vtcm_pool_tests.cc | 1 + tests/cpp-runtime/hexagon/run_all_tests.cc | 5 +- tests/cpp-runtime/hexagon/run_unit_tests.cc | 5 +- .../cpp-runtime/opencl/aa_opencl_qcom_extn.cc | 2 +- .../cpp-runtime/opencl/clml_memory_planner.cc | 2 +- .../opencl/opencl_compile_to_bin.cc | 12 +- tests/cpp-runtime/opencl/opencl_nativeptr.cc | 2 +- tests/cpp-runtime/opencl/texture_copy_test.cc | 26 +- tests/cpp/ir_functor_test.cc | 2 +- tests/cpp/llvm_codegen_registry_test.cc | 3 +- tests/cpp/nested_msg_test.cc | 51 +- tests/cpp/object_protocol_test.cc | 3 +- .../runtime/memory/memory_manager_tests.cc | 10 +- tests/cpp/tir_scalable_datatype.cc | 4 +- tests/lint/rust_format.sh | 35 - .../arith/test_arith_canonical_simplify.py | 1 - .../arith/test_arith_rewrite_simplify.py | 28 +- tests/python/arith/test_arith_simplify.py | 15 +- .../codegen/test_gpu_codegen_allreduce.py | 2 +- .../codegen/test_target_codegen_aarch64.py | 163 +- .../codegen/test_target_codegen_llvm_vla.py | 149 ++ .../codegen/test_target_codegen_vulkan.py | 55 + .../python/contrib/test_hexagon/README_RPC.md | 4 +- .../python/contrib/test_hexagon/test_vtcm.py | 2 +- tests/python/disco/test_loader.py | 2 +- .../python/dlight/test_gpu_low_batch_gemv.py | 8 +- tests/python/dlight/test_gpu_matmul.py | 4 +- tests/python/ffi/test_ndarray.py | 27 + .../test_meta_schedule_builder.py | 4 +- .../test_meta_schedule_post_order_apply.py | 2 +- .../test_meta_schedule_runner.py | 2 +- .../test_meta_schedule_space_generator.py | 4 +- .../test_nnapi}/test_from_exported_to_cuda.py | 328 ++++ .../python/relax/frontend_nn_extern_module.cc | 6 +- tests/python/relax/test_binding_rewrite.py | 2 +- tests/python/relax/test_frontend_dynamo.py | 36 +- .../test_frontend_from_exported_program.py | 695 +++++++- tests/python/relax/test_frontend_from_fx.py | 636 ++++++- .../relax/test_frontend_nn_extern_module.py | 4 +- tests/python/relax/test_frontend_nn_op.py | 4 +- tests/python/relax/test_frontend_onnx.py | 27 +- tests/python/relax/test_op_distributed.py | 2 +- tests/python/relax/test_op_grad.py | 2 +- tests/python/relax/test_op_nn.py | 5 + tests/python/relax/test_op_nn_pooling.py | 1518 +++++++++++++++-- tests/python/relax/test_relax_operators.py | 2 +- .../relax/test_transform_convert_layout.py | 4 +- tests/python/relax/test_transform_gradient.py | 2 +- .../relax/test_transform_legalize_ops_grad.py | 3 +- tests/python/runtime/test_runtime_error.py | 2 +- tests/python/runtime/test_runtime_rpc.py | 29 +- tests/python/target/test_riscv_features.py | 22 +- tests/python/te/test_te_create_primfunc.py | 8 + tests/python/te/test_te_verify_compute.py | 12 +- .../testing/test_type_annotation_checker.py | 2 +- tests/python/tir-base/test_tir_base.py | 2 +- tests/python/tir-base/test_tir_intrin.py | 24 +- .../test_tir_schedule_split_fuse.py | 2 +- .../test_tir_transform_common_subexpr_elim.py | 2 +- ...est_tir_transform_inject_ptx_async_copy.py | 2 +- .../test_tir_transform_vectorize.py | 2 +- tests/scripts/task_java_unittest.sh | 23 +- tests/scripts/task_lint.sh | 3 - tests/scripts/task_python_docs.sh | 7 - version.py | 8 +- web/.eslintignore | 2 + web/apps/node/example.js | 2 +- web/emcc/tvmjs_support.cc | 46 +- web/emcc/wasm_runtime.cc | 62 +- web/emcc/webgpu_runtime.cc | 22 +- web/package.json | 4 + web/src/asyncify.ts | 9 + web/src/ctypes.ts | 296 ++-- web/src/environment.ts | 27 +- web/src/memory.ts | 76 +- web/src/rpc_server.ts | 15 +- web/src/runtime.ts | 813 +++------ web/tests/node/test_ndarray.js | 2 +- web/tests/node/test_object.js | 5 - web/tests/node/test_packed_func.js | 53 +- web/tests/python/webgpu_rpc_test.py | 1 - 1300 files changed, 12689 insertions(+), 22641 deletions(-) rename rust/Cargo.toml => docker/install/ubuntu_install_nnef.sh (91%) mode change 100644 => 100755 create mode 100644 ffi/scripts/benchmark_dlpack.py delete mode 100644 golang/Makefile delete mode 100644 golang/README.md delete mode 100644 golang/sample/Makefile delete mode 100644 golang/sample/complex.go delete mode 100644 golang/sample/pack_func_closure_arg.go delete mode 100644 golang/sample/pack_func_closure_return.go delete mode 100644 golang/sample/pack_func_convert.go delete mode 100644 golang/sample/pack_func_handle_arg.go delete mode 100644 golang/sample/pack_func_register.go delete mode 100644 golang/sample/simple.go delete mode 100644 golang/src/array_test.go delete mode 100644 golang/src/bytearray.go delete mode 100644 golang/src/bytearray_test.go delete mode 100644 golang/src/device.go delete mode 100644 golang/src/error.go delete mode 100644 golang/src/error_test.go delete mode 100644 golang/src/function.go delete mode 100644 golang/src/function_test.go delete mode 100644 golang/src/gotvm.cc delete mode 100644 golang/src/gotvm.h delete mode 100644 golang/src/gotvm_test.go delete mode 100644 golang/src/module.go delete mode 100644 golang/src/module_test.go delete mode 100644 golang/src/ndarray.go delete mode 100644 golang/src/tvm_runtime_pack.cc delete mode 100644 golang/src/type.go delete mode 100644 golang/src/utils.go delete mode 100644 golang/src/value.go delete mode 100644 golang/src/value_test.go rename golang/src/gotvm.go => include/tvm/runtime/base.h (50%) delete mode 100644 include/tvm/runtime/c_runtime_api.h delete mode 100644 include/tvm/runtime/container/array.h delete mode 100644 include/tvm/runtime/container/base.h delete mode 100644 include/tvm/runtime/container/map.h delete mode 100644 include/tvm/runtime/container/optional.h delete mode 100644 include/tvm/runtime/container/shape_tuple.h delete mode 100644 include/tvm/runtime/container/string.h delete mode 100644 include/tvm/runtime/container/variant.h rename include/tvm/runtime/{memory.h => int_tuple.h} (73%) delete mode 100644 include/tvm/runtime/registry.h delete mode 100644 include/tvm/support/span.h rename jvm/core/src/main/java/org/apache/tvm/{ArgTypeCode.java => TVMObject.java} (65%) create mode 100644 jvm/core/src/main/java/org/apache/tvm/TypeIndex.java create mode 100644 jvm/core/src/test/scripts/prepare_test_libs.py delete mode 100644 python/tvm/_ffi/__init__.py delete mode 100644 python/tvm/_ffi/_pyversion.py delete mode 100644 python/tvm/_ffi/registry.py rename python/tvm/{_ffi => }/base.py (86%) delete mode 100644 python/tvm/generic.py rename python/tvm/{_ffi => }/libinfo.py (98%) create mode 100644 python/tvm/topi/slice_scatter.py delete mode 100644 rust/.gitignore delete mode 100644 rust/.rustfmt.toml delete mode 100644 rust/tvm-macros/Cargo.toml delete mode 100644 rust/tvm-macros/README.md delete mode 100644 rust/tvm-macros/src/external.rs delete mode 100644 rust/tvm-macros/src/import_module.rs delete mode 100644 rust/tvm-macros/src/lib.rs delete mode 100644 rust/tvm-macros/src/object.rs delete mode 100644 rust/tvm-macros/src/util.rs delete mode 100644 rust/tvm-rt/.gitignore delete mode 100644 rust/tvm-rt/Cargo.toml delete mode 100644 rust/tvm-rt/README.md delete mode 100644 rust/tvm-rt/src/device.rs delete mode 100644 rust/tvm-rt/src/errors.rs delete mode 100644 rust/tvm-rt/src/function.rs delete mode 100644 rust/tvm-rt/src/lib.rs delete mode 100644 rust/tvm-rt/src/module.rs delete mode 100644 rust/tvm-rt/src/ndarray.rs delete mode 100644 rust/tvm-rt/src/object/mod.rs delete mode 100644 rust/tvm-rt/src/object/object_ptr.rs delete mode 100644 rust/tvm-rt/src/string.rs delete mode 100644 rust/tvm-rt/src/to_function.rs delete mode 100644 rust/tvm-sys/Cargo.toml delete mode 100644 rust/tvm-sys/README.md delete mode 100644 rust/tvm-sys/build.rs delete mode 100644 rust/tvm-sys/src/array.rs delete mode 100644 rust/tvm-sys/src/byte_array.rs delete mode 100644 rust/tvm-sys/src/datatype.rs delete mode 100644 rust/tvm-sys/src/device.rs delete mode 100644 rust/tvm-sys/src/errors.rs delete mode 100644 rust/tvm-sys/src/lib.rs delete mode 100644 rust/tvm-sys/src/packed_func.rs delete mode 100644 rust/tvm-sys/src/value.rs delete mode 100644 src/runtime/c_runtime_api.cc delete mode 100644 src/runtime/container.cc create mode 100644 src/runtime/device_api.cc delete mode 100644 src/runtime/minrpc/minrpc_interfaces.h delete mode 100644 src/runtime/minrpc/minrpc_logger.cc delete mode 100644 src/runtime/minrpc/minrpc_logger.h delete mode 100644 src/runtime/minrpc/minrpc_server_logging.h delete mode 100644 src/runtime/object.cc delete mode 100644 src/runtime/object_internal.h delete mode 100644 src/runtime/packed_func.cc delete mode 100644 src/runtime/registry.cc delete mode 100644 src/runtime/rpc/rpc_channel_logger.h delete mode 100644 src/runtime/runtime_base.h delete mode 100644 src/tir/ir/utils.cc delete mode 100644 src/tir/ir/utils.h delete mode 100755 tests/lint/rust_format.sh create mode 100644 tests/python/codegen/test_target_codegen_llvm_vla.py rename tests/python/{relax => nightly/test_nnapi}/test_from_exported_to_cuda.py (70%) diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 6fd81c1d6903..cd7fd9197fae 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -3,7 +3,7 @@ runs: steps: - uses: actions/cache@v3 env: - CACHE_NUMBER: 1 + CACHE_NUMBER: 2 with: path: ~/conda_pkgs_dir key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('conda/build-environment.yaml') }} @@ -15,8 +15,7 @@ runs: channel-priority: strict environment-file: conda/build-environment.yaml auto-activate-base: false - conda-solver: classic - use-only-tar-bz2: true + miniforge-version: latest python-version: 3.9 condarc-file: conda/condarc - uses: conda-incubator/setup-miniconda@v3 @@ -26,14 +25,14 @@ runs: channel-priority: strict environment-file: conda/build-environment.yaml auto-activate-base: false - conda-solver: classic + miniforge-version: latest use-only-tar-bz2: true python-version: 3.9 condarc-file: conda/condarc - name: Conda info shell: pwsh run: | - conda info - conda list - conda info --envs - conda list --name base + mamba info + mamba list + mamba info --envs + mamba list --name base diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index bbccc75af117..3e07e778d78f 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit bbccc75af117473f6de81905bd3314775f41636e +Subproject commit 3e07e778d78f0fcd047533c1fdaed571a68a396f diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 26085bc366f4..5255d3f4b10a 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -42,9 +42,8 @@ #include "../ffi/src/ffi/object.cc" #include "../ffi/src/ffi/testing.cc" #include "../ffi/src/ffi/traceback.cc" -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/container.cc" #include "../src/runtime/cpu_device_api.cc" +#include "../src/runtime/device_api.cc" #include "../src/runtime/dso_library.cc" #include "../src/runtime/file_utils.cc" #include "../src/runtime/library_module.cc" @@ -53,7 +52,6 @@ #include "../src/runtime/minrpc/minrpc_logger.cc" #include "../src/runtime/module.cc" #include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" #include "../src/runtime/profiling.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/rpc/rpc_channel.cc" diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 88ad99e47af2..e5a5154acbf2 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,7 +20,9 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ -#include +#include +#include +#include #include #ifndef _WIN32 diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index dbb0a62d2c5d..a5d3f6957c33 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -24,7 +24,7 @@ #ifndef TVM_APPS_CPP_RPC_ENV_H_ #define TVM_APPS_CPP_RPC_ENV_H_ -#include +#include #include diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index c4ee4d35450f..2f74dd309f42 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -21,7 +21,7 @@ * \file rpc_server.cc * \brief RPC Server implementation. */ -#include +#include #if defined(__linux__) || defined(__ANDROID__) || defined(__APPLE__) #include #include @@ -398,6 +398,6 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate); +TVM_FFI_REGISTER_GLOBAL("rpc.ServerCreate").set_body_typed(RPCServerCreate); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index e4565d095b2e..9bb61065c58a 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -26,7 +26,7 @@ #include -#include "tvm/runtime/c_runtime_api.h" +#include "tvm/runtime/base.h" namespace tvm { namespace runtime { diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index f4fc9fb365a8..56242082cca3 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -19,9 +19,8 @@ #include "launcher_core.h" +#include #include -#include -#include #include #include diff --git a/apps/hexagon_launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h index ae9e4108cd57..5e62774607ba 100644 --- a/apps/hexagon_launcher/launcher_core.h +++ b/apps/hexagon_launcher/launcher_core.h @@ -22,10 +22,10 @@ #include #include +#include #include #include #include -#include #include #include diff --git a/apps/ios_rpc/tvmrpc/RPCServer.mm b/apps/ios_rpc/tvmrpc/RPCServer.mm index 3dc2fb0c192a..da689d4c7064 100644 --- a/apps/ios_rpc/tvmrpc/RPCServer.mm +++ b/apps/ios_rpc/tvmrpc/RPCServer.mm @@ -23,8 +23,7 @@ #import "RPCServer.h" -#include -#include +#include #include #include diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 243e4819d025..8d0ae7368d8a 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -23,7 +23,7 @@ #import -#include +#include #include "RPCArgs.h" @@ -51,14 +51,14 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.workpath") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { static const std::string base_ = NSTemporaryDirectory().UTF8String; const auto path = args[0].cast(); *rv = base_ + "/" + path; }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.load_module") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); std::string fmt = GetFileFormat(name, ""); @@ -109,7 +109,7 @@ void Init(const std::string& name) { }; // Add UnsignedDSOLoader plugin in global registry -TVM_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_dylib_custom") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto n = make_object(); n->Init(args[0]); diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini index c59d46539b32..bf2a11ab1bdc 100644 --- a/ci/jenkins/docker-images.ini +++ b/ci/jenkins/docker-images.ini @@ -17,10 +17,10 @@ # This data file is read during when Jenkins runs job to determine docker images. [jenkins] -ci_arm: tlcpack/ci-arm:20250428-080833-03eadc65 -ci_cpu: tlcpack/ci_cpu:20250428-080833-03eadc65 -ci_gpu: tlcpack/ci-gpu:20250428-080833-03eadc65 -ci_hexagon: tlcpack/ci-hexagon:20250428-080833-03eadc65 -ci_i386: tlcpack/ci-i386:20250428-080833-03eadc65 -ci_lint: tlcpack/ci-lint:20250428-080833-03eadc65 -ci_wasm: tlcpack/ci-wasm:20250428-080833-03eadc65 +ci_arm: tlcpack/ci-arm:20250513-063354-70aa3797 +ci_cpu: tlcpack/ci_cpu:20250513-063354-70aa3797 +ci_gpu: tlcpack/ci-gpu:20250513-063354-70aa3797 +ci_hexagon: tlcpack/ci-hexagon:20250513-063354-70aa3797 +ci_i386: tlcpack/ci-i386:20250513-063354-70aa3797 +ci_lint: tlcpack/ci-lint:20250513-063354-70aa3797 +ci_wasm: tlcpack/ci-wasm:20250513-063354-70aa3797 diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index 627bb85862f3..e93400f6d637 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-02-15T19:40:24.687837 +// Generated at 2025-05-09T10:31:17.078676 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -516,11 +516,6 @@ def run_build(node_type) { script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/cpu --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/libtvm_allvisible.so build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", label: 'Upload artifacts to S3', ) - - ci_setup(ci_cpu) - // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" - // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: 'Rust build and test') }) } } diff --git a/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 b/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 index c84b0c48a29f..50e47f9bbfaf 100644 --- a/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 +++ b/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 @@ -32,10 +32,6 @@ cmake_build(ci_cpu, 'build') make_cpp_tests(ci_cpu, 'build') {{ m.upload_artifacts(tag='cpu', filenames=tvm_lib + tvm_allvisible + cpptest) }} - ci_setup(ci_cpu) - // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" - // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: 'Rust build and test') {% endcall %} {% set test_method_names = [] %} diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy index 18c96d7817db..af3821919243 100755 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -30,8 +30,8 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> -ci_gpu = 'tlcpack/ci-gpu:20250428-080833-03eadc65' -ci_cpu = 'tlcpack/ci-cpu:20250428-080833-03eadc65' +ci_gpu = 'tlcpack/ci-gpu:20250513-063354-70aa3797' +ci_cpu = 'tlcpack/ci-cpu:20250513-063354-70aa3797' // <--- End of regex-scanned config. // Parameters to allow overriding (in Jenkins UI), the images diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index f9dd4a890369..84261c6ea0ae 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -109,6 +109,7 @@ if(USE_CUDA) message(STATUS "Build with Thrust support") tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu) add_library(tvm_thrust_objs OBJECT ${CONTRIB_THRUST_SRC}) + target_link_libraries(tvm_thrust_objs PRIVATE tvm_ffi_header) target_compile_options(tvm_thrust_objs PRIVATE $<$:--expt-extended-lambda>) target_compile_definitions(tvm_thrust_objs PUBLIC DMLC_USE_LOGGING_LIBRARY=) if (NOT USE_THRUST MATCHES ${IS_TRUE_PATTERN}) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index de4e6f4234d7..716b2198faeb 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -20,12 +20,10 @@ name: tvm-build # The conda channels to lookup the dependencies channels: - - anaconda - conda-forge # The packages to install to the environment dependencies: - - python=3.9 - conda < 24.9.0 - conda-build < 24.9.0 - git diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 7a95ebff8e5c..91ebf8746079 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -106,6 +106,10 @@ RUN bash /install/ubuntu_install_libxsmm.sh COPY install/ubuntu_install_onnx.sh /install/ubuntu_install_onnx.sh RUN bash /install/ubuntu_install_onnx.sh +# NNEF +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + # AArch64 Architecture Envelope Model (AEM) COPY install/ubuntu_install_aprofile_aem.sh /install RUN bash /install/ubuntu_install_aprofile_aem.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 332cb9b4e034..1295c679d778 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -100,6 +100,9 @@ RUN bash /install/ubuntu_install_tflite.sh COPY install/ubuntu_install_dgl.sh /install/ubuntu_install_dgl.sh RUN bash /install/ubuntu_install_dgl.sh +COPY install/ubuntu_install_nnef.sh /install/ubuntu_install_nnef.sh +RUN bash /install/ubuntu_install_nnef.sh + ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility COPY install/ubuntu_install_vulkan.sh /install/ubuntu_install_vulkan.sh RUN bash /install/ubuntu_install_vulkan.sh diff --git a/docker/README.md b/docker/README.md index ecc6e7948957..7d3fd22dc911 100644 --- a/docker/README.md +++ b/docker/README.md @@ -130,9 +130,3 @@ tasks. ```bash ./docker/ci_build.sh ci_gpu make -C docs html ``` - -- build golang test suite. - - ```bash - ./docker/build.sh ci_cpu tests/scripts/task_golang.sh - ``` diff --git a/rust/Cargo.toml b/docker/install/ubuntu_install_nnef.sh old mode 100644 new mode 100755 similarity index 91% rename from rust/Cargo.toml rename to docker/install/ubuntu_install_nnef.sh index 26b4398b427d..c7bc59bddffb --- a/rust/Cargo.toml +++ b/docker/install/ubuntu_install_nnef.sh @@ -1,3 +1,4 @@ +#!/bin/bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,9 +16,9 @@ # specific language governing permissions and limitations # under the License. -[workspace] -members = [ - "tvm-sys", - "tvm-macros", - "tvm-rt" -] +set -e +set -u +set -o pipefail + +pip3 install \ + nnef==1.0.8 diff --git a/docker/python/ci-constraints.txt b/docker/python/ci-constraints.txt index d28e81599775..70795265c6c8 100644 --- a/docker/python/ci-constraints.txt +++ b/docker/python/ci-constraints.txt @@ -28,3 +28,4 @@ tensorflow-gpu = "==2.7.2" tflite = "==2.4.0" torch = "==1.11.0" torchvision = "==0.12.0+cpu" +nnef = "==1.0.8" diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index e39468f0bf78..09867f88fa36 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -153,18 +153,18 @@ then be registered with the following steps. #. Register the function to the tvm registry:: - TVM_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global); + TVM_FFI_REGISTER_GLOBAL("device_api.foo").set_body_typed(FooDeviceAPI::Global); -.. _c_runtime_api.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h +.. _base.h: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h #. Add an entry for the new DeviceAPI to the ``TVMDeviceExtType`` enum - in `c_runtime_api.h`_. The value should be an unused value greater + in `base.h`_. The value should be an unused value greater than ``DLDeviceType::kDLExtDev``, but less than ``DeviceAPIManager::kMaxDeviceAPI``. #. Add a case in ``DeviceName`` in `device_api.h`_ to convert from the enum value to a string representation. This string representation - should match the name given to ``TVM_REGISTER_GLOBAL``. + should match the name given to ``TVM_FFI_REGISTER_GLOBAL``. #. Add entries to the ``DEVICE_TYPE_TO_NAME`` and ``DEVICE_NAME_TO_TYPE`` dictionaries of :py:class:`tvm.runtime.Device` for the new enum value. @@ -225,7 +225,7 @@ the same name as was used in the ``TVM_REGISTER_TARGET_KIND`` definition above. :: tvm::runtime::Module GeneratorFooCode(IRModule mod, Target target); - TVM_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode); + TVM_FFI_REGISTER_GLOBAL("target.build.foo").set_body_typed(GeneratorFooCode); The code generator takes two arguments. The first is the ``IRModule`` to compile, and the second is the ``Target`` that describes the device diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 85e9f45a5fba..c54ba18b0add 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -376,7 +376,7 @@ Python when needed. return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } - TVM_REGISTER_GLOBAL("relax.transform.FoldConstant") + TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant") .set_body_typed(FoldConstant); } // namespace transform @@ -550,7 +550,7 @@ a certain scope. .. code:: python - @tvm._ffi.register_object("transform.PassContext") + @tvm.ffi.register_object("transform.PassContext") class PassContext(tvm.runtime.Object): def __enter__(self): _transform.EnterPassContext(self) diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index f797039ee386..f1642827cd48 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -52,7 +52,7 @@ The following code block provides an example in C++ .. code:: c - #include + #include void MyAdd(ffi::PackedArgs args, ffi::Any* rv) { // automatically convert arguments to desired type. @@ -80,7 +80,7 @@ The following example registers PackedFunc in C++ and calls from python. .. code:: c // register a global packed function in c++ - TVM_REGISTER_GLOBAL("myadd") + TVM_FFI_REGISTER_GLOBAL("myadd") .set_body_packed(MyAdd); .. code:: python @@ -110,7 +110,7 @@ we can pass functions from python (as PackedFunc) to C++. .. code:: c - TVM_REGISTER_GLOBAL("callhello") + TVM_FFI_REGISTER_GLOBAL("callhello") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { PackedFunc f = args[0]; f("hello world"); @@ -134,7 +134,7 @@ which allows us to embed the PackedFunc into any languages. Besides python, so f `java`_ and `javascript`_. This philosophy of embedded API is very like Lua, except that we don't have a new language but use C++. -.. _minimum C API: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h +.. _minimum C API: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h .. _java: https://github.com/apache/tvm/tree/main/jvm .. _javascript: https://github.com/apache/tvm/tree/main/web @@ -282,7 +282,7 @@ Each argument in PackedFunc contains a union value `TVMValue`_ and a type code. This design allows the dynamically typed language to convert to the corresponding type directly, and statically typed language to do runtime type checking during conversion. -.. _TVMValue: https://github.com/apache/tvm/blob/main/include/tvm/runtime/c_runtime_api.h#L135 +.. _TVMValue: https://github.com/apache/tvm/blob/main/include/tvm/runtime/base.h#L135 The relevant files are diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 8274904037bc..5b56b003e8cc 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -99,14 +99,41 @@ class AnyView { return *this; } + /*! + * \brief Try to see if we can reinterpret the AnyView to as T object. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note This function won't try run type conversion (use try_cast for that purpose). + */ template ::convert_enabled>> TVM_FFI_INLINE std::optional as() const { - return TypeTraits::TryConvertFromAnyView(&data_); + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::CopyFromAnyViewAfterCheck(&data_); + } else { + return std::optional(std::nullopt); + } + } + /* + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T* as() const { + return this->as().value_or(nullptr); } + /** + * \brief Cast to a type T. + * + * \tparam T The type to cast to. + * \return The casted value, or throws an exception if the cast is not possible. + */ template ::convert_enabled>> TVM_FFI_INLINE T cast() const { - std::optional opt = TypeTraits::TryConvertFromAnyView(&data_); + std::optional opt = TypeTraits::TryCastFromAnyView(&data_); if (!opt.has_value()) { TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" @@ -115,16 +142,17 @@ class AnyView { return *std::move(opt); } - /* - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + /*! + * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. */ - template >> - TVM_FFI_INLINE const T* as() const { - return this->as().value_or(nullptr); + template ::convert_enabled>> + TVM_FFI_INLINE std::optional try_cast() const { + return TypeTraits::TryCastFromAnyView(&data_); } + // comparison with nullptr TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return data_.type_index == TypeIndex::kTVMFFINone; @@ -269,13 +297,45 @@ class Any { return *this; } + /** + * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note This function won't try to run type conversion (use try_cast for that purpose). + */ + template ::storage_enabled || std::is_same_v>> + TVM_FFI_INLINE std::optional as() && { + if constexpr (std::is_same_v) { + return std::move(*this); + } else { + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::MoveFromAnyAfterCheck(&data_); + } else { + return std::optional(std::nullopt); + } + } + } + + /** + * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note This function won't try to run type conversion (use try_cast for that purpose). + */ template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() const { + TVM_FFI_INLINE std::optional as() const& { if constexpr (std::is_same_v) { return *this; } else { - return TypeTraits::TryConvertFromAnyView(&data_); + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::CopyFromAnyViewAfterCheck(&data_); + } else { + return std::optional(std::nullopt); + } } } @@ -286,13 +346,18 @@ class Any { * \return The requested pointer, returns nullptr if type mismatches. */ template >> - TVM_FFI_INLINE const T* as() const { + TVM_FFI_INLINE const T* as() const& { return this->as().value_or(nullptr); } + /** + * \brief Cast to a type T, throw an exception if the cast is not possible. + * + * \tparam T The type to cast to. + */ template ::convert_enabled>> TVM_FFI_INLINE T cast() const& { - std::optional opt = TypeTraits::TryConvertFromAnyView(&data_); + std::optional opt = TypeTraits::TryCastFromAnyView(&data_); if (!opt.has_value()) { TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" @@ -301,13 +366,18 @@ class Any { return *std::move(opt); } + /** + * \brief Cast to a type T, throw an exception if the cast is not possible. + * + * \tparam T The type to cast to. + */ template ::storage_enabled>> TVM_FFI_INLINE T cast() && { - if (TypeTraits::CheckAnyStorage(&data_)) { - return TypeTraits::MoveFromAnyStorageAfterCheck(&data_); + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::MoveFromAnyAfterCheck(&data_); } // slow path, try to do fallback convert - std::optional opt = TypeTraits::TryConvertFromAnyView(&data_); + std::optional opt = TypeTraits::TryCastFromAnyView(&data_); if (!opt.has_value()) { TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" @@ -316,6 +386,22 @@ class Any { return *std::move(opt); } + /** + * \brief Try to cast to a type T. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note use STL name since it to be more consistent with cast API. + */ + template ::convert_enabled || std::is_same_v>> + TVM_FFI_INLINE std::optional try_cast() const { + if constexpr (std::is_same_v) { + return *this; + } else { + return TypeTraits::TryCastFromAnyView(&data_); + } + } /* * \brief Check if the two Any are same type and value in shallow comparison. * \param other The other Any @@ -412,19 +498,28 @@ struct AnyUnsafe : public ObjectUnsafe { } template - static TVM_FFI_INLINE bool CheckAnyStorage(const Any& ref) { - return TypeTraits::CheckAnyStorage(&(ref.data_)); + static TVM_FFI_INLINE bool CheckAnyStrict(const Any& ref) { + return TypeTraits::CheckAnyStrict(&(ref.data_)); } template - static TVM_FFI_INLINE T CopyFromAnyStorageAfterCheck(const Any& ref) { + static TVM_FFI_INLINE T CopyFromAnyViewAfterCheck(const Any& ref) { if constexpr (!std::is_same_v) { - return TypeTraits::CopyFromAnyStorageAfterCheck(&(ref.data_)); + return TypeTraits::CopyFromAnyViewAfterCheck(&(ref.data_)); } else { return ref; } } + template + static TVM_FFI_INLINE T MoveFromAnyAfterCheck(Any&& ref) { + if constexpr (!std::is_same_v) { + return TypeTraits::MoveFromAnyAfterCheck(&(ref.data_)); + } else { + return std::move(ref); + } + } + static TVM_FFI_INLINE Object* ObjectPtrFromAnyAfterCheck(const Any& ref) { return reinterpret_cast(ref.data_.v_obj); } @@ -452,7 +547,7 @@ struct AnyHash { if (src.data_.type_index == TypeIndex::kTVMFFIStr || src.data_.type_index == TypeIndex::kTVMFFIBytes) { const BytesObjBase* src_str = - details::AnyUnsafe::CopyFromAnyStorageAfterCheck(src); + details::AnyUnsafe::CopyFromAnyViewAfterCheck(src); return details::StableHashBytes(src_str->data, src_str->size); } else { return src.data_.v_uint64; @@ -478,9 +573,9 @@ struct AnyEqual { if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { const BytesObjBase* lhs_str = - details::AnyUnsafe::CopyFromAnyStorageAfterCheck(lhs); + details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); const BytesObjBase* rhs_str = - details::AnyUnsafe::CopyFromAnyStorageAfterCheck(rhs); + details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; } return false; @@ -488,5 +583,11 @@ struct AnyEqual { }; } // namespace ffi + +// Expose to the tvm namespace for usability +// Rationale: no ambiguity even in root +using tvm::ffi::Any; +using tvm::ffi::AnyView; + } // namespace tvm #endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h index 18cc3ecb726f..eeb892eff65e 100644 --- a/ffi/include/tvm/ffi/base_details.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -123,9 +123,9 @@ * This macro is used to clear the padding parts for hash and equality check * in 32bit platform. */ -#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ - if constexpr (sizeof(result->v_obj) != sizeof(result->v_int64)) { \ - result->v_int64 = 0; \ +#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ + if constexpr (sizeof((result)->v_obj) != sizeof((result)->v_int64)) { \ + (result)->v_int64 = 0; \ } namespace tvm { diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 1d495d9c5e96..996eaa369b3b 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -27,9 +27,21 @@ #include #include +// Macros to do weak linking +#ifdef _MSC_VER +#define TVM_FFI_WEAK __declspec(selectany) +#else +#define TVM_FFI_WEAK __attribute__((weak)) +#endif + +// Defines two macros +// TVM_FFI_DLL: marks the function as a DLL export/import +// depending on whether TVM_FFI_EXPORTS is defined +// TVM_FFI_DLL_EXPORT: always marks the function as a DLL export #if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) #include #define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE +#define TVM_FFI_DLL_EXPORT EMSCRIPTEN_KEEPALIVE #endif #if !defined(TVM_FFI_DLL) && defined(_MSC_VER) #ifdef TVM_FFI_EXPORTS @@ -37,9 +49,11 @@ #else #define TVM_FFI_DLL __declspec(dllimport) #endif +#define TVM_FFI_DLL_EXPORT __declspec(dllexport) #endif #ifndef TVM_FFI_DLL #define TVM_FFI_DLL __attribute__((visibility("default"))) +#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) #endif #ifdef __cplusplus @@ -579,8 +593,10 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* * \return 0 when success, nonzero when failure happens * \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree. The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. + + * \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. */ -TVM_FFI_DLL int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out); +TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out); //------------------------------------------------------------ // Section: Backend noexcept functions for internal use diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index 99069fb13b3c..5995abbf1516 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -156,5 +156,10 @@ inline OptionalType Downcast(const std::optional& ref) { } } // namespace ffi + +// Expose to the tvm namespace +// Rationale: convinience and no ambiguity +using ffi::Downcast; +using ffi::GetRef; } // namespace tvm #endif // TVM_FFI_CAST_H_ diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h index 5922eacb10f5..30402d9ae68c 100644 --- a/ffi/include/tvm/ffi/container/array.h +++ b/ffi/include/tvm/ffi/container/array.h @@ -386,9 +386,7 @@ class Array : public ObjectRef { // iterators struct ValueConverter { using ResultType = T; - static T convert(const Any& n) { - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(n); - } + static T convert(const Any& n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck(n); } }; using iterator = details::IterAdapter; @@ -427,7 +425,7 @@ class Array : public ObjectRef { if (i < 0 || i >= p->size_) { TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; } - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*(p->begin() + i)); + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin() + i)); } /*! \return The size of the array */ @@ -451,7 +449,7 @@ class Array : public ObjectRef { if (p == nullptr || p->size_ == 0) { TVM_FFI_THROW(IndexError) << "cannot index a empty array"; } - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*(p->begin())); + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin())); } /*! \return The last element of the array */ @@ -460,7 +458,7 @@ class Array : public ObjectRef { if (p == nullptr || p->size_ == 0) { TVM_FFI_THROW(IndexError) << "cannot index a empty array"; } - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*(p->end() - 1)); + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->end() - 1)); } public: @@ -835,7 +833,7 @@ class Array : public ObjectRef { // no other shared copies of the array. auto arr = static_cast(data.get()); for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { - T value = details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*it); + T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it); // reset the original value to nullptr, to ensure unique ownership it->reset(); T mapped = fmap(std::move(value)); @@ -860,7 +858,7 @@ class Array : public ObjectRef { // `T`. bool all_identical = true; for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*it)); + U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); if (!(*it).same_as(mapped)) { // At least one mapped element is different than the // original. Therefore, prepare the output array, @@ -914,7 +912,7 @@ class Array : public ObjectRef { // so we can either start or resume the iteration from that point, // with no further checks on the result. for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*it)); + U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); output->SetItem(it - arr->begin(), std::move(mapped)); } @@ -952,7 +950,7 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public ObjectRefTypeTraitsBase> { static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray; - using ObjectRefTypeTraitsBase>::CopyFromAnyStorageAfterCheck; + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { if (src->type_index != TypeIndex::kTVMFFIArray) { @@ -962,10 +960,10 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { const ArrayObj* n = reinterpret_cast(src->v_obj); for (size_t i = 0; i < n->size(); i++) { const Any& any_v = (*n)[i]; - // CheckAnyStorage is cheaper than as - if (details::AnyUnsafe::CheckAnyStorage(any_v)) continue; + // CheckAnyStrict is cheaper than try_cast + if (details::AnyUnsafe::CheckAnyStrict(any_v)) continue; // try see if p is convertible to T - if (any_v.as()) continue; + if (any_v.try_cast()) continue; // now report the accurate mismatch information return "Array[index " + std::to_string(i) + ": " + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; @@ -975,7 +973,7 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { TVM_FFI_UNREACHABLE(); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { if (src->type_index != TypeIndex::kTVMFFIArray) return false; if constexpr (std::is_same_v) { return true; @@ -983,13 +981,13 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { const ArrayObj* n = reinterpret_cast(src->v_obj); for (size_t i = 0; i < n->size(); i++) { const Any& any_v = (*n)[i]; - if (!details::AnyUnsafe::CheckAnyStorage(any_v)) return false; + if (!details::AnyUnsafe::CheckAnyStrict(any_v)) return false; } return true; } } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional> TryCastFromAnyView(const TVMFFIAny* src) { // try to run conversion. if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; if constexpr (!std::is_same_v) { @@ -997,20 +995,20 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { bool storage_check = [&]() { for (size_t i = 0; i < n->size(); i++) { const Any& any_v = (*n)[i]; - if (!details::AnyUnsafe::CheckAnyStorage(any_v)) return false; + if (!details::AnyUnsafe::CheckAnyStrict(any_v)) return false; } return true; }(); // fast path, if storage check passes, we can return the array directly. if (storage_check) { - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } // slow path, try to run a conversion to Array Array result; result.reserve(n->size()); for (size_t i = 0; i < n->size(); i++) { const Any& any_v = (*n)[i]; - if (auto opt_v = any_v.as()) { + if (auto opt_v = any_v.try_cast()) { result.push_back(*std::move(opt_v)); } else { return std::nullopt; @@ -1018,7 +1016,7 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { } return result; } else { - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } } @@ -1031,5 +1029,9 @@ inline constexpr bool type_contains_v, Array> = type_contains_v typename std::enable_if::value, typename T::difference_type>::type inline @@ -284,6 +294,14 @@ inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); +/*! + * \brief Check if all T are compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool all_object_ref_v = (std::is_base_of_v && ...); /** * \brief Check if Any storage of Derived can always be directly used as Base. * diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h index 7d805d61e9ee..a738c1e229f6 100644 --- a/ffi/include/tvm/ffi/container/map.h +++ b/ffi/include/tvm/ffi/container/map.h @@ -1354,7 +1354,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ const V at(const K& key) const { - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(GetMapObj()->at(key)); + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); } /*! * \brief Read element from map. @@ -1396,13 +1396,13 @@ class Map : public ObjectRef { iterator end() const { return iterator(GetMapObj()->end()); } /*! \return find the key and returns the associated iterator */ iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); } - /*! \return The value associated with the key, NullOpt if not found */ + /*! \return The value associated with the key, std::nullopt if not found */ std::optional Get(const K& key) const { MapObj::iterator iter = GetMapObj()->find(key); if (iter == GetMapObj()->end()) { return std::nullopt; } - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(iter->second); + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); } void erase(const K& key) { CopyOnWrite()->erase(key); } @@ -1445,8 +1445,8 @@ class Map : public ObjectRef { /*! \brief De-reference iterators */ reference operator*() const { auto& kv = *itr; - return std::make_pair(details::AnyUnsafe::CopyFromAnyStorageAfterCheck(kv.first), - details::AnyUnsafe::CopyFromAnyStorageAfterCheck(kv.second)); + return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), + details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); } /*! \brief Prefix self increment, e.g. ++iter */ iterator& operator++() { @@ -1513,7 +1513,7 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public ObjectRefTypeTraitsBase> { static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; - using ObjectRefTypeTraitsBase>::CopyFromAnyStorageAfterCheck; + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { if (src->type_index != TypeIndex::kTVMFFIMap) { @@ -1523,14 +1523,15 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { const MapObj* n = reinterpret_cast(src->v_obj); for (const auto& kv : *n) { if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStorage(kv.first) && !kv.first.as().has_value()) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && + !kv.first.try_cast().has_value()) { return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + ", V]"; } } if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStorage(kv.second) && - !kv.second.as().has_value()) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && + !kv.second.try_cast().has_value()) { return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + "]"; } @@ -1541,7 +1542,7 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { TVM_FFI_UNREACHABLE(); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { if (src->type_index != TypeIndex::kTVMFFIMap) return false; if constexpr (std::is_same_v && std::is_same_v) { return true; @@ -1549,44 +1550,44 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase> { const MapObj* n = reinterpret_cast(src->v_obj); for (const auto& kv : *n) { if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStorage(kv.first)) return false; + if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; } if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStorage(kv.second)) return false; + if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; } } return true; } } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional> TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt; if constexpr (!std::is_same_v || !std::is_same_v) { const MapObj* n = reinterpret_cast(src->v_obj); bool storage_check = [&]() { for (const auto& kv : *n) { if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStorage(kv.first)) return false; + if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; } if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStorage(kv.second)) return false; + if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; } } return true; }(); // fast path, if storage check passes, we can return the array directly. - if (storage_check) return CopyFromAnyStorageAfterCheck(src); + if (storage_check) return CopyFromAnyViewAfterCheck(src); // slow path, we need to create a new map and convert to the target type. Map ret; for (const auto& kv : *n) { - auto k = kv.first.as(); - auto v = kv.second.as(); + auto k = kv.first.try_cast(); + auto v = kv.second.try_cast(); if (!k.has_value() || !v.has_value()) return std::nullopt; ret.Set(*std::move(k), *std::move(v)); } return ret; } else { - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } } @@ -1602,5 +1603,9 @@ inline constexpr bool type_contains_v, Map> = } // namespace details } // namespace ffi + +// Expose to the tvm namespace +// Rationale: convinience and no ambiguity +using ffi::Map; } // namespace tvm #endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 63c36467fe3d..10303f0aecab 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -55,10 +55,13 @@ class Tuple : public ObjectRef { template && ...), int>> Tuple(Tuple&& other) : ObjectRef(std::move(other)) {} - template - explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) { - static_assert(sizeof...(Types) == sizeof...(UTypes), "Tuple size mismatch"); - } + + template , Tuple> && + ...))>> + explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { data_ = other.data_; @@ -98,7 +101,7 @@ class Tuple : public ObjectRef { static_assert(I < sizeof...(Types), "Tuple index out of bounds"); using ReturnType = std::tuple_element_t>; const Any* ptr = GetArrayObj()->begin() + I; - return details::AnyUnsafe::CopyFromAnyStorageAfterCheck(*ptr); + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); } /*! @@ -179,7 +182,7 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public ObjectRefTypeTraitsBase> { - using ObjectRefTypeTraitsBase>::CopyFromAnyStorageAfterCheck; + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { if (src->type_index != TypeIndex::kTVMFFIArray) { @@ -196,7 +199,7 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase) { const Any& any_v = arr[I]; - if (!details::AnyUnsafe::CheckAnyStorage(any_v) && !(any_v.as().has_value())) { + if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { // now report the accurate mismatch information return "Array[index " + std::to_string(I) + ": " + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; @@ -209,39 +212,38 @@ struct TypeTraits> : public ObjectRefTypeTraitsBasetype_index != TypeIndex::kTVMFFIArray) return false; const ArrayObj* n = reinterpret_cast(src->v_obj); if (n->size() != sizeof...(Types)) return false; const TVMFFIAny* ffi_any_arr = reinterpret_cast(n->begin()); - return CheckAnyStorageHelper<0, Types...>(ffi_any_arr); + return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); } template - static TVM_FFI_INLINE bool CheckAnyStorageHelper(const TVMFFIAny* src_arr) { + static TVM_FFI_INLINE bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) { if constexpr (!std::is_same_v) { - if (!TypeTraits::CheckAnyStorage(src_arr + I)) { + if (!TypeTraits::CheckAnyStrict(src_arr + I)) { return false; } } if constexpr (sizeof...(Rest) > 0) { - return CheckAnyStorageHelper(src_arr); + return CheckAnyStrictHelper(src_arr); } return true; } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView( - const TVMFFIAny* src // + static TVM_FFI_INLINE std::optional> TryCastFromAnyView(const TVMFFIAny* src // ) { if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; const ArrayObj* n = reinterpret_cast(src->v_obj); if (n->size() != sizeof...(Types)) return std::nullopt; // fast path, storage is already in the right type - if (CheckAnyStorage(src)) { - return CopyFromAnyStorageAfterCheck(src); + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); } // slow path, try to convert to each type to match the tuple storage need. - Array arr = TypeTraits>::CopyFromAnyStorageAfterCheck(src); + Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); Any* ptr = arr.CopyOnWrite()->MutableBegin(); if (TryConvertElements<0, Types...>(ptr)) { return Tuple(details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); @@ -252,7 +254,7 @@ struct TypeTraits> : public ObjectRefTypeTraitsBase static TVM_FFI_INLINE bool TryConvertElements(Any* arr) { if constexpr (!std::is_same_v) { - if (auto opt_convert = arr[I].as()) { + if (auto opt_convert = arr[I].try_cast()) { arr[I] = *std::move(opt_convert); } else { return false; @@ -275,5 +277,9 @@ inline constexpr bool type_contains_v, Tuple> = (type_contains } // namespace details } // namespace ffi + +// Expose to the tvm namespace +// Rationale: convinience and no ambiguity +using ffi::Tuple; } // namespace tvm #endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index c8b58ba49e39..373d0aaa70b8 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -34,15 +34,73 @@ namespace tvm { namespace ffi { +namespace details { +/*! + * \brief Base class for Variant. + * + * \tparam all_storage_object Whether all types are derived from ObjectRef. + */ +template +class VariantBase { + public: + TVM_FFI_INLINE bool same_as(const VariantBase& other) const { + return data_.same_as(other.data_); + } + + protected: + template + explicit VariantBase(T other) : data_(std::move(other)) {} + + TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } + + TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } + + TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } + + Any data_; +}; + +// Specialization for all object ref case, backed by ObjectRef. +template <> +class VariantBase : public ObjectRef { + protected: + template + explicit VariantBase(const T& other) : ObjectRef(other) {} + template + explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {} + explicit VariantBase(ObjectPtr ptr) : ObjectRef(ptr) {} + explicit VariantBase(Any other) + : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} + + TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } + + TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } + + TVM_FFI_INLINE AnyView ToAnyView() const { + TVMFFIAny any_data; + if (data_ == nullptr) { + any_data.type_index = TypeIndex::kTVMFFINone; + any_data.v_int64 = 0; + } else { + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); + any_data.type_index = data_->type_index(); + any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); + } + return AnyView::CopyFromTVMFFIAny(any_data); + } +}; +} // namespace details /*! * \brief A typed variant container. * - * A Variant is backed by Any container, with strong checks during construction. + * When all values are ObjectRef, Variant is backed by ObjectRef, + * otherwise it is backed by Any. */ template -class Variant { +class Variant : public details::VariantBase> { public: + using TParent = details::VariantBase>; static_assert(details::all_storage_enabled_v, "All types used in Variant<...> must be compatible with Any"); /* @@ -54,31 +112,30 @@ class Variant { template using enable_if_variant_contains_t = std::enable_if_t>; - Variant(const Variant& other) : data_(other.data_) {} - Variant(Variant&& other) : data_(std::move(other.data_)) {} + Variant(const Variant& other) : TParent(other.data_) {} + Variant(Variant&& other) : TParent(std::move(other.data_)) {} TVM_FFI_INLINE Variant& operator=(const Variant& other) { - data_ = other.data_; + this->SetData(other.data_); return *this; } TVM_FFI_INLINE Variant& operator=(Variant&& other) { - data_ = std::move(other.data_); + this->SetData(std::move(other.data_)); return *this; } template > - Variant(T other) : data_(std::move(other)) {} // NOLINT(*) + Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) template > TVM_FFI_INLINE Variant& operator=(T other) { - data_ = std::move(other); - return *this; + return operator=(Variant(std::move(other))); } template > TVM_FFI_INLINE std::optional as() const { - return data_.as(); + return this->TParent::ToAnyView().template as(); } /* @@ -89,29 +146,27 @@ class Variant { */ template >> TVM_FFI_INLINE const T* as() const { - return data_.as().value_or(nullptr); + return this->TParent::ToAnyView().template as().value_or(nullptr); } template > TVM_FFI_INLINE T get() const& { - return data_.template cast(); + return this->TParent::ToAnyView().template cast(); } template > TVM_FFI_INLINE T get() && { - return std::move(data_).template cast(); + return std::move(*this).TParent::MoveToAny().template cast(); } - TVM_FFI_INLINE std::string GetTypeKey() const { return data_.GetTypeKey(); } + TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } private: friend struct TypeTraits>; friend struct ObjectPtrHash; friend struct ObjectPtrEqual; // constructor from any - explicit Variant(Any data) : data_(std::move(data)) {} - // internal data is backed by Any - Any data_; + explicit Variant(Any data) : TParent(std::move(data)) {} /*! * \brief Get the object pointer from the variant * \note This function is only available if all types used in Variant<...> are derived from @@ -122,8 +177,11 @@ class Variant { static_assert(all_object_v, "All types used in Variant<...> must be derived from ObjectRef " "to enable ObjectPtrHash/ObjectPtrEqual"); - return details::AnyUnsafe::ObjectPtrFromAnyAfterCheck(data_); + return this->data_.get(); } + // rexpose to friend class + using TParent::MoveToAny; + using TParent::ToAnyView; }; template @@ -132,33 +190,33 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public TypeTraitsBase { static TVM_FFI_INLINE void CopyToAnyView(const Variant& src, TVMFFIAny* result) { - *result = AnyView(src.data_).CopyToTVMFFIAny(); + *result = src.ToAnyView().CopyToTVMFFIAny(); } static TVM_FFI_INLINE void MoveToAny(Variant src, TVMFFIAny* result) { - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src.data_)); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); } static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { return TypeTraitsBase::GetMismatchTypeInfo(src); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { - return (TypeTraits::CheckAnyStorage(src) || ...); + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + return (TypeTraits::CheckAnyStrict(src) || ...); } - static TVM_FFI_INLINE Variant CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE Variant CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); } - static TVM_FFI_INLINE Variant MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE Variant MoveFromAnyAfterCheck(TVMFFIAny* src) { return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(*src))); } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional> TryCastFromAnyView(const TVMFFIAny* src) { // fast path, storage is already in the right type - if (CheckAnyStorage(src)) { - return CopyFromAnyStorageAfterCheck(src); + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); } // More expensive path, try to convert to each type, in order of declaration return TryVariantTypes(src); @@ -166,7 +224,7 @@ struct TypeTraits> : public TypeTraitsBase { template static TVM_FFI_INLINE std::optional> TryVariantTypes(const TVMFFIAny* src) { - if (auto opt_convert = TypeTraits::TryConvertFromAnyView(src)) { + if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { return Variant(*std::move(opt_convert)); } if constexpr (sizeof...(Rest) > 0) { @@ -194,5 +252,9 @@ template inline constexpr bool type_contains_v, T> = (type_contains_v || ...); } // namespace details } // namespace ffi + +// Expose to the tvm namespace +// Rationale: convinience and no ambiguity +using ffi::Variant; } // namespace tvm #endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h index 99eb227ee1af..954c77fdfa17 100644 --- a/ffi/include/tvm/ffi/dtype.h +++ b/ffi/include/tvm/ffi/dtype.h @@ -121,7 +121,7 @@ inline DLDataType StringToDLDataType(const String& str) { inline String DLDataTypeToString(DLDataType dtype) { TVMFFIObjectHandle out; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(dtype, &out)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); return String(details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(out))); } @@ -140,20 +140,20 @@ struct TypeTraits : public TypeTraitsBase { result->v_dtype = src; } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index == TypeIndex::kTVMFFIDataType; } - static TVM_FFI_INLINE DLDataType CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return src->v_dtype; } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIDataType) { return src->v_dtype; } // enable string to dtype auto conversion - if (auto opt_str = TypeTraits::TryConvertFromAnyView(src)) { + if (auto opt_str = TypeTraits::TryCastFromAnyView(src)) { return StringToDLDataType(*opt_str); } return std::nullopt; diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index 239a0e500b73..de754bd6ea77 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -51,6 +51,10 @@ #define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 #endif +#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0 +#endif + namespace tvm { namespace ffi { @@ -212,8 +216,10 @@ class ErrorBuilder { * * \endcode */ -#define TVM_FFI_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, false).stream() +#define TVM_FFI_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder(#ErrorKind, TVM_FFI_TRACEBACK_HERE, \ + TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \ + .stream() /*! * \brief Explicitly log error in stderr and then throw the error. diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index bdcca7b73ead..128c67830e84 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -735,17 +735,17 @@ struct TypeTraits> : public TypeTraitsBase { TypeTraits::MoveToAny(std::move(src.packed()), result); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index == TypeIndex::kTVMFFIFunction; } - static TVM_FFI_INLINE TypedFunction CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { - return TypedFunction(TypeTraits::CopyFromAnyStorageAfterCheck(src)); + static TVM_FFI_INLINE TypedFunction CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return TypedFunction(TypeTraits::CopyFromAnyViewAfterCheck(src)); } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView( + static TVM_FFI_INLINE std::optional> TryCastFromAnyView( const TVMFFIAny* src) { - std::optional opt = TypeTraits::TryConvertFromAnyView(src); + std::optional opt = TypeTraits::TryCastFromAnyView(src); if (opt.has_value()) { return TypedFunction(*std::move(opt)); } else { @@ -787,7 +787,7 @@ class Function::Registry { * .set_body_typed(multiply); // will have type int(int, int) * * // will have type int(int, int) - * TVM_REGISTER_GLOBAL("sub") + * TVM_FFI_REGISTER_GLOBAL("sub") * .set_body_typed([](int a, int b) -> int { return a - b; }); * * \endcode @@ -909,6 +909,55 @@ inline int32_t TypeKeyToIndex(std::string_view type_key) { */ #define TVM_FFI_REGISTER_GLOBAL(OpName) \ TVM_FFI_STR_CONCAT(TVM_FFI_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::ffi::Function::Registry(OpName) + +/*! + * \brief Export typed function as a SafeCallType symbol. + * + * \param ExportName The symbol name to be exported. + * \param Function The typed function. + * \note ExportName and Function must be different, + * see code examples below. + * + * \sa ffi::TypedFunction + * + * \code + * + * int AddOne_(int x) { + * return x + 1; + * } + * + * // Expose the function as "AddOne" + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); + * + * // Expose the function as "SubOne" + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { + * return x - 1; + * }); + * + * // The following code will cause compilation error. + * // Because the same Function and ExportName + * // TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); + * + * // The following code is OK, assuming the macro + * // is in a different namespace from xyz + * // TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); + * + * \endcode + */ +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_FFI_DLL_EXPORT int ExportName(void* self, TVMFFIAny* args, int32_t num_args, \ + TVMFFIAny* result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ + static std::string name = #ExportName; \ + ::tvm::ffi::details::unpack_call( \ + std::make_index_sequence{}, &name, Function, \ + reinterpret_cast(args), num_args, \ + reinterpret_cast<::tvm::ffi::Any*>(result)); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ + } } // namespace ffi } // namespace tvm #endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index 3e7f9be140c7..6391c4ebba4a 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -137,7 +137,7 @@ class ArgValueWithContext { } else if constexpr (std::is_same_v) { return Any(args_[arg_index_]); } else { - std::optional opt = args_[arg_index_].as(); + std::optional opt = args_[arg_index_].try_cast(); if (!opt.has_value()) { TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny(); TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index d3af2dd49077..eb317d2bbd72 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -205,5 +205,9 @@ inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&.. } } // namespace ffi + +// Export the make_object function +// rationale: ease of use, and no ambiguity +using ffi::make_object; } // namespace tvm #endif // TVM_FFI_MEMORY_H_ diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index 7b3f69ef9919..9c73d6a5705d 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -295,5 +295,8 @@ class Optional>> : public Object } }; } // namespace ffi + +// Expose to the tvm namespace +using ffi::Optional; } // namespace tvm #endif // TVM_FFI_OPTIONAL_H_ diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h index b300436feee6..3b939d2df515 100644 --- a/ffi/include/tvm/ffi/rvalue_ref.h +++ b/ffi/include/tvm/ffi/rvalue_ref.h @@ -114,8 +114,7 @@ struct TypeTraits> : public TypeTraitsBase { } } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView( - const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional> TryCastFromAnyView(const TVMFFIAny* src) { // first try rvalue conversion if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); @@ -123,10 +122,10 @@ struct TypeTraits> : public TypeTraitsBase { tmp_any.type_index = rvalue_ref->get()->type_index(); tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); // fast path, storage type matches, direct move the rvalue ref - if (TypeTraits::CheckAnyStorage(&tmp_any)) { + if (TypeTraits::CheckAnyStrict(&tmp_any)) { return RValueRef(TObjRef(std::move(*rvalue_ref))); } - if (std::optional opt = TypeTraits::TryConvertFromAnyView(&tmp_any)) { + if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { // object type does not match up, we need to try to convert the object // in this case we do not move the original rvalue ref since conversion creates a copy return RValueRef(*std::move(opt)); @@ -134,7 +133,7 @@ struct TypeTraits> : public TypeTraitsBase { return std::nullopt; } // try lvalue conversion - if (std::optional opt = TypeTraits::TryConvertFromAnyView(src)) { + if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { return RValueRef(*std::move(opt)); } else { return std::nullopt; diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 1c22f10892ac..c3eceff90590 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -210,7 +210,7 @@ class Bytes : public ObjectRef { */ class String : public ObjectRef { public: - String(nullptr_t) = delete; // NOLINT(*) + String(std::nullptr_t) = delete; // NOLINT(*) /*! * \brief constructor from char [N] @@ -430,8 +430,8 @@ struct TypeTraits : public TypeTraitsBase { // when we need to move to any, convert to owned object first ObjectRefTypeTraitsBase::MoveToAny(String(src), result); } - // Do not allow const char* in a container, so we do not need CheckAnyStorage - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + // Do not allow const char* in a container, so we do not need CheckAnyStrict + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIRawStr) { return static_cast(src->v_c_str); } @@ -458,8 +458,7 @@ struct TypeTraits : public TypeTraitsBase { ObjectRefTypeTraitsBase::MoveToAny(Bytes(*src), result); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView( - const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { return static_cast(src->v_ptr); } @@ -641,6 +640,11 @@ inline int Bytes::memncmp(const char* lhs, const char* rhs, size_t lhs_count, si } } } // namespace ffi + +// Expose to the tvm namespace for usability +// Rationale: no ambiguity even in root +using ffi::Bytes; +using ffi::String; } // namespace tvm namespace std { diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index d350aea82ac7..02c9a90edcfd 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -43,25 +43,25 @@ namespace ffi { * * - CopyToAnyView: Convert a value T to AnyView * - MoveToAny: Move a value to Any - * - CheckAnyStorage: Check if a Any stores a result of MoveToAny of current T. - * - CopyFromAnyStorageAfterCheck: Copy a value T from Any storage after we pass CheckAnyStorage. - * - MoveFromAnyStorageAfterCheck: Move a value T from Any storage after we pass CheckAnyStorage. - * - TryConvertFromAnyView: Convert a AnyView to a T, we may apply type conversion. - * - GetMismatchTypeInfo: Get the type key of a type when TryConvertFromAnyView fails. + * - CheckAnyStrict: Check if a Any stores a result of CopyToAnyView of current T. + * - CopyFromAnyViewAfterCheck: Copy a value T from Any view after we pass CheckAnyStrict. + * - MoveFromAnyAfterCheck: Move a value T from Any storage after we pass CheckAnyStrict. + * - TryCastFromAnyView: Convert a AnyView to a T, we may apply type conversion. + * - GetMismatchTypeInfo: Get the type key of a type when TryCastFromAnyView fails. * - TypeStr: Get the type key of a type * - * It is possible that CheckAnyStorage is false but TryConvertFromAnyView still works. + * It is possible that CheckAnyStrict is false but TryCastFromAnyView still works. * - * For example, when Any x stores int, TypeTraits::CheckAnyStorage(x) will be false, - * but TypeTraits::TryConvertFromAnyView(x) will return a corresponding float value + * For example, when Any x stores int, TypeTraits::CheckAnyStrict(x) will be false, + * but TypeTraits::TryCastFromAnyView(x) will return a corresponding float value * via type conversion. * - * CheckAnyStorage is mainly used in recursive container such as Array to + * CheckAnyStrict is mainly used in recursive container such as Array to * decide if a new Array needed to be created via recursive conversion, * or we can use the current container as is when converting to Array. * * A container array: Array satisfies the following invariant: - * - `all(TypeTraits::CheckAnyStorage(x) for x in the array)`. + * - `all(TypeTraits::CheckAnyStrict(x) for x in the array)`. */ template struct TypeTraits { @@ -85,7 +85,7 @@ struct TypeTraitsBase { static constexpr bool convert_enabled = true; static constexpr bool storage_enabled = true; // get mismatched type when result mismatches the trait. - // this function is called after TryConvertFromAnyView fails + // this function is called after TryCastFromAnyView fails // to get more detailed type information in runtime // especially when the error involves nested container type static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* source) { @@ -132,17 +132,17 @@ struct TypeTraits : public TypeTraitsBase { result->v_int64 = 0; } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index == TypeIndex::kTVMFFINone; } - static TVM_FFI_INLINE std::nullptr_t CopyFromAnyStorageAfterCheck(const TVMFFIAny*) { + static TVM_FFI_INLINE std::nullptr_t CopyFromAnyViewAfterCheck(const TVMFFIAny*) { return nullptr; } - static TVM_FFI_INLINE std::nullptr_t MoveFromAnyStorageAfterCheck(TVMFFIAny*) { return nullptr; } + static TVM_FFI_INLINE std::nullptr_t MoveFromAnyAfterCheck(TVMFFIAny*) { return nullptr; } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) { return nullptr; } @@ -179,20 +179,20 @@ struct TypeTraits : public TypeTraitsBase { CopyToAnyView(src, result); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index == TypeIndex::kTVMFFIBool; } - static TVM_FFI_INLINE StrictBool CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return static_cast(src->v_int64); } - static TVM_FFI_INLINE StrictBool MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE StrictBool MoveFromAnyAfterCheck(TVMFFIAny* src) { // POD type, we can just copy the value - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIBool) { return StrictBool(static_cast(src->v_int64)); } @@ -214,20 +214,20 @@ struct TypeTraits : public TypeTraitsBase { static TVM_FFI_INLINE void MoveToAny(bool src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index == TypeIndex::kTVMFFIBool; } - static TVM_FFI_INLINE bool CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return static_cast(src->v_int64); } - static TVM_FFI_INLINE bool MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE bool MoveFromAnyAfterCheck(TVMFFIAny* src) { // POD type, we can just copy the value - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { return static_cast(src->v_int64); } @@ -249,21 +249,21 @@ struct TypeTraits>> : public TypeT static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { - // NOTE: CheckAnyStorage is always strict and should be consistent with MoveToAny + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny return src->type_index == TypeIndex::kTVMFFIInt; } - static TVM_FFI_INLINE Int CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE Int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return static_cast(src->v_int64); } - static TVM_FFI_INLINE Int MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE Int MoveFromAnyAfterCheck(TVMFFIAny* src) { // POD type, we can just copy the value - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { return Int(src->v_int64); } @@ -286,21 +286,21 @@ struct TypeTraits>> static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { - // NOTE: CheckAnyStorage is always strict and should be consistent with MoveToAny + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny return src->type_index == TypeIndex::kTVMFFIFloat; } - static TVM_FFI_INLINE Float CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return static_cast(src->v_float64); } - static TVM_FFI_INLINE Float MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE Float MoveFromAnyAfterCheck(TVMFFIAny* src) { // POD type, we can just copy the value - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIFloat) { return Float(src->v_float64); } else if (src->type_index == TypeIndex::kTVMFFIInt || @@ -326,21 +326,19 @@ struct TypeTraits : public TypeTraitsBase { static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { - // NOTE: CheckAnyStorage is always strict and should be consistent with MoveToAny + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny return src->type_index == TypeIndex::kTVMFFIOpaquePtr; } - static TVM_FFI_INLINE void* CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { - return src->v_ptr; - } + static TVM_FFI_INLINE void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return src->v_ptr; } - static TVM_FFI_INLINE void* MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE void* MoveFromAnyAfterCheck(TVMFFIAny* src) { // POD type, we can just copy the value - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { return static_cast(src->v_ptr); } @@ -368,20 +366,20 @@ struct TypeTraits : public TypeTraitsBase { result->v_device = src; } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index == TypeIndex::kTVMFFIDevice; } - static TVM_FFI_INLINE DLDevice CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return src->v_device; } - static TVM_FFI_INLINE DLDevice MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE DLDevice MoveFromAnyAfterCheck(TVMFFIAny* src) { // POD type, we can just copy the value - return CopyFromAnyStorageAfterCheck(src); + return CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIDevice) { return src->v_device; } @@ -404,12 +402,20 @@ struct TypeTraits : public TypeTraitsBase { result->v_ptr = src; } + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; + } + + static TVM_FFI_INLINE DLTensor* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return static_cast(src->v_ptr); + } + static TVM_FFI_INLINE void MoveToAny(DLTensor*, TVMFFIAny*) { TVM_FFI_THROW(RuntimeError) << "DLTensor* cannot be held in Any as it does not retain ownership, use NDArray instead"; } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { return static_cast(src->v_ptr); } else if (src->type_index == TypeIndex::kTVMFFINDArray) { @@ -458,7 +464,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { result->v_obj = obj_ptr; } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) return true; } @@ -466,7 +472,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { details::IsObjectInstance(src->type_index)); } - static TVM_FFI_INLINE TObjRef CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { return TObjRef(ObjectPtr(nullptr)); @@ -475,7 +481,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); } - static TVM_FFI_INLINE TObjRef MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { return TObjRef(ObjectPtr(nullptr)); @@ -488,7 +494,7 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase { return TObjRef(std::move(obj_ptr)); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == TypeIndex::kTVMFFINone) { return TObjRef(ObjectPtr(nullptr)); @@ -525,7 +531,7 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase { // disable container for FallbackOnlyTraitsBase static constexpr bool storage_enabled = false; - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { return TryFallbackTypes(src); } @@ -534,7 +540,7 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase { static_assert(!std::is_same_v, "Using bool as FallbackType can cause bug because int will be detected as bool, " "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryConvertFromAnyView(src)) { + if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); } if constexpr (sizeof...(Rest) > 0) { @@ -557,11 +563,11 @@ struct FallbackOnlyTraitsBase : public TypeTraitsBase { */ template struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase { - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { - if (auto opt_obj = ObjectRefTypeTraitsBase::TryConvertFromAnyView(src)) { + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { + if (auto opt_obj = ObjectRefTypeTraitsBase::TryCastFromAnyView(src)) { return opt_obj.value(); } - // apply fallback types in TryConvertFromAnyView + // apply fallback types in TryCastFromAnyView return TryFallbackTypes(src); } @@ -570,7 +576,7 @@ struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase, "Using bool as FallbackType can cause bug because int will be detected as bool, " "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryConvertFromAnyView(src)) { + if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); } if constexpr (sizeof...(Rest) > 0) { @@ -601,17 +607,17 @@ struct TypeTraitsv_obj); } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && details::IsObjectInstance(src->type_index); } - static TVM_FFI_INLINE const TObject* CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView(const TVMFFIAny* src) { - if (CheckAnyStorage(src)) return CopyFromAnyStorageAfterCheck(src); + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { + if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src); return std::nullopt; } @@ -639,28 +645,28 @@ struct TypeTraits> : public TypeTraitsBase { } } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) return true; - return TypeTraits::CheckAnyStorage(src); + return TypeTraits::CheckAnyStrict(src); } - static TVM_FFI_INLINE Optional CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE Optional CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) { return Optional(std::nullopt); } - return TypeTraits::CopyFromAnyStorageAfterCheck(src); + return TypeTraits::CopyFromAnyViewAfterCheck(src); } - static TVM_FFI_INLINE Optional MoveFromAnyStorageAfterCheck(TVMFFIAny* src) { + static TVM_FFI_INLINE Optional MoveFromAnyAfterCheck(TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) { return Optional(std::nullopt); } - return TypeTraits::MoveFromAnyStorageAfterCheck(src); + return TypeTraits::MoveFromAnyAfterCheck(src); } - static TVM_FFI_INLINE std::optional> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional> TryCastFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) return Optional(std::nullopt); - if (std::optional opt = TypeTraits::TryConvertFromAnyView(src)) { + if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { return Optional(*std::move(opt)); } else { // important to be explicit here diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py new file mode 100644 index 000000000000..b19f566364e4 --- /dev/null +++ b/ffi/scripts/benchmark_dlpack.py @@ -0,0 +1,345 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This script is used to benchmark the API overhead of different +python FFI API calling overhead, through DLPack API. + +Specifically, we would like to understand the overall overhead +python/C++ API calls. The general goal is to understand the overall +space and get a sense of what are the possible operations. + +We pick function f(x, y, z) where x, y, z are length 1 tensors. +The benchmark is running in eager mode so we can see what is possible. +It is orthogonal to other optimizations. For example cudagraph can +eliminate these overheads completely. So the goal is to get a sense +of what is possible under eager mode. + +Summary of some takeaways: +- numpy.add roughly takes 0.36 us per call, which gives roughly what can + be done in python env. +- torch.add on gpu takes about 3.7us per call, giving us an idea of what + roughly we need to get to in eager mode. +- + +""" +import torch +import numpy as np +from tvm import ffi as tvm_ffi +import time + + +def print_speed(name, speed): + print(f"{name:<40} {speed} sec/call") + + +def print_error(name, error): + print(f"{name:<40} {error}") + + +def baseline_torch_add(repeat): + """Run torch.add with one element""" + + def run_bench(device): + x = torch.arange(1, device=device) + y = torch.arange(1, device=device) + z = torch.arange(1, device=device) + + torch.add(x, y, out=z) + if device == "cuda": + torch.cuda.synchronize() + start = time.time() + for i in range(repeat): + torch.add(x, y, out=z) + # note we deliberately do not use torch.cuda.synchronize() + # because we want to see the overhead of the FFI call. + end = time.time() + print_speed(f"torch.add[{device}]", (end - start) / repeat) + + # rough take away: add on cuda roughly takes 3e-6 sec/call + run_bench("cpu") + run_bench("cuda") + + +def baseline_numpy_add(repeat): + """Run numpy.add with one element""" + x = np.arange(1) + y = np.arange(1) + z = np.arange(1) + + np.add(x, y, out=z) + start = time.time() + for i in range(repeat): + np.add(x, y, out=z) + end = time.time() + speed = (end - start) / repeat + print_speed("numpy.add", speed) + + +def baseline_cupy_add(repeat): + """Run cupy.add with one element""" + try: + import cupy + except ImportError: + # skip if cupy is not installed + return + x = cupy.arange(1) + y = cupy.arange(1) + z = cupy.arange(1) + + cupy.add(x, y, out=z) + start = time.time() + for i in range(repeat): + cupy.add(x, y, out=z) + end = time.time() + speed = (end - start) / repeat + print_speed("cupy.add", speed) + + +def tvm_ffi_nop(repeat): + """Overhead of tvm FFI python call via calling a NOP. + + testing.nop is defined in c++ and do nothing. + """ + nop = tvm_ffi.get_global_func("testing.nop") + x = tvm_ffi.from_dlpack(torch.arange(1)) + y = tvm_ffi.from_dlpack(torch.arange(1)) + z = tvm_ffi.from_dlpack(torch.arange(1)) + nop(x, y, z) + start = time.time() + for i in range(repeat): + y = tvm_ffi.from_dlpack(x) + end = time.time() + print_speed("tvm.ffi.nop", (end - start) / repeat) + + +def bench_ffi_nop_from_dlpack(name, x, y, z, repeat): + """run dlpack conversion + tvm.ffi.nop + + Measures overhead of running dlpack for each args then invoke + """ + nop = tvm_ffi.get_global_func("testing.nop") + tx = tvm_ffi.from_dlpack(x) + ty = tvm_ffi.from_dlpack(y) + tz = tvm_ffi.from_dlpack(z) + nop(tx, ty, tz) + + start = time.time() + for i in range(repeat): + tx = tvm_ffi.from_dlpack(x) + ty = tvm_ffi.from_dlpack(y) + tz = tvm_ffi.from_dlpack(z) + nop(tx, ty, tz) + end = time.time() + print_speed(name, (end - start) / repeat) + + +def tvm_ffi_nop_from_torch_dlpack(repeat): + """run dlpack conversion + tvm.ffi.nop + + Measures overhead of running dlpack for each args then invoke + """ + x = torch.arange(1) + y = torch.arange(1) + z = torch.arange(1) + bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(torch)", x, y, z, repeat) + + +def tvm_ffi_nop_from_numpy_dlpack(repeat): + """run dlpack conversion + tvm.ffi.nop + + Measures overhead of running dlpack for each args then invoke + """ + x = np.arange(1) + y = np.arange(1) + z = np.arange(1) + bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(numpy)", x, y, z, repeat) + + +def tvm_ffi_self_dlpack_nop(repeat): + """run dlpack conversion + tvm.ffi.nop + + Measures overhead of running dlpack for each args then invoke + """ + x = tvm_ffi.from_dlpack(torch.arange(1)) + y = tvm_ffi.from_dlpack(torch.arange(1)) + z = tvm_ffi.from_dlpack(torch.arange(1)) + bench_ffi_nop_from_dlpack("tvm.ffi.nop+from_dlpack(tvm)", x, y, z, repeat) + + +def bench_ffi_nop_from_dlpack(name, x, y, z, repeat): + """run dlpack conversion + tvm.ffi.nop + + Measures overhead of running dlpack for each args then invoke + """ + nop = tvm_ffi.get_global_func("testing.nop") + tx = tvm_ffi.from_dlpack(x) + ty = tvm_ffi.from_dlpack(y) + tz = tvm_ffi.from_dlpack(z) + nop(tx, ty, tz) + + start = time.time() + for i in range(repeat): + tx = tvm_ffi.from_dlpack(x) + ty = tvm_ffi.from_dlpack(y) + tz = tvm_ffi.from_dlpack(z) + nop(tx, ty, tz) + end = time.time() + print_speed(name, (end - start) / repeat) + + +def tvm_ffi_nop_from_torch_utils_to_dlpack(repeat): + """ + Measures overhead of running dlpack for each args then invoke + but uses the legacy torch.utils.dlpack.to_dlpack API + + This helps to measure possible implementation overhead of torch. + """ + nop = tvm_ffi.get_global_func("testing.nop") + x = torch.arange(1) + y = torch.arange(1) + z = torch.arange(1) + + tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y)) + tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z)) + nop(tx, ty, tz) + + start = time.time() + for i in range(repeat): + tx = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + ty = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(y)) + tz = tvm_ffi.from_dlpack(torch.utils.dlpack.to_dlpack(z)) + nop(tx, ty, tz) + end = time.time() + speed = (end - start) / repeat + print_speed("tvm.ffi.nop+from_dlpack(torch.utils)", speed) + + +def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat): + """ + Measures overhead of running dlpack via auto convert by directly + take torch.Tensor as inputs. + """ + nop = tvm_ffi.get_global_func("testing.nop") + nop(x, y, z) + start = time.time() + for i in range(repeat): + nop(x, y, z) + end = time.time() + speed = (end - start) / repeat + print_speed(name, speed) + + +def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"): + """ + Measures overhead of running dlpack via auto convert by directly + take torch.Tensor as inputs. + """ + # use larger to ensure alignment req is met + x = torch.arange(1, device=device) + y = torch.arange(1, device=device) + z = torch.arange(1, device=device) + bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, z, repeat) + + +def tvm_ffi_nop_autodlpack_from_numpy(repeat): + """ + Measures overhead of running dlpack via auto convert by directly + take numpy.ndarray as inputs. + """ + # use larger to ensure alignment req is met + x = np.arange(256) + y = np.arange(256) + z = np.arange(256) + bench_tvm_ffi_nop_autodlpack("tvm.ffi.nop.autodlpack(numpy)", x, y, z, repeat) + + +def bench_to_dlpack(x, name, repeat): + x.__dlpack__() + start = time.time() + for i in range(repeat): + x.__dlpack__() + end = time.time() + speed = (end - start) / repeat + print_speed(name, speed) + + +def bench_to_dlpack_versioned(x, name, repeat, max_version=(1, 1)): + """ + Measures overhead of running dlpack with latest 1.1. + """ + try: + x.__dlpack__(max_version=max_version) + start = time.time() + for i in range(repeat): + x.__dlpack__(max_version=max_version) + end = time.time() + speed = (end - start) / repeat + print_speed(name, speed) + except Exception as e: + print_error(name, e) + + +def bench_torch_utils_to_dlpack(repeat): + """ + Measures overhead of running torch.utils.dlpack.to_dlpack + """ + x = torch.arange(1) + torch.utils.dlpack.to_dlpack(x) + start = time.time() + for i in range(repeat): + torch.utils.dlpack.to_dlpack(x) + end = time.time() + speed = (end - start) / repeat + print_speed("torch.utils.dlpack.to_dlpack", speed) + + +def main(): + repeat = 10000 + print("-----------------------------") + print("Benchmark f(x, y, z) overhead") + print("-----------------------------") + baseline_numpy_add(repeat) + baseline_torch_add(repeat) + baseline_cupy_add(repeat) + tvm_ffi_nop(repeat) + tvm_ffi_nop_from_torch_dlpack(repeat) + tvm_ffi_nop_from_numpy_dlpack(repeat) + tvm_ffi_self_dlpack_nop(repeat) + tvm_ffi_nop_from_torch_utils_to_dlpack(repeat) + tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu") + tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda") + tvm_ffi_nop_autodlpack_from_numpy(repeat) + print("-------------------------------") + print("Benchmark x.__dlpack__ overhead") + print("-------------------------------") + bench_torch_utils_to_dlpack(repeat) + bench_to_dlpack(torch.arange(1), "torch.__dlpack__", repeat) + bench_to_dlpack(np.arange(1), "numpy.__dlpack__", repeat) + bench_to_dlpack(tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__", repeat) + print("---------------------------------------------------") + print("Benchmark x.__dlpack__(max_version=(1,1)) overhead") + print("---------------------------------------------------") + bench_to_dlpack_versioned(torch.arange(1), "torch.__dlpack__(max_version=(1,1))", repeat) + bench_to_dlpack_versioned(np.arange(1), "numpy.__dlpack__(max_version=(1,1))", repeat) + bench_to_dlpack_versioned( + tvm_ffi.from_dlpack(torch.arange(1)), "tvm.__dlpack__(max_version=(1,1))", repeat + ) + + +if __name__ == "__main__": + main() diff --git a/ffi/src/ffi/dtype.cc b/ffi/src/ffi/dtype.cc index 7661ab4b97b1..cb0bd4959735 100644 --- a/ffi/src/ffi/dtype.cc +++ b/ffi/src/ffi/dtype.cc @@ -320,9 +320,9 @@ int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out) { TVM_FFI_SAFE_CALL_END(); } -int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) { +int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) { TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(dtype)); + tvm::ffi::String out_str(tvm::ffi::DLDataTypeToString_(*dtype)); *out = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out_str)); TVM_FFI_SAFE_CALL_END(); } diff --git a/ffi/src/ffi/ndarray.cc b/ffi/src/ffi/ndarray.cc index d4c1470566bf..f3c48c8ad56f 100644 --- a/ffi/src/ffi/ndarray.cc +++ b/ffi/src/ffi/ndarray.cc @@ -32,7 +32,7 @@ TVM_FFI_REGISTER_GLOBAL("ffi.Shape").set_body_packed([](ffi::PackedArgs args, An int64_t* mutable_data; ObjectPtr shape = details::MakeEmptyShape(args.size(), &mutable_data); for (int i = 0; i < args.size(); ++i) { - if (auto opt_int = args[i].as()) { + if (auto opt_int = args[i].try_cast()) { mutable_data[i] = *opt_int; } else { TVM_FFI_THROW(ValueError) << "Expect shape to take list of int arguments"; diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index 816ae28e0e9c..3ad81cd11803 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -332,11 +332,45 @@ TEST(Any, ObjectRefWithFallbackTraits) { EXPECT_EQ(v9->value, 0); } +TEST(Any, CastVsAs) { + AnyView view0 = 1; + // as only runs strict check + auto opt_v0 = view0.as(); + EXPECT_TRUE(opt_v0.has_value()); + EXPECT_EQ(opt_v0.value(), 1); + + auto opt_v1 = view0.as(); + EXPECT_TRUE(!opt_v1.has_value()); + auto opt_v2 = view0.as(); + EXPECT_TRUE(!opt_v2.has_value()); + + // try_cast will try run the conversion. + auto opt_v3 = view0.try_cast(); + EXPECT_TRUE(opt_v3.has_value()); + EXPECT_EQ(opt_v3.value(), 1); + auto opt_v4 = view0.try_cast(); + EXPECT_TRUE(opt_v4.has_value()); + EXPECT_EQ(opt_v4.value(), 1); + + Any any1 = true; + auto opt_v5 = any1.as(); + EXPECT_TRUE(opt_v5.has_value()); + EXPECT_EQ(opt_v5.value(), 1); + + auto opt_v6 = any1.try_cast(); + EXPECT_TRUE(opt_v6.has_value()); + EXPECT_EQ(opt_v6.value(), 1); + + auto opt_v7 = any1.try_cast(); + EXPECT_TRUE(opt_v7.has_value()); +} + TEST(Any, ObjectMove) { Any any1 = TPrimExpr("float32", 3.14); auto v0 = std::move(any1).cast(); EXPECT_EQ(v0->value, 3.14); EXPECT_EQ(v0.use_count(), 1); + EXPECT_TRUE(any1 == nullptr); } } // namespace diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc index e31df8761db0..620f729a6678 100644 --- a/ffi/tests/cpp/test_dtype.cc +++ b/ffi/tests/cpp/test_dtype.cc @@ -114,14 +114,14 @@ TEST(DataType, AnyConversion) { TEST(DataType, AnyConversionWithString) { AnyView view0 = "float32"; - Optional opt_v0 = view0.as(); + Optional opt_v0 = view0.try_cast(); DLDataType dtype_v0 = opt_v0.value(); EXPECT_EQ(dtype_v0.code, kDLFloat); EXPECT_EQ(dtype_v0.bits, 32); EXPECT_EQ(dtype_v0.lanes, 1); Any any = String("bfloat16x2"); - Optional opt_v1 = any.as(); + Optional opt_v1 = any.try_cast(); EXPECT_EQ(opt_v1.value().code, kDLBfloat); EXPECT_EQ(opt_v1.value().bits, 16); EXPECT_EQ(opt_v1.value().lanes, 2); diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc index bd0b58b7c46e..b7c977fd344c 100644 --- a/ffi/tests/cpp/test_map.cc +++ b/ffi/tests/cpp/test_map.cc @@ -243,7 +243,7 @@ TEST(Map, AnyConvertCheck) { ::tvm::ffi::Error); } -TEST(Map, ffi::FunctionGetItem) { +TEST(Map, FunctionGetItem) { Function f = Function::FromTyped([](const MapObj* n, const Any& k) -> Any { return n->at(k); }, "map_get_item"); Map map{{"x", 1}, {"y", 2}}; diff --git a/ffi/tests/cpp/test_string.cc b/ffi/tests/cpp/test_string.cc index 847ed6f9559c..a74102a95349 100644 --- a/ffi/tests/cpp/test_string.cc +++ b/ffi/tests/cpp/test_string.cc @@ -275,22 +275,23 @@ TEST(String, Any) { Any b = view; EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIStr); EXPECT_EQ(b.as().value(), "hello"); - EXPECT_EQ(b.as().value(), "hello"); + EXPECT_TRUE(b.as().has_value()); + EXPECT_EQ(b.try_cast().value(), "hello"); std::string s_world = "world"; view = s_world; - EXPECT_EQ(view.as().value(), "world"); + EXPECT_EQ(view.try_cast().value(), "world"); String s{"hello"}; Any a = s; EXPECT_EQ(a.type_index(), TypeIndex::kTVMFFIStr); EXPECT_EQ(a.as().value(), "hello"); - EXPECT_EQ(a.as().value(), "hello"); + EXPECT_EQ(a.try_cast().value(), "hello"); Any c = "helloworld"; EXPECT_EQ(c.type_index(), TypeIndex::kTVMFFIStr); EXPECT_EQ(c.as().value(), "helloworld"); - EXPECT_EQ(c.as().value(), "helloworld"); + EXPECT_EQ(c.try_cast().value(), "helloworld"); } TEST(String, Bytes) { @@ -312,52 +313,52 @@ TEST(String, BytesAny) { AnyView view = &arr; EXPECT_EQ(view.type_index(), TypeIndex::kTVMFFIByteArrayPtr); - EXPECT_EQ(view.as().value().operator std::string(), s); + EXPECT_EQ(view.try_cast().value().operator std::string(), s); Any b = view; EXPECT_EQ(b.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(b.as().value().operator std::string(), s); - EXPECT_EQ(b.as().value(), s); + EXPECT_EQ(b.try_cast().value().operator std::string(), s); + EXPECT_EQ(b.cast(), s); } TEST(String, StdString) { std::string s1 = "test_string"; AnyView view1 = s1; EXPECT_EQ(view1.type_index(), TypeIndex::kTVMFFIRawStr); - EXPECT_EQ(view1.as().value(), s1); + EXPECT_EQ(view1.try_cast().value(), s1); TVMFFIByteArray arr1{s1.data(), static_cast(s1.size())}; AnyView view2 = &arr1; EXPECT_EQ(view2.type_index(), TypeIndex::kTVMFFIByteArrayPtr); - EXPECT_EQ(view2.as().value(), s1); + EXPECT_EQ(view2.try_cast().value(), s1); Bytes bytes1 = s1; AnyView view3 = bytes1; EXPECT_EQ(view3.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(view3.as().value(), s1); + EXPECT_EQ(view3.try_cast().value(), s1); String string1 = s1; AnyView view4 = string1; EXPECT_EQ(view4.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(view4.as().value(), s1); + EXPECT_EQ(view4.try_cast().value(), s1); // Test with Any Any any1 = s1; EXPECT_EQ(any1.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(any1.as().value(), s1); + EXPECT_EQ(any1.try_cast().value(), s1); Any any2 = &arr1; EXPECT_EQ(any2.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(any2.as().value(), s1); + EXPECT_EQ(any2.try_cast().value(), s1); Any any3 = bytes1; EXPECT_EQ(any3.type_index(), TypeIndex::kTVMFFIBytes); - EXPECT_EQ(any3.as().value(), s1); + EXPECT_EQ(any3.try_cast().value(), s1); Any any4 = string1; EXPECT_EQ(any4.type_index(), TypeIndex::kTVMFFIStr); - EXPECT_EQ(any4.as().value(), s1); + EXPECT_EQ(any4.try_cast().value(), s1); } TEST(String, CAPIAccessor) { diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc index 79eeb488643d..e0f69d820018 100644 --- a/ffi/tests/cpp/test_tuple.cc +++ b/ffi/tests/cpp/test_tuple.cc @@ -136,4 +136,32 @@ TEST(Tuple, Upcast) { static_assert(details::type_contains_v, Tuple>); static_assert(details::type_contains_v, Tuple>); } + +TEST(Tuple, ArrayIterForwarding) { + Tuple t0(1, 2); + Tuple t1(3, 4); + Array> arr0 = {t0, t1}; + std::vector> vec0 = {t0}; + vec0.insert(vec0.end(), arr0.begin(), arr0.end()); + EXPECT_EQ(vec0.size(), 3); + EXPECT_EQ(vec0[0].get<0>()->value, 1); + EXPECT_EQ(vec0[0].get<1>()->value, 2); + EXPECT_EQ(vec0[1].get<0>()->value, 1); + EXPECT_EQ(vec0[1].get<1>()->value, 2); + EXPECT_EQ(vec0[2].get<0>()->value, 3); + EXPECT_EQ(vec0[2].get<1>()->value, 4); +} + +TEST(Tuple, ArrayIterForwardSingleElem) { + Tuple t0(1); + Tuple t1(2); + Array> arr0 = {t0, t1}; + std::vector> vec0 = {t0}; + vec0.insert(vec0.end(), arr0.begin(), arr0.end()); + EXPECT_EQ(vec0.size(), 3); + EXPECT_EQ(vec0[0].get<0>()->value, 1); + EXPECT_EQ(vec0[1].get<0>()->value, 1); + EXPECT_EQ(vec0[2].get<0>()->value, 2); +} + } // namespace diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index ee49ac75d15f..451913c9926d 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -32,7 +32,6 @@ using namespace tvm::ffi::testing; TEST(Variant, Basic) { Variant v1 = 1; EXPECT_EQ(v1.get(), 1); - EXPECT_EQ(v1.as().value(), 1.0f); Variant v2 = 2.0f; EXPECT_EQ(v2.get(), 2.0f); @@ -134,4 +133,31 @@ TEST(Variant, Upcast) { EXPECT_EQ(a1[0].get(), 1); } +TEST(Variant, AllObjectRef) { + Variant> v0 = TInt(1); + EXPECT_EQ(v0.get()->value, 1); + static_assert(std::is_base_of_v); + Any any0 = v0; + EXPECT_EQ(any0.cast()->value, 1); + auto v2 = any0.cast>>(); + EXPECT_TRUE(v0.same_as(v2)); + // assignment operator + v0 = Array({TInt(2), TInt(3)}); + EXPECT_EQ(v0.get>().size(), 2); + EXPECT_EQ(v0.get>()[0]->value, 2); + EXPECT_EQ(v0.get>()[1]->value, 3); + EXPECT_EQ(sizeof(v0), sizeof(ObjectRef)); +} + +TEST(Variant, PODSameAs) { + Variant v0 = 1; + Variant v1 = 1; + EXPECT_TRUE(v0.same_as(v1)); + String s = String("hello"); + v0 = s; + v1 = s; + EXPECT_TRUE(v0.same_as(v1)); + v1 = String("hello"); + EXPECT_TRUE(!v0.same_as(v1)); +} } // namespace diff --git a/golang/Makefile b/golang/Makefile deleted file mode 100644 index 76ac371b628d..000000000000 --- a/golang/Makefile +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -.PHONY: clean all - -TVM_BASE = $(CURDIR)/../ -TARGET = gotvm -LIBS = -lm -ldl -NATIVE_SRC = tvm_runtime_pack.cc - -GOPATH=$(CURDIR)/gopath -GOPATHDIR=${GOPATH}/src/${TARGET}/ -CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" -CGO_CXXFLAGS="-std=c++17 -DDMLC_USE_LOGGING_LIBRARY= -DTVM_USE_LIBBACKTRACE=0" -CGO_CFLAGS="-I${TVM_BASE}" -CGO_LDFLAGS="-ldl -lm" - -all: - @mkdir gopath 2>/dev/null || true - @mkdir gopath/src 2>/dev/null || true - @mkdir gopath/src/$(TARGET) 2>/dev/null || true - @cp src/$(TARGET).cc gopath/src/$(TARGET) - @cp src/$(TARGET).h gopath/src/$(TARGET) - @cp src/$(NATIVE_SRC) gopath/src/$(TARGET) - @cp src/*.go gopath/src/$(TARGET) - @export GOPATH=$(GOPATH); \ - export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ - export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \ - export CGO_CFLAGS=$(CGO_CFLAGS); \ - export CGO_LDFLAGS=$(CGO_LDFLAGS); \ - (cd $(GOPATHDIR) && go clean -cache \ - && golint && go build -o $(TARGET).a \ - && go install) - @find . -name gotvm.a - @#mkdir gopath/doc 2>/dev/null || true - @#godoc -html -goroot gopath/ gotvm | grep -v "for documentation on the gotvm command" > gopath/doc/gotvm.html - @#echo "Run 'godoc -http=:6060 -goroot=./gopath' for documentation" - -samples: all - cp gopath/pkg/linux_amd64/gotvm.a sample/ -rfa - make -C sample - -tests: all - @(cd sample; python3 deploy.py) - @export GOPATH=$(GOPATH); \ - export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ - export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \ - export CGO_CFLAGS=$(CGO_CFLAGS); \ - export CGO_LDFLAGS=$(CGO_LDFLAGS); \ - (cd $(GOPATHDIR) \ - && cp ../../../sample/deploy.so . \ - && go test -v) - -clean: - @if [ -d $(GOPATHDIR) ] ; then \ - export GOPATH=$(GOPATH); \ - export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ - export CGO_CFLAGS=$(CGO_CFLAGS); \ - export CGO_LDFLAGS=$(CGO_LDFLAGS); \ - (cd $(GOPATHDIR) && go clean -cache); fi - @rm -rf gopath - @make -C sample clean - -lint: - @(cd src; golint) - @python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.cc - @python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.h diff --git a/golang/README.md b/golang/README.md deleted file mode 100644 index ee3ea8cc2e98..000000000000 --- a/golang/README.md +++ /dev/null @@ -1,126 +0,0 @@ - - - - - - - - - - - - - - - - - -# gotvm - Golang Frontend for TVM Runtime - -This folder contain golang interface for TVM runtime. It brings TVM runtime to Golang. - -- It enable c runtime api of tvm exposed to golang. -- It enables module loading (lib, graph and params) and inference operations. - -## Installation - -### Requirements - -- go compiler (https://golang.org/) version 0.10 or above. - -### Modules - -- src - Module that generates golang package corresponding to the c runtime api exposed from tvm source tree. - This process build golang package _gotvm.a_ - -- samples - Sample golang reference application to inference through gotvm package. - -### Build - -Once the Requirements are installed - -To build _gotvm_ package - -```bash -make -``` - -To build and run internal tests - -```bash -make tests -``` - -To build sample apps. - -```bash -make samples -``` - -## Run - -To Demonstrates sample TVM module compilation using python and deploy via golang. -```bash -./simple -``` - -To deploy a realtime module with lib, graph and param. -```bash -python3 gen_mobilenet_lib.py - -./complex -``` - -To demonstrate go function closure conversion to packed function handle. - -```bash -./pack_func_convert -``` - -To demonstrate a packed function handle given as an argument. - -```bash -./pack_func_handle_arg -``` - -To register go function with runtime as a global function. - -```bash -./pack_func_register -``` - -To demonstrate function closure passed as argument to a function call. - -```bash -./pack_func_closure_arg -``` - -To demonstrate function closure returned from a packed function. - -```bash -./pack_func_closure_return -``` - -## Documentation -gotvm.go is documented with sufficient information about gotvm package. -A html version documentation can be accessed by running below command after building runtime. - -```bash -godoc -http=:6060 -goroot=./gopath -``` -After above command try http://127.0.0.1:6060 from any browser. - -Also please refer to the sample applications under sample folder. - -## Docker -Docker setup may need below additions for dependencies and environment preparation. - -Please refer ```docker/install/ubuntu_install_golang.sh``` for the packages dependencies. - -go compiler 1.10 on ubuntu doesn't install on standard path, hence an explicit export may be needed as shown below. - -```bash -export PATH="/usr/lib/go-1.10/bin:$PATH" -``` diff --git a/golang/sample/Makefile b/golang/sample/Makefile deleted file mode 100644 index fd738b6f979f..000000000000 --- a/golang/sample/Makefile +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -.PHONY: clean all - -SOURCES=$(wildcard *.go) -EXECUTABLE=$(patsubst %.go, %, $(SOURCES)) - -all: $(EXECUTABLE) - @golint - @python3 deploy.py - -%: %.o - @go tool link -linkmode external -extld "g++" -extldflags "-ldl" -o $@ $< - -%.o: %.go - @go tool compile -pack -o $@ $< - -clean: - @rm -f $(EXECUTABLE) *.so *.o *.a *.json *.params diff --git a/golang/sample/complex.go b/golang/sample/complex.go deleted file mode 100644 index c048207b8b5e..000000000000 --- a/golang/sample/complex.go +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application deployment over tvm. - * \file complex.go - */ - -package main - -import ( - "fmt" - "io/ioutil" - "math/rand" - "./gotvm" - "runtime" -) - -// NNVM compiled model paths. -const ( - modLib = "./mobilenet.so" - modJSON = "./mobilenet.json" - modParams = "./mobilenet.params" -) - -// main -func main() { - defer runtime.GC() - // Welcome - fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion) - fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion) - - // Query global functions available - funcNames, err := gotvm.FuncListGlobalNames() - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Global Functions:%v\n", funcNames) - - // Import tvm module (so) - modp, err := gotvm.LoadModuleFromFile(modLib) - if err != nil { - fmt.Print(err) - fmt.Printf("Please copy tvm compiled modules here and update the sample.go accordingly.\n") - fmt.Printf("You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut\n") - return - } - fmt.Printf("Module Imported:%p\n", modp) - bytes, err := ioutil.ReadFile(modJSON) - if err != nil { - fmt.Print(err) - return - } - jsonStr := string(bytes) - - // Load module on tvm runtime - call tvm.graph_executor.create - funp, err := gotvm.GetGlobalFunction("tvm.graph_executor.create") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Calling tvm.graph_executor.create\n") - // Call function - graphrt, err := funp.Invoke(jsonStr, modp, (int64)(gotvm.KDLCPU), (int64)(0)) - if err != nil { - fmt.Print(err) - return - } - graphmod := graphrt.AsModule() - fmt.Printf("Graph executor Created\n") - - // Array allocation attributes - tshapeIn := []int64{1, 224, 224, 3} - tshapeOut := []int64{1, 1001} - - // Allocate input Array - inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) - if err != nil { - fmt.Print(err) - return - } - - // Allocate output Array - out, err := gotvm.Empty(tshapeOut) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Input and Output Arrays allocated\n") - - // Get module function from graph executor : load_params - // Read params - bytes, err = ioutil.ReadFile(modParams) - if err != nil { - fmt.Print(err) - } - - // Load Params - funp, err = graphmod.GetFunction("load_params") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Func load_params:%p\n", funp) - - // Call function - _, err = funp.Invoke(bytes) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Module params loaded\n") - - // Set some data in input Array - inSlice := make([]float32, (224 * 224 * 3)) - rand.Seed(10) - rand.Shuffle(len(inSlice), func(i, j int) {inSlice[i], - inSlice[j] = rand.Float32(), - rand.Float32() }) - inX.CopyFrom(inSlice) - - // Set Input - funp, err = graphmod.GetFunction("set_input") - if err != nil { - fmt.Print(err) - return - } - - // Call function - _, err = funp.Invoke("input", inX) - if err != nil { - fmt.Print(err) - return - } - - fmt.Printf("Module input is set\n") - - // Run - funp, err = graphmod.GetFunction("run") - if err != nil { - fmt.Print(err) - return - } - - // Call function - _, err = funp.Invoke() - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Module Executed \n") - - // Call runtime function get_output - funp, err = graphmod.GetFunction("get_output") - if err != nil { - fmt.Print(err) - return - } - - // Call function - _, err = funp.Invoke(int64(0), out) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Got Module Output \n") - - // Print results - outIntf, _ := out.AsSlice() - outSlice := outIntf.([]float32) - fmt.Printf("Result:%v\n", outSlice[:10]) -} diff --git a/golang/sample/pack_func_closure_arg.go b/golang/sample/pack_func_closure_arg.go deleted file mode 100644 index ff2d1e2754c4..000000000000 --- a/golang/sample/pack_func_closure_arg.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate go-closure given to a packed function argument. - * \file pack_func_closure_arg.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - - -// sampleFunctionArg receives a Packed Function handle and calls it. -func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - // Call Packed Function - retVal, err = pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) - return -} - -// main -func main() { - // Not passing a function name implicitely - // picks the name from reflection as "main.sampleDunctionArg" - gotvm.RegisterFunction(sampleFunctionArg); - fmt.Printf("Registered: sampleFunctionArg\n") - - // Get registered global function. - funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("GetGlobalFunction: main.sampleFunctionArg - Success\n") - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - // Call function - result, err := funp.Invoke(funccall, 30, 50) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Invoked sampleFunctionArg with function closure arg : Result:%v\n", result.AsInt64()) -} diff --git a/golang/sample/pack_func_closure_return.go b/golang/sample/pack_func_closure_return.go deleted file mode 100644 index e010b9395361..000000000000 --- a/golang/sample/pack_func_closure_return.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate go-closure returned from a callback function. - * \file pack_func_closure_return.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - -// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. -func sampleFunctionCb(args ...*gotvm.Value) (retVal interface{}, err error) { - funccall := func (cargs ...*gotvm.Value) (fret interface{}, ferr error) { - for _, v := range cargs { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := cargs[0].AsInt64() - val2 := cargs[1].AsInt64() - fret = int64(val1+val2) - return - } - retVal = funccall - return -} - -// main -func main() { - // Not passing a function name implicitely - // picks the name from reflection as "main.sampleDunctionCb" - gotvm.RegisterFunction(sampleFunctionCb); - fmt.Printf("Registered: sampleFunctionCb\n") - - // Get registered global function - funp, err := gotvm.GetGlobalFunction("main.sampleFunctionCb") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("GetGlobalFunction: main.sampleFunctionCb - Success\n") - - // Call function - result, err := funp.Invoke() - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Invoked main.sampleFunctionCb via Function handle\n") - - pfunc := result.AsFunction() - fmt.Printf("Function Handle received via Packed Function call:%T - %v \n", pfunc, pfunc) - - pfuncRet, err := pfunc.Invoke(30, 40) - fmt.Printf("Invoked closure inside sampleFunctionCb result:%v\n", pfuncRet.AsInt64()) -} diff --git a/golang/sample/pack_func_convert.go b/golang/sample/pack_func_convert.go deleted file mode 100644 index b6d1fbf24d46..000000000000 --- a/golang/sample/pack_func_convert.go +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate function conversion to packed function. - * \file pack_func_convert.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - -// sampleCb is a simple golang callback function like C = A + B. -func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return -} - -// main -func main() { - // Welcome - - // Simple convert to a packed function - fhandle, err := gotvm.ConvertFunction(sampleCb) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Converted function\n") - - retVal, err := fhandle.Invoke(10, 20) - fmt.Printf("Invoke Completed\n") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Result:%v\n", retVal.AsInt64()) -} diff --git a/golang/sample/pack_func_handle_arg.go b/golang/sample/pack_func_handle_arg.go deleted file mode 100644 index d5a3f074946e..000000000000 --- a/golang/sample/pack_func_handle_arg.go +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate converted packed - * function handle passed to another packed function. - * \file pack_func_handle_arg.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - -// sampleCb is a simple golang callback function like C = A + B. -func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return -} - -// sampleFunctionArg receives a Packed Function handle and calls it. -func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - - // Call Packed Function - retVal, err = pfunc.Invoke(args[1], args[2]) - return -} - -// main -func main() { - // Simple convert to a packed function - fhandle, err := gotvm.ConvertFunction(sampleCb) - if err != nil { - fmt.Print(err) - return - } - - gotvm.RegisterFunction(sampleFunctionArg); - fmt.Printf("Registered: sampleFunctionArg\n") - - funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg") - if err != nil { - fmt.Print(err) - return - } - - retVal, err := funp.Invoke(fhandle, 10, 20) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Result:%v\n", retVal.AsInt64()) -} diff --git a/golang/sample/pack_func_register.go b/golang/sample/pack_func_register.go deleted file mode 100644 index ac4ea438dbef..000000000000 --- a/golang/sample/pack_func_register.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate function register into TVM global functions. - * \file pack_func_register.go - */ - -package main - -import ( - "fmt" - "./gotvm" - "strings" -) - -// sampleCb is a simple golang callback function like C = A + B. -func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return -} - -// main -func main() { - // Register sampleCb with TVM packed function system and call and check Global Function List. - gotvm.RegisterFunction(sampleCb, "sampleCb"); - // Query global functions available - funcNames, err := gotvm.FuncListGlobalNames() - if err != nil { - fmt.Print(err) - return - } - - found := 0 - for ii := range (funcNames) { - if strings.Compare(funcNames[ii], "sampleCb") == 0 { - found = 1 - } - } - if found == 0 { - fmt.Printf("Function registerd but, not listed\n") - return - } - - - // Get "sampleCb" and verify the call. - funp, err := gotvm.GetGlobalFunction("sampleCb") - if err != nil { - fmt.Print(err) - return - } - - // Call function - result, err := funp.Invoke((int64)(10), (int64)(20)) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("sampleCb result: %v\n", result.AsInt64()) -} diff --git a/golang/sample/simple.go b/golang/sample/simple.go deleted file mode 100644 index 7bb503db4598..000000000000 --- a/golang/sample/simple.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application deployment over tvm. - * \file simple.go - */ - -package main - -import ( - "fmt" - "runtime" - "./gotvm" - "math/rand" -) - -// NNVM compiled model paths. -const ( - modLib = "./deploy.so" -) - -// main -func main() { - // Welcome - defer runtime.GC() - fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion) - fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion) - - // Import tvm module (so) - modp, _ := gotvm.LoadModuleFromFile(modLib) - fmt.Printf("Module Imported\n") - - - // Allocate Array for inputs and outputs. - // Allocation by explicit type and device. - tshapeIn := []int64{4} - inX, _ := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) - - // Default allocation on CPU - inY, _ := gotvm.Empty(tshapeIn, "float32") - - // Default allocation to type "float32" and on CPU - out, _ := gotvm.Empty(tshapeIn) - fmt.Printf("Input and Output Arrays allocated\n") - - // Fill Input Data : inX , inY - inXSlice := make([]float32, 4) - inYSlice := make([]float32, 4) - for i := range inXSlice { - inXSlice[i] = rand.Float32() - inYSlice[i] = rand.Float32() - } - - - // Copy the data on target memory through runtime CopyFrom api. - inX.CopyFrom(inXSlice) - inY.CopyFrom(inYSlice) - fmt.Printf("X: %v\n", inXSlice) - fmt.Printf("Y: %v\n", inYSlice) - - // Get function "myadd" - funp, _ := modp.GetFunction("myadd") - - // Call function - funp.Invoke(inX, inY, out) - fmt.Printf("Module function myadd executed\n") - - // Get the output tensor as an interface holding a slice through runtime CopyTo api. - outSlice, _ := out.AsSlice() - - // Print results - fmt.Printf("Result:%v\n", outSlice.([]float32)) -} diff --git a/golang/src/array_test.go b/golang/src/array_test.go deleted file mode 100644 index a2636a8b0f20..000000000000 --- a/golang/src/array_test.go +++ /dev/null @@ -1,614 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file array_test.go - */ - - -package gotvm - -import ( - "testing" - "unsafe" - "math/rand" -) - -// Create an array and check size. -func TestArrayCreateSize(t *testing.T) { - _, err := Empty([]int64{4}) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 5, 6}) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{}) - if err == nil { - t.Error("Expected err for empty Array created, but didn't got !!") - return - } -} - -// Check array creation via various different arguments. -func TestArrayCreateArgs(t *testing.T) { - _, err := Empty([]int64{4, 2}, "float32", CPU(0)) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 2}, "float32") - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 2}, CPU(0)) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 2}, CPU(0), "float32") - if err != nil { - t.Error(err.Error()) - return - } -} - -// Create an array and check the NDim. -func TestArrayNDim(t *testing.T) { - arr, err := Empty([]int64{4, 5, 6}) - if err != nil { - t.Error(err.Error()) - return - } - - if 3 != arr.GetNdim() { - t.Errorf("GetNdim failed Expected: 3 Got :%v\n", arr.GetNdim()) - return - } -} - -// Create an array and check Shape. -func TestArrayShape(t *testing.T) { - arr, err := Empty([]int64{4, 5, 6}) - if err != nil { - t.Error(err.Error()) - return - } - - shape := arr.GetShape() - if len(shape) != 3 { - t.Errorf("Shape slice expected: 3 Got :%v\n", len(shape)) - return - } - - if shape[0] != 4 || shape[1] != 5 || shape[2] != 6 { - t.Errorf("Shape values expected {4, 5, 6} Got : %v\n", shape); - return - } -} - -// Create an array and check created Device. -func TestArrayDevice(t *testing.T) { - // TODO: Could some test cases for other targets - arr, err := Empty([]int64{4}, CPU(0)) - if err != nil { - t.Error(err.Error()) - return - } - - dev := arr.GetDevice() - if dev.DeviceType != KDLCPU { - t.Errorf("Dev DeviceType expected: %v Got :%v\n", KDLCPU, dev.DeviceType) - return - } - if dev.DeviceID != 0 { - t.Errorf("Dev DeviceID expected: %v Got :%v\n", KDLCPU, dev.DeviceID) - return - } - - arr, err = Empty([]int64{4}, CPU(2)) - if err != nil { - t.Error(err.Error()) - return - } - - dev = arr.GetDevice() - if dev.DeviceType != KDLCPU { - t.Errorf("Dev DeviceType expected: %v Got :%v\n", KDLCPU, dev.DeviceType) - return - } - if dev.DeviceID != 2 { - t.Errorf("Dev DeviceID expected: %v Got :%v\n", KDLCPU, dev.DeviceID) - return - } -} - -// Create array of different dtypes and check dtypes. -func TestArrayDType(t *testing.T) { - for _, dtype := range []string{"int8", "int16", "int32", "int64", - "uint8", "uint16", "uint32", "uint64", - "float32", "float64"} { - arr, err := Empty([]int64{4}, dtype) - if err != nil { - t.Error(err.Error()) - return - } - - if dtype != arr.GetDType() { - t.Errorf("Dtype expected: %v Got :%v\n", dtype, arr.GetDType()) - return - } - } -} - -// Copy Int8 data to created Array and verify. -func TestArrayCopySliceInt8(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int8") - - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen) - rand.Read(bdata) - data := (*[1<<31]int8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []int8: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - - dataRet := ret.([]int8) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Int16 data to created Array and verify. -func TestArrayCopySliceInt16(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int16") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*2) - rand.Read(bdata) - data := (*[1<<31]int16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - switch ret.(type) { - case []int16: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - - dataRet := ret.([]int16) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Int32 data to created Array and verify. -func TestArrayCopySliceInt32(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int32") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*4) - rand.Read(bdata) - data := (*[1<<31]int32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []int32: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]int32) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Int64 data to created Array and verify. -func TestArrayCopySliceInt64(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int64") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*8) - rand.Read(bdata) - data := (*[1<<31]int64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []int64: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]int64) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt8 data to created Array and verify. -func TestArrayCopySliceUInt8(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint8") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen) - rand.Read(bdata) - data := (*[1<<31]uint8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint8: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint8) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt16 data to created Array and verify. -func TestArrayCopySliceUInt16(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint16") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*2) - rand.Read(bdata) - data := (*[1<<31]uint16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint16: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint16) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt32 data to created Array and verify. -func TestArrayCopySliceUInt32(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint32") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*4) - rand.Read(bdata) - data := (*[1<<31]uint32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint32: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint32) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt64 data to created Array and verify. -func TestArrayCopySliceUInt64(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint64") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*8) - rand.Read(bdata) - data := (*[1<<31]uint64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint64: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint64) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Float32 data to created Array and verify. -func TestArrayCopySliceFloat32(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "float32") - if err != nil { - t.Error(err.Error()) - return - } - - data := make([]float32, dlen) - - for i := range data { - data[i] = rand.Float32() - } - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []float32: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]float32) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v \nGot :%v \n", data, dataRet) - return - } - } -} - -// Copy Float64 data to created Array and verify. -func TestArrayCopySliceFloat64(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "float64") - if err != nil { - t.Error(err.Error()) - return - } - - data := make([]float64, dlen) - - for i := range data { - data[i] = rand.Float64() - } - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []float64: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]float64) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} diff --git a/golang/src/bytearray.go b/golang/src/bytearray.go deleted file mode 100644 index 4dcecef4a9b7..000000000000 --- a/golang/src/bytearray.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMByteArray interface. - * \file bytearray.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" -) - -// ByteArray type wraps the TVMByteArray of C runtime API. -// -// This can be used to hold raw data like params of a model. -type ByteArray uintptr - -// nativeCPtr returns the type freed unitptr for ByteArray. -func (tbytearray ByteArray) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(tbytearray) - return -} - -// SetData is used to intialize ByteArray from a golang string object. -// -// This method initialize both data and data size of the underlaying object. -// This function handles freeing old data object if any before allocating new. -// -// `val` is the golang string object from which the ByteArray is initialized. -func (tbytearray ByteArray) setData(val string) { - bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data - if bufPtr == (*C.char)(C.NULL) { - C.free(unsafe.Pointer(bufPtr)) - } - - ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data = C.CString(val) - ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size = C.ulong(len(val)) -} - -// getData returns the golang byte slice corresponding to the ByteArray. -func (tbytearray ByteArray) getData() (retVal []byte) { - val := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data - blen := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size - retVal = C.GoBytes(unsafe.Pointer(val), C.int(blen)) - return -} - -// newByteArray initilizes the native TVMByteArray object with given byte slice -// -//`val` is the golang byte array used to initialize. -// -// returns newly created ByteArray. -func newByteArray(val []byte) (retVal ByteArray) { - handle := ByteArray(C.malloc(C.sizeof_TVMByteArray)) - ((*C.TVMByteArray)(unsafe.Pointer(handle))).data = (*C.char)(C.NULL) - ((*C.TVMByteArray)(unsafe.Pointer(handle))).size = 0 - handle.setData(string(val)) - retVal = handle - return -} - -// deleteTVMByteArray releases the allocated native object of ByteArray. -// -// This delete handles freeing of underlaying native data object too. -func (tbytearray ByteArray) deleteTVMByteArray() { - bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data - C.free(unsafe.Pointer(bufPtr)) - C.free(unsafe.Pointer(tbytearray.nativeCPtr())) -} diff --git a/golang/src/bytearray_test.go b/golang/src/bytearray_test.go deleted file mode 100644 index c4047c50a605..000000000000 --- a/golang/src/bytearray_test.go +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file bytearray_test.go - */ - - -package gotvm - -import ( - "testing" - "math/rand" -) - -// Check ByteArray creation from byte slice and verify the data. -func TestByteArrayGet(t *testing.T) { - data := make([]byte, 1024) - rand.Read(data) - - barr := newByteArray(data) - dataRet := barr.getData() - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v at : %v\n", data[i], dataRet[i], i) - return - } - } -} diff --git a/golang/src/device.go b/golang/src/device.go deleted file mode 100644 index 2918cf6a0f0f..000000000000 --- a/golang/src/device.go +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for Device interface - * \file device.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -// KDLCPU is golang enum correspond to TVM device type kDLCPU. -var KDLCPU = int32(C.kDLCPU) -// kDLCUDA is golang enum correspond to TVM device type kDLCUDA. -var kDLCUDA = int32(C.kDLCUDA) -// kDLCUDAHost is golang enum correspond to TVM device type kDLCUDAHost. -var kDLCUDAHost = int32(C.kDLCUDAHost) -// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL. -var KDLOpenCL = int32(C.kDLOpenCL) -// KDLMetal is golang enum correspond to TVM device type kDLMetal. -var KDLMetal = int32(C.kDLMetal) -// KDLVPI is golang enum correspond to TVM device type kDLVPI. -var KDLVPI = int32(C.kDLVPI) -// KDLROCM is golang enum correspond to TVM device type kDLROCM. -var KDLROCM = int32(C.kDLROCM) -// KDLVulkan is golang enum correspond to TVM device type kDLVulkan. -var KDLVulkan = int32(C.kDLVulkan) -// KExtDev is golang enum correspond to TVM device type kDLExtDev. -var KExtDev = int32(C.kDLExtDev) - -// Device dtype corresponding to Device aka DLDevice -type Device struct { - DeviceType int32 - DeviceID int32 -} - -// CPU returns the Device object for CPU target on given index -func CPU(index int32) Device { - return Device{KDLCPU, index} -} - -// CUDA returns the Device object for CUDA target on given index -func CUDA(index int32) Device { - return Device{kDLCUDA, index} -} - -// CUDAHost returns the Device object for CUDAHost target on given index -func CUDAHost(index int32) Device { - return Device{kDLCUDAHost, index} -} - -// OpenCL returns the Device object for OpenCL target on given index -func OpenCL(index int32) Device { - return Device{KDLOpenCL, index} -} - -// Metal returns the Device object for Metal target on given index -func Metal(index int32) Device { - return Device{KDLMetal, index} -} - -// VPI returns the Device object for VPI target on given index -func VPI(index int32) Device { - return Device{KDLVPI, index} -} - -// ROCM returns the Device object for ROCM target on given index -func ROCM(index int32) Device { - return Device{KDLROCM, index} -} - -// Vulkan returns the Device object for Vulkan target on given index -func Vulkan(index int32) Device { - return Device{KDLVulkan, index} -} diff --git a/golang/src/error.go b/golang/src/error.go deleted file mode 100644 index edd8116a3612..000000000000 --- a/golang/src/error.go +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for error related API interface. - * \file error.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" -) - -// getTVMLastError returns the detailed error string for any api called in TVM runtime. -// -// This is useful when any api returns non zero value. -// -// Returns golang string for the corresponding native error message. -func getTVMLastError() (retVal string) { - errStr := C.TVMGetLastError() - retVal = C.GoString(errStr) - return -} - -func setTVMLastError(errStr string) { - cstr := C.CString(errStr) - C.TVMAPISetLastError(cstr) - C.free(unsafe.Pointer(cstr)) -} diff --git a/golang/src/error_test.go b/golang/src/error_test.go deleted file mode 100644 index 3fe912db110e..000000000000 --- a/golang/src/error_test.go +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file error_test.go - */ - - -package gotvm - -import ( - "testing" - "strings" -) - -// Check err receiving from TVM global function. -func TestErrorTest(t *testing.T) { - _, err := LoadModuleFromFile("dummy.so") - if err == nil { - t.Error("Expected an error, but not received\n") - return - } - - errStr := err.Error() - if !(strings.Contains(errStr, string("cannot open shared object"))) { - t.Error("Ah! TVM didn't report an error\n") - } -} diff --git a/golang/src/function.go b/golang/src/function.go deleted file mode 100644 index 7b1c5d27d429..000000000000 --- a/golang/src/function.go +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMFunction interface. - * \file function.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" - "encoding/binary" - "errors" - "runtime" - "reflect" - "fmt" -) - -// Function type in golang hold pointer for the TVMFunction handle. -type Function uintptr - -// nativeCPtr returns type freed uintptr for the Function. -func (tvmfunction Function) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(tvmfunction) - return -} - -// Invoke calls the TVM packed function referred by the handle with given arguments. -func (tvmfunction *Function) Invoke(args ...interface{}) (retVal *Value, err error) { - funccall := func (fargs ...interface{}) (*Value, error) { - return callNativeFunction(tvmfunction, fargs) - } - // Check is any args are contain any ValueArray - // Possible is it's a args forward from one packed function to another. - valueArrayFound := false - for ii := range args { - switch args[ii].(type) { - case []*Value: - valueArrayFound = true - } - } - - if !valueArrayFound { - return funccall(args...) - } - if len(args) != 1 { - err = fmt.Errorf("Not supported if packed function args are a mix of []Value and other types") - return - } - - valArray := args[0].([]*Value) - if len(valArray) > 0 { - newArgs := make([]interface{}, len(valArray)) - for ii := range valArray { - newVal := newTVMValue() - newVal.moveFrom(valArray[ii]) - newArgs[ii] = newVal - } - - return funccall(newArgs...) - } - return funccall() -} - -// FuncListGlobalNames is used to query global callable packed function names from TVM. -// -// returns slice of string holding function names and error if any. -func FuncListGlobalNames() (retVal []string, err error) { - var str string - ret := (int32)(C._TVMFuncListGlobalNames(unsafe.Pointer((&str)))) - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - str = goStringFromNative(*(*string)(unsafe.Pointer(&str))) - bin := binary.LittleEndian - size := bin.Uint64([]byte(str[:8])) - str = str[8:] - retVal = make([]string, size) - for i := range retVal { - len := bin.Uint64([]byte(str[:8])) - str = str[8:] - retVal[i] = str[:len] - str = str[len:] - } - return -} - -// GetGlobalFunction is to get handle to the given global function name. -// -// `funcname` is the name of global packed function. -// -// returns a function closure with signature -// func (args ...interface{}) (interface{}, error) and error if any. -// -// The closure function can be used to call Function with arguments directly. -// -// Variadic arguments can be any type which can be embed into Value. -func GetGlobalFunction(funcname string) (retVal *Function, err error) { - var funp uintptr - - cfuncname := C.CString(funcname) - ret := (int32)(C.TVMFuncGetGlobal(cfuncname, - (*C.TVMFunctionHandle)(unsafe.Pointer(&funp)))) - C.free(unsafe.Pointer(cfuncname)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - handle := new(Function) - *handle = Function(funp) - finalizer := func(fhandle *Function) { - nativeTVMFuncFree(fhandle) - fhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// callNativeFunction is routine which calls gotvm native wrapper with given arguments. -// -// `handle` is the handle for Function. -// -// `args` are the variadic arguments to the Function. -// -// returns the interface for the return value from TVM if any and error if any. -func callNativeFunction(handle *Function, args []interface{}) (retVal *Value, err error) { - argsIn := make([]*Value, len(args)) - var typeCodes []int32 - if len(args) != 0 { - typeCodes = make([]int32, len(args)) - } else { - typeCodes = make([]int32, 1) - } - - for ii := range args { - argsIn[ii] = newTVMValue() - if typeCodes[ii], err = argsIn[ii].setValue(args[ii]); err != nil { - return - } - } - - retVal = newTVMValue() - argsOut := []*Value{retVal} - retTypeCode := KNull - err = nativeTVMFuncCall(handle, argsIn, typeCodes, argsOut, &retTypeCode) - if err != nil { - retVal = nil - return - } - retVal.isLocal = false - retVal.dtype = retTypeCode - return -} - -// nativeTVMFuncFree free the function handle allocated in TVM runtime. -// -// `funp` is the Function handle to be freed. -func nativeTVMFuncFree(funp *Function) (retVal int32) { - retVal = (int32) (C.TVMFuncFree(C.TVMFunctionHandle(funp.nativeCPtr()))) - return -} - -// nativeToGoSlice converts native TVMValue array to Golang slice of TVMValue -// -// -func nativeToGoSlice(nargValues (*C.void), argValues []*Value, typeCodes []int32) { - for ii := range argValues { - C._TVMValueNativeGet(unsafe.Pointer(argValues[ii].nativeCPtr()), - unsafe.Pointer(nargValues), - C.int(int32(ii))) - argValues[ii].dtype = typeCodes[ii] - } -} - -// nativeFromGoSlice converts golang slice of TVMValue to native TVMValue array. -// -// -func nativeFromGoSlice(argValues []*Value) (nptr (*C.void)) { - nargValues := ((uintptr)(C.malloc(C.ulong(C.sizeof_TVMValue * len(argValues))))) - for ii := range argValues { - C._TVMValueNativeSet(unsafe.Pointer(nargValues), - unsafe.Pointer(argValues[ii].nativeCPtr()), - C.int(int32(ii))) - } - nptr = (*C.void)(unsafe.Pointer(nargValues)) - return -} - -// nativeTVMFuncCall executes the function with given arguments -// -// `funp` Function handle to the packed function. -// -// `argValues` is the slice of Value which are arguments to the packed function. -// -// `typeCodes` is the alice of argument type codes corresponding to argValues. -// -// `retValues` is return argument which is slice of return values from the packed function. -// -// `retTypeCode` is int32 holding type codes for retValue -// -// Returns err indicating native error if any. -func nativeTVMFuncCall(funp *Function, argValues []*Value, typeCodes []int32, - retValues []*Value, retTypeCode *int32) (err error) { - nargValues := nativeFromGoSlice(argValues) - nretValues := nativeFromGoSlice(retValues) - result := (int32)(C.TVMFuncCall(C.TVMFunctionHandle(*funp), - (*C.TVMValue)(unsafe.Pointer(nargValues)), - (*C.int)(unsafe.Pointer(&(typeCodes[0]))), - C.int(len(argValues)), - (*C.TVMValue)(unsafe.Pointer(nretValues)), - (*C.int)(unsafe.Pointer(retTypeCode)))) - nativeToGoSlice(nargValues, argValues, typeCodes) - nativeToGoSlice(nretValues, retValues, (*[1<<31] int32)(unsafe.Pointer(retTypeCode))[:1:1]) - C.free(unsafe.Pointer(nargValues)) - C.free(unsafe.Pointer(nretValues)) - - if result != 0 { - err = errors.New(getTVMLastError()) - } - return -} - -// goCallBack is a structure holding the go callback function pointer. -// This wrapping is necessary as cgo doesn't support -// passing golang functions type conversion to native. -type goCallBack struct { - cb func (args ...*Value) (interface{}, error) -} - -//export goTVMCallback -func goTVMCallback(args C.native_voidp, typeCodes C.native_voidp, numArgs int32, - retArg C.native_voidp, resourceHandle C.native_voidp) (ret int32){ - fcb := (*goCallBack)(resourceHandle) - // Make Value Sice from native TVMValue pointer. - argValues := make([]*Value, numArgs) - - for ii := range argValues { - argValues[ii] = newTVMValue() - argValues[ii].isLocal = false - } - - // Prepare arguments for golang callback function - nativeToGoSlice((*C.void)(unsafe.Pointer(args)), argValues, - (*[1<<31] int32)(unsafe.Pointer(typeCodes))[:numArgs:numArgs]) - cbargs := argValues - - // Execute the callback - retVal, err := fcb.cb(cbargs...) - if err != nil { - errStr := err.Error() - setTVMLastError(errStr) - return -1 - } - - // It's possible a packed function directly return - // the return value of another packed function. - // - // Inside a packed func : - // ```return pfunc.Invoke(args)``` - // - // In this case pfunc returns nil which is - // returned as an interface holding nil *Value. - // Which becomes a valid retVal holding nil *Value. - isRetNull := false - switch retVal.(type) { - case *Value: - pRet := retVal.(*Value) - if pRet == nil { - isRetNull = true - } - } - - // Handle return value from callback function - if retVal != nil && !isRetNull { - var retTypeCode int32 - retValues := []*Value{newTVMValue()} - - retTypeCode, err = retValues[0].setValue(retVal) - if err != nil { - errStr := err.Error() - setTVMLastError(errStr) - return -1 - } - nretValues := nativeFromGoSlice(retValues) - - // Handle KStr, KBytes: Local finalizers shouldn't try freeing them. - retValues[0].isLocal = false - - apiRet := (int32) (C.TVMCFuncSetReturn(C.TVMRetValueHandle(retArg), - (*C.TVMValue)(unsafe.Pointer(nretValues)), - (*C.int)(unsafe.Pointer(&retTypeCode)), 1)) - C.free(unsafe.Pointer(nretValues)) - if apiRet != 0 { - errStr := string("TVMCFuncSetReturn failed ") - setTVMLastError(errStr) - } - } - return -} - -// ConvertFunction converts given golang function to TVM packed function. -// -// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})``` -// -// Returns Function handle and err if any. -func ConvertFunction(args ...interface{}) (retVal *Function, err error) { - function := args[0].(func (args ...*Value) (interface{}, error)) - fcb := &goCallBack{cb:function} - var funp uintptr - - result := (int32) (C._ConvertFunction(unsafe.Pointer(fcb), - unsafe.Pointer(&funp))) - if result != 0 { - err = errors.New(getTVMLastError()) - } - - handle := new(Function) - *handle = Function(funp) - finalizer := func(fhandle *Function) { - nativeTVMFuncFree(fhandle) - fhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// RegisterFunction registers the golang func in TVM runtime global space. -// -// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})``` -// -// `args[1]` Optional argument of function name with which it will be registered. -// If not passed we use function name from reflection. -// -// Returns err indicating native error if any. -func RegisterFunction(args ...interface{}) (err error) { - fhandle, err := ConvertFunction(args...) - if err != nil { - return - } - - funcname := runtime.FuncForPC(reflect.ValueOf(args[0]).Pointer()).Name() - if len(args) > 1 { - funcname = args[1].(string) - } - - cfuncname := C.CString(funcname) - result := (int32) (C.TVMFuncRegisterGlobal(cfuncname, - C.TVMFunctionHandle(*fhandle), - 0)); // Override = False - C.free(unsafe.Pointer(cfuncname)) - if result != 0 { - err = errors.New(getTVMLastError()) - } - // Clear the finalizer as we don't need to control it anymore. - runtime.SetFinalizer(fhandle, nil) - return -} diff --git a/golang/src/function_test.go b/golang/src/function_test.go deleted file mode 100644 index 0830d16419a2..000000000000 --- a/golang/src/function_test.go +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file function_test.go - */ - -package gotvm - -import ( - "testing" - "reflect" - "math/rand" - "strings" - "fmt" -) - -// Check global function list API -func TestFunctionGlobals(t *testing.T) { - funcNames, err := FuncListGlobalNames() - if err != nil { - t.Error(err.Error()) - return - } - if len(funcNames) < 1 { - t.Errorf("Global Function names received:%v\n", funcNames) - } -} - -// Check GetFunction API -func TestFunctionGlobalGet(t *testing.T) { - funp, err := GetGlobalFunction("tvm.graph_executor.create") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(funp).Kind() != reflect.Ptr { - t.Error("Function type mis matched\n") - return - } -} - -func TestFunctionModuleGet(t *testing.T) { - modp, err := LoadModuleFromFile("./deploy.so") - if err != nil { - t.Error(err.Error()) - return - } - funp, err := modp.GetFunction("myadd") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(funp).Kind() != reflect.Ptr { - t.Error("Function type mis matched\n") - return - } - - dlen := int64(1024) - shape := []int64{dlen} - inX, _ := Empty(shape) - inY, _ := Empty(shape) - out, _ := Empty(shape) - dataX := make([]float32, (dlen)) - dataY := make([]float32, (dlen)) - outExpected := make([]float32, (dlen)) - - for i := range dataX { - dataX[i] = rand.Float32() - dataY[i] = rand.Float32() - outExpected[i] = dataX[i] + dataY[i] - } - - inX.CopyFrom(dataX) - inY.CopyFrom(dataY) - - funp.Invoke(inX, inY, out) - outi, _ := out.AsSlice() - outSlice := outi.([]float32) - if len(outSlice) != len(outExpected) { - t.Errorf("Data expected Len: %v Got :%v\n", len(outExpected), len(outSlice)) - return - } - for i := range outSlice { - if outExpected[i] != outSlice[i] { - t.Errorf("Data expected: %v Got :%v at index %v\n", outExpected[i], outSlice[i], i) - return - } - } -} - -// Check FunctionConvert API -func TestFunctionConvert(t *testing.T) { - sampleCb := func (args ...*Value) (retVal interface{}, err error) { - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - fhandle, err := ConvertFunction(sampleCb) - if err != nil { - t.Error(err.Error()) - return - } - - retVal, err := fhandle.Invoke(10, 20) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsInt64() != int64(30) { - t.Errorf("Expected result :30 got:%v\n", retVal.AsInt64()) - return - } -} - -func TestFunctionError(t *testing.T) { - sampleCb := func (args ...*Value) (retVal interface{}, err error) { - err = fmt.Errorf("Sample Error XYZABC"); - return - } - - fhandle, err := ConvertFunction(sampleCb) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = fhandle.Invoke() - if err == nil { - t.Error("Expected error but didn't received\n") - return - } - - if !strings.Contains(err.Error(), string("Sample Error XYZABC")) { - t.Errorf("Expected Error should contain :\"Sample Error XYZABC\" got :%v\n", err.Error()) - } -} - -// Check FunctionRegister -func TestFunctionRegister(t *testing.T) { - sampleCb := func (args ...*Value) (retVal interface{}, err error) { - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - RegisterFunction(sampleCb, "TestFunctionRegister.sampleCb"); - // Query global functions available - funcNames, err := FuncListGlobalNames() - if err != nil { - t.Error(err.Error()) - return - } - - found := 0 - for ii := range (funcNames) { - if strings.Compare(funcNames[ii], "TestFunctionRegister.sampleCb") == 0 { - found = 1 - } - } - if found == 0 { - t.Error("Registered function not found in global function list.") - return - } - - // Get "sampleCb" and verify the call. - funp, err := GetGlobalFunction("TestFunctionRegister.sampleCb") - if err != nil { - t.Error(err.Error()) - return - } - - // Call function - result, err := funp.Invoke((int64)(10), (int64)(20)) - if err != nil { - t.Error(err.Error()) - return - } - if result.AsInt64() != int64(30) { - t.Errorf("Expected result :30 got:%v\n", result.AsInt64()) - return - } -} - -// Check packed function receiving go-closure as argument. -func TestFunctionClosureArg(t *testing.T) { - // sampleFunctionArg receives a Packed Function handle and calls it. - sampleFunctionArg := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - - // Call Packed Function by Value - ret, err := pfunc.Invoke(args[1], args[2]) - if err != nil { - return - } - - // Call Packed Function with extracted values - ret1, err := pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) - if err != nil { - return - } - if ret1.AsInt64() != ret.AsInt64() { - err = fmt.Errorf("Invoke with int64 didn't match with Value") - return - } - retVal = ret - return - } - - RegisterFunction(sampleFunctionArg, "TestFunctionClosureArg.sampleFunctionArg"); - funp, err := GetGlobalFunction("TestFunctionClosureArg.sampleFunctionArg") - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - // Call function - result, err := funp.Invoke(funccall, 30, 50) - if err != nil { - t.Error(err.Error()) - return - } - - if result.AsInt64() != int64(80) { - t.Errorf("Expected result :80 got:%v\n", result.AsInt64()) - return - } -} - -// Check packed function returning a go-closure. -func TestFunctionClosureReturn(t *testing.T) { - // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. - sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { - funccall := func (cargs ...*Value) (fret interface{}, ferr error) { - val1 := cargs[0].AsInt64() - val2 := cargs[1].AsInt64() - fret = int64(val1+val2) - return - } - retVal = funccall - return - } - - RegisterFunction(sampleFunctionCb, "TestFunctionClosureReturn.sampleFunctionCb"); - funp, err := GetGlobalFunction("TestFunctionClosureReturn.sampleFunctionCb") - if err != nil { - t.Error(err.Error()) - return - } - - // Call function - result, err := funp.Invoke() - if err != nil { - t.Error(err.Error()) - return - } - - pfunc := result.AsFunction() - pfuncRet, err := pfunc.Invoke(30, 40) - if err != nil { - t.Error(err.Error()) - return - } - if pfuncRet.AsInt64() != int64(70) { - t.Errorf("Expected result :70 got:%v\n", pfuncRet.AsInt64()) - return - } -} - -// Check packed function with no arguments and no return values. -func TestFunctionNoArgsReturns(t *testing.T) { - sampleFunction := func (args ...*Value) (retVal interface{}, err error) { - return - } - - fhandle, err := ConvertFunction(sampleFunction) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = fhandle.Invoke() - if err != nil { - t.Error(err.Error()) - return - } -} - -// Check packed function returning a go-closure with no arg and returns. -func TestFunctionNoArgsReturns2(t *testing.T) { - // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. - sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { - funccall := func (cargs ...*Value) (fret interface{}, ferr error) { - return - } - retVal = funccall - return - } - - funp, err := ConvertFunction(sampleFunctionCb) - if err != nil { - t.Error(err.Error()) - return - } - - // Call function - result, err := funp.Invoke() - if err != nil { - t.Error(err.Error()) - return - } - - pfunc := result.AsFunction() - _, err = pfunc.Invoke() - if err != nil { - t.Error(err.Error()) - return - } -} diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc deleted file mode 100644 index d8919dafbfcb..000000000000 --- a/golang/src/gotvm.cc +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm native interface definition - * \file gotvm.cxx - */ - -// Standard includes -#include -#include -#include -#include -#include -#include - -// golang string compatible definition -typedef struct { - char* p; - int n; -} _gostring_; -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// TVM runtime C interface -#include -#include - -/*! - * \brief Convert native char array to _gostring_ structure. - * _gostring_ structure represents the same memory footprint as golang string object. - * - * \param p is char pointer to a char array. - * \param l is the size of the char array. this method exclusively need length as - * its possible to have a bytearray in a string. - * - * \return _gostring_ object corresponding to native char array. - * Caller is responsible to free the memory block allocated here. - */ -static _gostring_ _native_to_gostring(const char* p, size_t l) { - _gostring_ ret; - ret.p = reinterpret_cast(malloc(l)); - if (NULL == ret.p) { - ret.n = 0; - return ret; - } - memcpy(ret.p, p, l); - ret.n = l; - return ret; -} - -/*! - * \brief embeds a 64bit uint value inside a string to serialize the data. - * - * \param s is string object. - * \param off is the offset in the string object. - * \param v is the uint64_t value which need to embed into given string. - */ -static void putuint64(std::string* s, size_t off, uint64_t v) { - for (int i = 0; i < 8; i++) { - (*s)[off + i] = (v >> (i * 8)) & 0xff; - } -} - -// TVM runtime C interface wrappers - -/*! - * \brief Native interface to query TVM_VERSION in golang string format. - * - * \return char pointer to TVM-VERSION - */ -const char* _TVM_VERSION(void) { - const char* version = TVM_VERSION; - return version; -} - -/*! - * \brief Native interface for getting TVMGlobal function list. - * - * \param names return by argument to return the function names. - * We wrap all strings into single string joined by (len+string) - * which is unpacked and processed in golang. - * - * \return c_runtime_api return status. - */ -int _TVMFuncListGlobalNames(_gostring_* names) { - int names_size; - char** names_array; - int result; - - result = TVMFuncListGlobalNames(&names_size, (char const***)&names_array); - if (result) { - return result; - } - - size_t tot = 8; - for (int ii = 0; ii < names_size; ++ii) { - tot += 8 + strlen(names_array[ii]); - } - - std::string str; - str.resize(tot); - putuint64(&str, 0, names_size); - size_t off = 8; - for (int64_t ii = 0; ii < names_size; ++ii) { - putuint64(&str, off, strlen(names_array[ii])); - off += 8; - str.replace(off, strlen(names_array[ii]), names_array[ii]); - off += strlen(names_array[ii]); - } - *names = _native_to_gostring(str.data(), str.size()); - if (str.size() != names->n) { - TVMAPISetLastError("malloc failed during _native_to_gostring"); - result = 1; - } - return result; -} - -// Helpers for TVMValue - -/*! - * \brief Native helper to copy TVMValue from golang slice to native array. - * this helper is need as underlying memory for golang slice is not continuous. - * - * \param to_ptr is the native pointer of TVMValue array. - * \param from_ptr pointer to TVMValue in golang slice. - * \param array index in native array. - */ -void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { - TVMValue* from_p = reinterpret_cast(from_ptr); - TVMValue* to_p = reinterpret_cast(to_ptr); - memcpy(to_p + ind, from_p, sizeof(TVMValue)); -} - -/*! - * \brief Native helper to copy TVMValue from golang slice to native array. - * this helper is need as underlying memory for golang slice is not continuous. - * - * \param to_ptr pointer to TVMValue in golang slice. - * \param from_ptr is the native pointer of TVMValue array. - * \param array index in native array. - */ -void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { - TVMValue* from_p = reinterpret_cast(from_ptr); - TVMValue* to_p = reinterpret_cast(to_ptr); - memcpy(to_p, from_p + ind, sizeof(TVMValue)); -} - -extern int goTVMCallback(void*, void*, int, void*, void*); - -/*! - * \brief _TVMCallback is the TVM runtime callback function for ffi::Functiontion system. - * - * \param args is an array of TVMValue - * \param type_codes is an array of int - * \param num_args is int representing number of in arguments - * \param ret is the return value handle to set the packed function return. - * \param resource_handle is the golang private data pointer. - * - * \returns the error status as TVM_DLL - */ -int _TVMCallback(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, - void* resource_handle) { - return goTVMCallback(args, type_codes, num_args, ret, resource_handle); -} - -/*! - * _TVMPackedCFuncFinalizer is finalizer for packed function system. - * - */ -void _TVMPackedCFuncFinalizer(void* resource_handle) { return; } - -/*! - * /brief _ConvertFunction creates a packed function for with given resource handle. - * - * /param fptr is the pointer to golang resource handle. - * /param *fhandle is the return argument holding packed function. - * - * /return is an int indicating the return status. - */ -int _ConvertFunction(void* fptr, TVMFunctionHandle* fhandle) { - int ret = TVMFuncCreateFromCFunc(_TVMCallback, fptr, _TVMPackedCFuncFinalizer, fhandle); - return ret; -} - -#ifdef __cplusplus -} -#endif diff --git a/golang/src/gotvm.h b/golang/src/gotvm.h deleted file mode 100644 index a053e39bd79a..000000000000 --- a/golang/src/gotvm.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm native interface declaration. - * \file gotvm.h - * - * These declarations are in cgo interface definition while calling API - * across golang and native C boundaries. - */ - -#ifndef GOTVM_GOTVM_H_ -#define GOTVM_GOTVM_H_ - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include -#include -#include - -// Some type definitions for golang "C" -typedef void* native_voidp; - -// Version -extern char* _TVM_VERSION(void); - -// Wrappers : For incompatible cgo API. -// To handle array of strings wrapped into __gostring__ -extern int _TVMFuncListGlobalNames(void*); -// To handle TVMValue slice to/from native sequential TVMValue array. -extern void _TVMValueNativeSet(void* to, void* from, int index); -extern void _TVMValueNativeGet(void* to, void* from, int index); - -// Callbacks -extern int _ConvertFunction(void* fptr, void* funp); - -#ifdef __cplusplus -} -#endif -#endif // GOTVM_GOTVM_H_ diff --git a/golang/src/gotvm_test.go b/golang/src/gotvm_test.go deleted file mode 100644 index 271b1899897b..000000000000 --- a/golang/src/gotvm_test.go +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file gotvm_test.go - */ - - -package gotvm - -import ( - "testing" - "reflect" -) - -// Check TVMVersion API -func TestTVMVersion(t *testing.T) { - if len(TVMVersion) == 0 { - t.Error("TVMVersion not set\n") - } - if reflect.TypeOf(TVMVersion).Kind() != reflect.String { - t.Error("TVMVersion type mismatch\n") - } -} - -// Check DLPackVersion API -func TestDLPackVersion(t *testing.T) { - if reflect.TypeOf(DLPackVersion).Kind() != reflect.Int { - t.Error("TVMVersion type mismatch\n") - } -} diff --git a/golang/src/module.go b/golang/src/module.go deleted file mode 100644 index 8ac09e369cae..000000000000 --- a/golang/src/module.go +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMModule interface. - * \file module.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "errors" - "runtime" - "unsafe" -) - -// Module type in golang hold pointer for the TVMModule handle. -// -// Module initialization happen through TVMModLoadFromFile api in TVM runtime. -type Module uintptr - -// nativeCPtr returns type freed uintptr for the Module. -func (tvmmodule *Module) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(*tvmmodule) - return -} - -// LoadModuleFromFile loads the given module in TVM runtime. -// -// `modpath` is the path to tvm module. -// -// `args` is an optional arguments of ["dll", "dylib", "dso", "so"] with default value "so" -// -// returns pointer to Module and err or if any. -func LoadModuleFromFile(modpath string, args ...interface{}) (retVal *Module, err error) { - modtype := "so" - if len(args) > 0 { - modtype = args[0].(string) - } - var modp uintptr - - cmodpath := C.CString(modpath) - cmodtype := C.CString(modtype) - - ret := (int32)(C.TVMModLoadFromFile(cmodpath, - cmodtype, - (*C.TVMModuleHandle)(unsafe.Pointer(&modp)))) - - C.free(unsafe.Pointer(cmodpath)) - C.free(unsafe.Pointer(cmodtype)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - handle := new(Module) - *handle = Module(modp) - finalizer := func(mhandle *Module) { - nativeTVMModFree(mhandle) - mhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// nativeTVMModFree free the module handle allocated in TVM runtime. -// -// `modp` is the Module handle to be freed. -func nativeTVMModFree(modp *Module) (retVal int32) { - retVal = (int32) (C.TVMModFree(C.TVMModuleHandle(modp.nativeCPtr()))) - return -} - -// GetFunction returns the function pointer from the module for given function name. -// -// `tvmmodule` is handle for Module -// -// `funcname` function name in module. -// -// `args` variadic args of `queryImport` -// -// returns function closure with signature -// func (args ...interface{}) (interface{}, error) and error if any. -// -// The closure function can be used to call Function with arguments directly. -// -// Variadic arguments can be any type which can be embed into Value. -func (tvmmodule *Module) GetFunction ( - funcname string, args ...interface{}) ( - retVal *Function, err error){ - queryImports := int32(1) - if len(args) > 0 { - queryImports = int32(args[1].(int)) - } - - var funp uintptr - cfuncname := C.CString(funcname) - ret := (int32)(C.TVMModGetFunction((C.TVMModuleHandle)(*tvmmodule), - cfuncname, - C.int(queryImports), - (*C.TVMFunctionHandle)(unsafe.Pointer(&funp)))) - C.free(unsafe.Pointer(cfuncname)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - handle := new(Function) - *handle = Function(funp) - finalizer := func(fhandle *Function) { - nativeTVMFuncFree(fhandle) - fhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} diff --git a/golang/src/module_test.go b/golang/src/module_test.go deleted file mode 100644 index 7e18a86c5b3a..000000000000 --- a/golang/src/module_test.go +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file module_test.go - */ - - -package gotvm - -import ( - "testing" - "reflect" -) - -// Check module loading - dll -func TestModuleTestLoad1(t *testing.T) { - // dll - mod, err := LoadModuleFromFile("./deploy.so", "dll") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading - dylib -func TestModuleTestLoad2(t *testing.T) { - // dylib - mod, err := LoadModuleFromFile("./deploy.so", "dylib") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -func TestModuleTestLoad3(t *testing.T) { - // dso - mod, err := LoadModuleFromFile("./deploy.so", "dso") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading - so -func TestModuleTestLoad4(t *testing.T) { - // so - mod, err := LoadModuleFromFile("./deploy.so", "so") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading - default (so) -func TestModuleTestLoad5(t *testing.T) { - // default type as so - mod, err := LoadModuleFromFile("./deploy.so") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading err -func TestModuleTestLoadErr(t *testing.T) { - // Unknown file should return error - _, err := LoadModuleFromFile("xyzabc.so") - if err == nil { - t.Error("Expected an error, but not received\n") - return - } -} diff --git a/golang/src/ndarray.go b/golang/src/ndarray.go deleted file mode 100644 index b1e71aef56bd..000000000000 --- a/golang/src/ndarray.go +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMArray aka DLTensor - * \file ndarray.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" - "fmt" - "errors" - "runtime" - "reflect" -) - -// Array type in golang hold pointer for the TVMArray object from dlpack. -// -// Array initialization happen through Empty api -type Array uintptr - -// nativeCPtr returns type freed uintptr for the Array. -func (parray Array) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(parray) - return -} - -func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) { - ret := C.TVMArrayCopyFromBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())), - data, - C.ulong(datalen)) - if ret != 0 { - err = errors.New(getTVMLastError()) - } - return -} - -// CopyFrom copies given golang data slice into Array. -// -// `val` is interface homding a slice of Array data type. -// -// returns err is any. -// TOD: Use reflections for better handling -func (parray Array) CopyFrom(val interface{}) (err error) { - var data unsafe.Pointer - var datalen int - dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype - - switch val.(type) { - case []int8: - sliceVal := val.([]int8) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []int16: - sliceVal := val.([]int16) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []int32: - sliceVal := val.([]int32) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []int64: - sliceVal := val.([]int64) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint8: - sliceVal := val.([]uint8) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint16: - sliceVal := val.([]uint16) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint32: - sliceVal := val.([]uint32) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint64: - sliceVal := val.([]uint64) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []float32: - sliceVal := val.([]float32) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []float64: - sliceVal := val.([]float64) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - default: - err = fmt.Errorf("Given type not supported : %v", reflect.TypeOf(val)) - return - } - return -} - -func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){ - ret := C.TVMArrayCopyToBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())), - unsafe.Pointer(data), - C.ulong(datalen)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - } - return -} - -// AsSlice returns the unitptr of for the data inside Array. -// -// returns the slice of array inside Array and err of any. -// TOD: Use reflections for better handling -func (parray Array) AsSlice() (retVal interface{}, err error) { - shape := parray.GetShape() - size := int64(1) - var data unsafe.Pointer - var datalen int - - for ii := range shape { - size *= shape[ii] - } - dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype - - switch parray.GetDType() { - case "int8": - sliceVal := make([]int8, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "int16": - sliceVal := make([]int16, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "int32": - sliceVal := make([]int32, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "int64": - sliceVal := make([]int64, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint8": - sliceVal := make([]uint8, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint16": - sliceVal := make([]uint16, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint32": - sliceVal := make([]uint32, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint64": - sliceVal := make([]uint64, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "float32": - sliceVal := make([]float32, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "float64": - sliceVal := make([]float64, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - default: - err = fmt.Errorf("Given type not supported : %v", parray.GetDType()) - return - } - return -} - -// GetNdim returns the number of dimentions in Array -func (parray Array) GetNdim() (retVal int32) { - retVal = int32(((*C.DLTensor)(unsafe.Pointer(parray))).ndim) - return -} - -// GetShape returns the number of dimentions in Array -func (parray Array) GetShape() (retVal []int64) { - shapePtr := (*C.int64_t)(((*C.DLTensor)(unsafe.Pointer(parray))).shape) - ndim := parray.GetNdim() - - shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim] - retVal = make([]int64, ndim) - copy(retVal, shapeSlice) - return -} - -// GetDType returns the number of dimentions in Array -func (parray Array) GetDType() (retVal string) { - ret := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype - retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret))) - return -} - -// GetDevice returns the number of dimentions in Array -func (parray Array) GetDevice() (retVal Device) { - ret := ((*C.DLTensor)(unsafe.Pointer(parray))).device - retVal = *(*Device)(unsafe.Pointer(&ret)) - return -} - -// nativeTVMArrayAlloc is used to allocate TVMArray from given attributes. -// -// `shape` is int64 slice holding shape of the Array to be created. -// -// `ndim` is the rank of the Array to be created. -// -// `dtypeCode`, `dtypeBits` and `dtypeLanes` describe the data type in Array. -// -// `deviceType` indicates the device on whose memory the Array to allocated. -// -// `deviceID` indicates device index if multiple devices of same type present. -// -// return argument holding native pointer to newly created Array and error is any. -func nativeTVMArrayAlloc(shape []int64, ndim int32, - dtypeCode int32, dtypeBits int32, dtypeLanes int32, - deviceType int32, deviceID int32) (retVal uintptr, err error) { - ret := (int32)(C.TVMArrayAlloc((*C.long)(&(shape[0])), - C.int(ndim), - C.int(dtypeCode), - C.int(dtypeBits), - C.int(dtypeLanes), - C.int(deviceType), - C.int(deviceID), - (*C.TVMArrayHandle)(unsafe.Pointer(&retVal)))) - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - return -} - -// Empty is used to allocate TVM empty array of given epecification. -// -// `shape` is int64 slice holding shape of the Array -// -// `args` is variadic args for -// -// `args[0]` is string for data type. Default value is 'float32' -// -// `args[1]` is Device. Default value is '{KDLCPU, 0}' -// -// returns pointer to Array on successful execution and error if any. -func Empty(shape []int64, args ...interface{}) (parray *Array, err error) { - typeName := "float32" - dev := Device{KDLCPU, 0} - - if len(shape) < 1 { - err = fmt.Errorf("Invalid shape for Array creation: %v", len(shape)) - return - } - - for i, val := range args { - switch val.(type) { - case string: - typeName = args[i].(string) - case Device: - dev = args[i].(Device) - default: - err = fmt.Errorf("Invalid Optional Argument Type: %T", val) - return - } - } - - tvmType, err := dtypeToTVMType(typeName) - if err != nil { - return - } - ndim := int32(len(shape)) - newArray, err := nativeTVMArrayAlloc(shape, ndim, int32(tvmType.code), - int32(tvmType.bits), int32(tvmType.lanes), - dev.DeviceType, dev.DeviceID) - if err != nil { - return - } - handle := new(Array) - *handle = Array(newArray) - - finalizer := func (ahandle *Array) { - nativeTVMArrayFree(*ahandle) - ahandle = nil - } - runtime.SetFinalizer(handle, finalizer) - parray = handle - return -} - -// nativeTVMArrayFree is used to release the Array. -// -// `parray` is the Array handle. -// -// `ret` indicates the status of this api execution. -func nativeTVMArrayFree(parray Array) (retVal int32) { - retVal = (int32)(C.TVMArrayFree((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())))) - return -} diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc deleted file mode 100644 index 475abf5a3e36..000000000000 --- a/golang/src/tvm_runtime_pack.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief This is an all in one TVM runtime file. - * \file tvm_runtime_pack.cc - */ -#include "src/runtime/c_runtime_api.cc" -#include "src/runtime/container.cc" -#include "src/runtime/cpu_device_api.cc" -#include "src/runtime/file_utils.cc" -#include "src/runtime/library_module.cc" -#include "src/runtime/logging.cc" -#include "src/runtime/module.cc" -#include "src/runtime/ndarray.cc" -#include "src/runtime/object.cc" -#include "src/runtime/registry.cc" -#include "src/runtime/thread_pool.cc" -#include "src/runtime/threading_backend.cc" -#include "src/runtime/workspace_pool.cc" - -// NOTE: all the files after this are optional modules -// that you can include remove, depending on how much feature you use. - -// Likely we only need to enable one of the following -// If you use Module::Load, use dso_module -// For system packed library, use system_lib_module -#include "src/runtime/dso_library.cc" -#include "src/runtime/system_library.cc" - -// Graph executor -#include "src/runtime/memory/memory_manager.cc" - -// Uncomment the following lines to enable RPC -// #include "../../src/runtime/rpc/rpc_session.cc" -// #include "../../src/runtime/rpc/rpc_event_impl.cc" -// #include "../../src/runtime/rpc/rpc_server_env.cc" - -// These macros enables the device API when uncommented. -#define TVM_CUDA_RUNTIME 1 -#define TVM_METAL_RUNTIME 1 -#define TVM_OPENCL_RUNTIME 1 - -// Uncomment the following lines to enable Metal -// #include "../../src/runtime/metal/metal_device_api.mm" -// #include "../../src/runtime/metal/metal_module.mm" - -// Uncomment the following lines to enable CUDA -// #include "../../src/runtime/cuda/cuda_device_api.cc" -// #include "../../src/runtime/cuda/cuda_module.cc" - -// Uncomment the following lines to enable OpenCL -// #include "../../src/runtime/opencl/opencl_device_api.cc" -// #include "../../src/runtime/opencl/opencl_module.cc" -// #include "../src/runtime/source_utils.cc" diff --git a/golang/src/type.go b/golang/src/type.go deleted file mode 100644 index 6202e0baa875..000000000000 --- a/golang/src/type.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package for TVMType interface - * \file type.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "fmt" -) - -// pTVMType corresponding to data types. -type pTVMType struct { - code uint8 - bits uint8 - lanes uint16 -} - -// data type to pTVMType mapping -var dtypeMap = map[string] pTVMType { - "int8": pTVMType{0, 8, 1}, - "int16": pTVMType{0, 16, 1}, - "int32": pTVMType{0, 32, 1}, - "int64": pTVMType{0, 64, 1}, - "uint8": pTVMType{1, 8, 1}, - "uint16": pTVMType{1, 16, 1}, - "uint32": pTVMType{1, 32, 1}, - "uint64": pTVMType{1, 64, 1}, - "float32": pTVMType{2, 32, 1}, - "float64": pTVMType{2, 64, 1}, -} - -// dtypeFromTVMType return the pTVMType corresponding to given dtype -// -// `dtype` string for the given data type. -func dtypeFromTVMType(tvmtype pTVMType) (retVal string, err error) { - for k, v := range dtypeMap { - if v.code == tvmtype.code && v.bits == tvmtype.bits && v.lanes == tvmtype.lanes { - retVal = k - return - } - } - - err = fmt.Errorf("Cannot map TVMType:%v to dtype", tvmtype) - return -} - -// dtypeToTVMType return the pTVMType corresponding to given dtype -// -// `dtype` string for the given data type. -func dtypeToTVMType(args ...interface{}) (tvmtype pTVMType, err error) { - dtype := args[0].(string) - lanes := 1 - - if len(args) == 2 { - lanes = args[1].(int) - } - - for k, v := range dtypeMap { - if k == dtype { - tvmtype = v - tvmtype.lanes = uint16(lanes) - return - } - } - err = fmt.Errorf("Cannot map dtype:%v to TVMType", dtype) - return -} diff --git a/golang/src/utils.go b/golang/src/utils.go deleted file mode 100644 index 2da4138a1e66..000000000000 --- a/golang/src/utils.go +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for common utilities - * \file utils.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" -) - -// Native string map for go string -type nativeGoString struct { p uintptr; n int32 } - -func goStringFromNative (s string) (retStr string) { - p := *(*nativeGoString)(unsafe.Pointer(&s)) - retStr = string((*[0x7fffffff]byte)(unsafe.Pointer(p.p))[:p.n]) - C.free(unsafe.Pointer(p.p)) - return -} diff --git a/golang/src/value.go b/golang/src/value.go deleted file mode 100644 index 450cf4866ab0..000000000000 --- a/golang/src/value.go +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMValue interface - * \file value.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "fmt" - "runtime" - "unsafe" -) - -// KHandle is golang type code for TVM enum kTVMOpaqueHandle. -var KHandle = int32(C.kTVMOpaqueHandle) -// KNull is golang type code for TVM kTVMNullptr. -var KNull = int32(C.kTVMNullptr) -// KTVMType is golang type code for TVM kTVMDataType. -var KTVMType = int32(C.kTVMDataType) -// KDLDevice is golang type code for TVM kDLDevice. -var KDLDevice = int32(C.kDLDevice) -// KArrayHandle is golang type code for TVM kTVMDLTensorHandle. -var KArrayHandle = int32(C.kTVMDLTensorHandle) -// KObjectHandle is golang type code for TVM kTVMObjectHandle. -var KObjectHandle = int32(C.kTVMObjectHandle) -// KModuleHandle is gonag type code for TVM kTVMModuleHandle. -var KModuleHandle = int32(C.kTVMModuleHandle) -// KFuncHandle is gonalg type code for TVM kTVMPackedFuncHandle. -var KFuncHandle = int32(C.kTVMPackedFuncHandle) -// KStr is golang type code for TVM kTVMStr. -var KStr = int32(C.kTVMStr) -// KBytes is golang type code for TVM kTVMBytes. -var KBytes = int32(C.kTVMBytes) -// KNDArrayContainer is golang typecode for kTVMNDArrayHandle. -var KNDArrayContainer = int32(C.kTVMNDArrayHandle) -// KExtBegin is golang enum corresponding to TVM kTVMExtBegin. -var KExtBegin = int32(C.kTVMExtBegin) -// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst. -var KNNVMFirst = int32(C.kTVMNNVMFirst) -// KNNVMLast is golang enum corresponding to TVM kNNVMLast. -var KNNVMLast = int32(C.kTVMNNVMLast) -// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd. -var KExtReserveEnd = int32(C.kTVMExtReserveEnd) -// KExtEnd is golang enum corresponding to TVM kExtEnd. -var KExtEnd = int32(C.kTVMExtEnd) -// KDLInt is golang type code for TVM kDLInt. -var KDLInt = int32(C.kDLInt) -// KDLUInt is golang type code for TVM kDLUInt. -var KDLUInt = int32(C.kDLUInt) -// KDLFloat is golang type code for TVM kDLFloat. -var KDLFloat = int32(C.kDLFloat) - -// Value Typemap for union exposed by TVM runtime API. -// -// gotvm maps it to a uintptr and then dynamically allocates memory by newTVMValue method. -type Value struct { - nptr uintptr - dtype int32 - isLocal bool -} - -// AsInt64 returns the int64 value inside the Value. -func (tvmval *Value) AsInt64() (retVal int64) { - retVal = tvmval.getVInt64() - return -} - -// AsFloat64 returns the Float64 value inside the Value. -func (tvmval *Value) AsFloat64() (retVal float64) { - retVal = tvmval.getVFloat64() - return -} - -// AsModule returns the Module inside the Value. -func (tvmval *Value) AsModule() (retVal *Module) { - mhandle := tvmval.getVMHandle() - retVal = &mhandle - return -} - -// AsFunction returns the Function inside the Value. -func (tvmval *Value) AsFunction() (retVal *Function) { - fhandle := tvmval.getVFHandle() - retVal = &fhandle - - return -} - -// AsBytes returns the byte slice value inside the Value. -func (tvmval *Value) AsBytes() (retVal []byte) { - retVal = tvmval.getVBHandle().getData() - return -} - -// AsStr returns the golang string in the Value. -func (tvmval *Value) AsStr() (retVal string) { - str := tvmval.getVStr() - retVal = str - return -} - -// nativeCPtr return the unitptr corresponding to Value type. -func (tvmval *Value) nativeCPtr() (ret uintptr) { - ret = (uintptr)(tvmval.nptr) - return -} - -// moveFrom copies the tvmval from other Value object. -func (tvmval *Value) moveFrom(fromval *Value) () { - C.memcpy(unsafe.Pointer(tvmval.nativeCPtr()), - unsafe.Pointer(fromval.nativeCPtr()), - C.sizeof_TVMValue) - - // Move the dtype too. - tvmval.dtype = fromval.dtype - fromval.dtype = KNull - return -} - -// setVInt64 initializes the Value object with given int64 value. -// -// `val` is the int64 value to initialize the Value -func (tvmval *Value) setVInt64(val int64) { - valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = C.int64_t(val) - tvmval.dtype = KDLInt - return -} - - -// getVInt64 returns the int64 value inside the Value. -func (tvmval *Value) getVInt64() (retVal int64) { - valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = int64(*valp) - return -} - -// setVFloat64 initializes the Value object with given float64 value. -// -// `val` is the float64 value to initialize the Value. -func (tvmval *Value) setVFloat64(val float64) { - valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = C.double(val) - tvmval.dtype = KDLFloat - return -} - -// getVFloat64 returns the float64 value inside Value. -func (tvmval *Value) getVFloat64() (retVal float64) { - valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = float64(*valp) - return -} - -// setVHandle initializes the handle inside the Value. -// -// Can be used to store any uintptr type object like -// module handle, function handle and any object's nativeCPtr. -// -// `val` is the uintptr type of given handle. -func (tvmval *Value) setVHandle(val uintptr) { - valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = (*C.void)(unsafe.Pointer(val)) -} - -// getVHandle returns the uintptr handle -func (tvmval *Value) getVHandle() (retVal uintptr) { - valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = uintptr(unsafe.Pointer(*valp)) - return -} - -// setVStr intializes the Value with given golang string object. -// -// `val` is the golang string object used to initialize the Value. -func (tvmval *Value) setVStr(val string) { - valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = C.CString(val) - tvmval.dtype = KStr - return -} - - -// getVStr returns the golang string for the native string inside Value. -func (tvmval *Value) getVStr() (retVal string) { - valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = C.GoString(*valp) - return -} - -// unSetVStr release the memory allocated in setVStr -func (tvmval *Value) unSetVStr() { - valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) - C.free(unsafe.Pointer(*valp)) - tvmval.dtype = KNull -} - -// setVAHandle is used to set Array handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead too. -// This is a wrapper to accept Array directly. -func (tvmval *Value) setVAHandle(ptvmarray Array) { - tvmval.setVHandle(ptvmarray.nativeCPtr()) - tvmval.dtype = KArrayHandle - return -} - -// getVAHandle is used to get Array handle in Value. -func (tvmval *Value) getVAHandle() (retVal Array) { - retVal = (Array)(tvmval.getVHandle()) - return -} - -// setVMHandle is used to set Module handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead too. -// This is a wrapper to accept Module directly. -func (tvmval *Value) setVMHandle(tvmmodule Module) { - tvmval.setVHandle(tvmmodule.nativeCPtr()) - tvmval.dtype = KModuleHandle - return -} - -// getVMHandle is used to get Module handle in Value. -func (tvmval *Value) getVMHandle() (retVal Module) { - retVal = (Module)(tvmval.getVHandle()) - return -} - -// setVFHandle is used to set Function handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead. -// This is a wrapper to accept Function directly. -func (tvmval *Value) setVFHandle(tvmfunction Function) { - tvmval.setVHandle(tvmfunction.nativeCPtr()) - tvmval.dtype = KFuncHandle - return -} - -// getVFHandle is used to get Function handle in Value. -func (tvmval *Value) getVFHandle() (retVal Function) { - retVal = (Function)(tvmval.getVHandle()) - return -} - -// setVBHandle is used to set ByteArray handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead. -// This is a wrapper to accept ByteArray directly. -func (tvmval *Value) setVBHandle(tbytearray ByteArray) { - tvmval.setVHandle(tbytearray.nativeCPtr()) - tvmval.dtype = KBytes - return -} - -// getVBHandle is used to get ByteArray handle in Value. -func (tvmval *Value) getVBHandle() (retVal ByteArray) { - retVal = (ByteArray)(tvmval.getVHandle()) - return -} - -// setValue is used to set the given value in Value. -// -// `val` is value of types accepted by Value container or native union. -func (tvmval *Value) setValue(val interface{}) (retVal int32, err error) { - retVal = KNull - switch val.(type) { - case string: - tvmval.setVStr(val.(string)) - case uint8: - tvmval.setVInt64(int64(val.(uint8))) - case uint16: - tvmval.setVInt64(int64(val.(uint16))) - case uint32: - tvmval.setVInt64(int64(val.(uint32))) - case uint64: - tvmval.setVInt64(int64(val.(uint64))) - case int: - tvmval.setVInt64(int64(val.(int))) - case int8: - tvmval.setVInt64(int64(val.(int8))) - case int16: - tvmval.setVInt64(int64(val.(int16))) - case int32: - tvmval.setVInt64(int64(val.(int32))) - case int64: - tvmval.setVInt64(val.(int64)) - case float32: - tvmval.setVFloat64(float64(val.(float32))) - case float64: - tvmval.setVFloat64(val.(float64)) - case *Module: - tvmval.setVMHandle(*(val.(*Module))) - case *Function: - tvmval.setVFHandle(*(val.(*Function))) - case *ByteArray: - tvmval.setVBHandle(*(val.(*ByteArray))) - case []byte: - barray := newByteArray(val.([]byte)) - tvmval.setVBHandle(barray) - case *Array: - tvmval.setVAHandle(*(val.(*Array))) - case func (args ...*Value) (interface{}, error): - fhandle, apierr := ConvertFunction(val) - if apierr != nil { - err = fmt.Errorf("Given value Type not defined for Value: %v : %T", val, val); - return - } - tvmval.setVFHandle(*fhandle) - - // Clear the finalizer as we don't need to control it anymore. - runtime.SetFinalizer(fhandle, nil) - case *Value: - tvmval.moveFrom(val.(*Value)) - case Value: - fromval := val.(Value) - tvmval.moveFrom(&fromval) - default: - err = fmt.Errorf("Given value Type not defined for Value: %v : %T", val, val); - } - retVal = tvmval.dtype - return -} - -// newTVMValue initialize the TVMValue native object. -// -// This is intended to use as intermediate type between native and golang types. -// Allocated from FuncCall or Callback to handle conversions. -func newTVMValue() (retVal *Value) { - handle := new(Value) - - handle.nptr = (uintptr(C.malloc(C.sizeof_TVMValue))) - handle.dtype = KNull - handle.isLocal = true - finalizer := func(vhandle *Value) { - vhandle.deleteTVMValue() - vhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// deleteTVMValue free the native Value object which is allocated in newTVMValue. -func (tvmval Value) deleteTVMValue() { - if tvmval.isLocal == true { - if tvmval.dtype == KStr { - tvmval.unSetVStr() - } - if tvmval.dtype == KBytes { - tvmval.getVBHandle().deleteTVMByteArray() - } - } - - C.free(unsafe.Pointer(tvmval.nativeCPtr())) -} diff --git a/golang/src/value_test.go b/golang/src/value_test.go deleted file mode 100644 index ba502254cd20..000000000000 --- a/golang/src/value_test.go +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file value_test.go - */ - -package gotvm - -import ( - "testing" - "math/rand" - "strings" -) - -// Check Int64 Value looping via packed function calling another packed function. -func TestValueLoopInt64(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Int63() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - if retVal.AsInt64() != result { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -// Check Int32 Value looping via packed function calling another packed function. -func TestValueLoopInt32(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Int31() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsInt64() != int64(result) { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -// Check Float32 Value looping via packed function calling another packed function. -func TestValueLoopFloat32(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Float32() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsFloat64() != float64(result) { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -// Check Float64 Value looping via packed function calling another packed function. -func TestValueLoopFloat64(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Float64() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsFloat64() != result { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -func TestValueLoopString(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - argStr := args[1].AsStr() - // Call Packed Function by Value - return pfunc.Invoke(argStr) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0].AsStr() - return - } - - retVal, err := fhandle.Invoke(funccall, "TestString") - if err != nil { - t.Error(err.Error()) - return - } - - vStr := retVal.AsStr() - if strings.Compare(vStr, string("TestString")) != 0 { - t.Errorf("Expected : %v got:%v\n", string("TestString"), vStr) - return - } -} - -// Check []byte Value looping via packed function calling another packed function. -func TestValueLoopByteSlice(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - argBytes := args[1].AsBytes() - // Call Packed Function by Value - return pfunc.Invoke(argBytes) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0].AsBytes() - return - } - - result := make([]byte, 1024) - rand.Read(result) - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - received := retVal.AsBytes() - if len(result) != len(received) { - t.Errorf("Data expected Len: %v Got :%v\n", len(result), len(received)) - return - } - for i := range result { - if result[i] != received[i] { - t.Errorf("Data expected: %v Got :%v at index %v\n", result[i], received[i], i) - return - } - } -} diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index f09564d050ca..29d6282ca351 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -299,7 +299,8 @@ Map AsIntSet(const Map& var_dom); * \param var_dom The ranges of the variables * \param predicate The predicate for the affine map * \param analyzer The analyzer used - * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis + * \return std::nullopt if the detection fails, or an array of arith::IntSet as the result of + * analysis */ TVM_DLL Optional> EstimateRegionStrictBound(const Array& region, const Map& var_dom, @@ -313,7 +314,8 @@ TVM_DLL Optional> EstimateRegionStrictBound(const Array& re * \param var_dom The ranges of the variables * \param predicate The predicate for the affine map * \param analyzer The analyzer used - * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis + * \return std::nullopt if the detection fails, or an array of arith::IntSet as the result of + * analysis */ TVM_DLL Optional> EstimateRegionLowerBound(const Array& region, const Map& var_dom, diff --git a/include/tvm/ir/analysis.h b/include/tvm/ir/analysis.h index afe18792dee0..ad95f2f0ebb5 100644 --- a/include/tvm/ir/analysis.h +++ b/include/tvm/ir/analysis.h @@ -27,10 +27,10 @@ #ifndef TVM_IR_ANALYSIS_H_ #define TVM_IR_ANALYSIS_H_ +#include #include #include #include -#include namespace tvm { namespace ir { diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 6e004415f8e9..6378d6f74ac2 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -45,11 +45,11 @@ #define TVM_IR_ATTRS_H_ #include +#include +#include #include #include #include -#include -#include #include #include @@ -604,7 +604,7 @@ inline void SetValue(T* ptr, const ffi::AnyView& val) { template inline void SetIntValue(T* ptr, const ffi::AnyView& val) { - if (auto opt_int = val.as()) { + if (auto opt_int = val.try_cast()) { *ptr = static_cast(opt_int.value()); } else { IntImm expr = val.cast(); @@ -620,16 +620,12 @@ inline void SetValue(DataType* ptr, const ffi::AnyView& val) { template <> inline void SetValue(std::string* ptr, const ffi::AnyView& val) { - if (auto opt_str = val.as()) { - *ptr = opt_str.value(); - } else { - LOG(FATAL) << "Expect str"; - } + *ptr = val.cast(); } template <> inline void SetValue(double* ptr, const ffi::AnyView& val) { - if (auto opt_double = val.as()) { + if (auto opt_double = val.try_cast()) { *ptr = opt_double.value(); } else { ObjectRef expr = val.cast(); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 52fab116360c..ab5cf31c6c86 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -24,8 +24,8 @@ #ifndef TVM_IR_ENV_FUNC_H_ #define TVM_IR_ENV_FUNC_H_ +#include #include -#include #include #include @@ -142,8 +142,8 @@ class TypedEnvFunc : public ObjectRef { if constexpr (std::is_same_v) { n->func(std::forward(args)...); } else { - Any res = n->func(std::forward(args)...); - if constexpr (std::is_same_v) { + ffi::Any res = n->func(std::forward(args)...); + if constexpr (std::is_same_v) { return res; } else { return std::move(res).cast(); diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index a3defa592af6..974747c77416 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,10 +24,10 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ +#include #include #include #include -#include #include #include @@ -38,7 +38,7 @@ namespace tvm { -using tvm::runtime::String; +using tvm::String; // Forward-declare VirtualDevice to avoid circular imports. class VirtualDevice; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index e19e3f3af124..fa51856a0104 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -24,11 +24,11 @@ #ifndef TVM_IR_FUNCTION_H_ #define TVM_IR_FUNCTION_H_ +#include +#include +#include #include #include -#include -#include -#include #include #include diff --git a/include/tvm/ir/instrument.h b/include/tvm/ir/instrument.h index 1b9eb9c1b7c8..b1ef86c12c58 100644 --- a/include/tvm/ir/instrument.h +++ b/include/tvm/ir/instrument.h @@ -26,8 +26,8 @@ #ifndef TVM_IR_INSTRUMENT_H_ #define TVM_IR_INSTRUMENT_H_ +#include #include -#include #include #include diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 66637f67d948..994f3a4bb86a 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -24,14 +24,14 @@ #ifndef TVM_IR_MODULE_H_ #define TVM_IR_MODULE_H_ +#include +#include +#include #include #include #include #include #include -#include -#include -#include #include #include diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 8eaa62a98120..9c758a52b384 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -25,13 +25,13 @@ #ifndef TVM_IR_OP_H_ #define TVM_IR_OP_H_ +#include #include #include #include #include #include #include -#include #include #include diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 7b79a2c89455..83e2f4f375bf 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -23,10 +23,9 @@ #ifndef TVM_IR_SOURCE_MAP_H_ #define TVM_IR_SOURCE_MAP_H_ +#include #include #include -#include -#include #include #include diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 0da882f3884d..8562fbaa8ff4 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -56,11 +56,11 @@ #ifndef TVM_IR_TRANSFORM_H_ #define TVM_IR_TRANSFORM_H_ +#include +#include #include #include #include -#include -#include #include #include @@ -273,10 +273,10 @@ class PassContext : public ObjectRef { auto type_key = ffi::TypeIndexToTypeKey(tindex); auto legalization = [=](ffi::Any value) -> ffi::Any { - if (auto opt_map = value.as>()) { + if (auto opt_map = value.try_cast>()) { return reflection->CreateObject(type_key, opt_map.value()); } else { - auto opt_val = value.as(); + auto opt_val = value.try_cast(); if (!opt_val.has_value()) { TVM_FFI_THROW(AttributeError) << "Expect config " << key << " to have type " << type_key << ", but instead get " @@ -365,7 +365,7 @@ class PassInfo : public ObjectRef { * \param required The passes that are required to perform the current pass. * \param traceable Boolean that tells whether the pass is traceable. */ - TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); + TVM_DLL PassInfo(int opt_level, String name, Array required, bool traceable); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -538,7 +538,7 @@ class Sequential : public Pass { * \return The created module pass. */ TVM_DLL Pass CreateModulePass(std::function pass_func, - int opt_level, String name, Array required, + int opt_level, String name, Array required, bool traceable = false); /* diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 5faaa31ca23b..2e49a9c5185b 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -49,9 +49,9 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ +#include #include #include -#include #include #include diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h index ccf093126232..2768ed2737dc 100644 --- a/include/tvm/meta_schedule/arg_info.h +++ b/include/tvm/meta_schedule/arg_info.h @@ -19,10 +19,10 @@ #ifndef TVM_META_SCHEDULE_ARG_INFO_H_ #define TVM_META_SCHEDULE_ARG_INFO_H_ +#include #include #include #include -#include #include #include #include @@ -81,7 +81,7 @@ class TensorInfoNode : public ArgInfoNode { /*! \brief The data type of the tensor. */ runtime::DataType dtype; /*! \brief The shape of the tensor. */ - runtime::ShapeTuple shape; + ffi::Shape shape; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -106,7 +106,7 @@ class TensorInfo : public ArgInfo { * \param dtype The data type of the tensor argument. * \param shape The shape tuple of the tensor argument. */ - TVM_DLL explicit TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape); + TVM_DLL explicit TensorInfo(runtime::DataType dtype, ffi::Shape shape); /*! * \brief Parse the argument information from a JSON object. * \param json_obj The json object to parse. diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 76883feda76c..24e136f9d345 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -19,15 +19,15 @@ #ifndef TVM_META_SCHEDULE_BUILDER_H_ #define TVM_META_SCHEDULE_BUILDER_H_ +#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include #include #include -#include #include namespace tvm { @@ -66,7 +66,7 @@ class BuilderInput : public runtime::ObjectRef { * \param params Parameters for Relax build module. */ TVM_DLL explicit BuilderInput(IRModule mod, Target target, - Optional> params = NullOpt); + Optional> params = std::nullopt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 48a27340bff1..300f53e113bd 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -20,14 +20,14 @@ #ifndef TVM_META_SCHEDULE_COST_MODEL_H_ #define TVM_META_SCHEDULE_COST_MODEL_H_ +#include +#include +#include #include #include #include #include -#include -#include #include -#include #include #include diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 45c4a241e29d..570da2cf0650 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -19,14 +19,14 @@ #ifndef TVM_META_SCHEDULE_DATABASE_H_ #define TVM_META_SCHEDULE_DATABASE_H_ +#include +#include +#include #include #include #include #include -#include -#include #include -#include #include #include #include @@ -238,7 +238,7 @@ class DatabaseNode : public runtime::Object { * \param mod The IRModule to be searched for. * \param target The target to be searched for. * \param workload_name The name of the workload to be searched for. - * \return The best record of the given workload; NullOpt if not found. + * \return The best record of the given workload; std::nullopt if not found. */ virtual Optional QueryTuningRecord(const IRModule& mod, const Target& target, const String& workload_name); @@ -247,7 +247,7 @@ class DatabaseNode : public runtime::Object { * \param mod The IRModule to be searched for. * \param target The target to be searched for. * \param workload_name The name of the workload to be searched for. - * \return The schedule in the best schedule of the given workload; NullOpt if not found. + * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ virtual Optional QuerySchedule(const IRModule& mod, const Target& target, const String& workload_name); @@ -256,7 +256,7 @@ class DatabaseNode : public runtime::Object { * \param mod The IRModule to be searched for. * \param target The target to be searched for. * \param workload_name The name of the workload to be searched for. - * \return The IRModule in the best IRModule of the given workload; NullOpt if not found. + * \return The IRModule in the best IRModule of the given workload; std::nullopt if not found. */ virtual Optional QueryIRModule(const IRModule& mod, const Target& target, const String& workload_name); @@ -330,7 +330,7 @@ class PyDatabaseNode : public DatabaseNode { * \param mod The IRModule to be searched for. * \param target The target to be searched for. * \param workload_name The name of the workload to be searched for. - * \return The best record of the given workload; NullOpt if not found. + * \return The best record of the given workload; std::nullopt if not found. */ using FQueryTuningRecord = ffi::TypedFunction(const IRModule&, const Target&, const String&)>; @@ -339,7 +339,7 @@ class PyDatabaseNode : public DatabaseNode { * \param mod The IRModule to be searched for. * \param target The target to be searched for. * \param workload_name The name of the workload to be searched for. - * \return The schedule in the best schedule of the given workload; NullOpt if not found. + * \return The schedule in the best schedule of the given workload; std::nullopt if not found. */ using FQuerySchedule = ffi::TypedFunction(const IRModule&, const Target&, const String&)>; @@ -348,7 +348,7 @@ class PyDatabaseNode : public DatabaseNode { * \param mod The IRModule to be searched for. * \param target The target to be searched for. * \param workload_name The name of the workload to be searched for. - * \return The IRModule in the best IRModule of the given workload; NullOpt if not found. + * \return The IRModule in the best IRModule of the given workload; std::nullopt if not found. */ using FQueryIRModule = ffi::TypedFunction(const IRModule&, const Target&, const String&)>; diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index 239bf0dc5777..cfc1f29e8efb 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -19,10 +19,10 @@ #ifndef TVM_META_SCHEDULE_EXTRACTED_TASK_H_ #define TVM_META_SCHEDULE_EXTRACTED_TASK_H_ +#include +#include #include #include -#include -#include #include #include diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index 3e01faccaf28..e45cb4eab195 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -20,13 +20,13 @@ #ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ #define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ +#include +#include +#include #include #include -#include -#include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 10356b6f5fb0..3a3d83cbf996 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -20,16 +20,16 @@ #ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ #define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ +#include +#include +#include #include #include #include #include #include #include -#include -#include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index f7257b56d206..9bfc9d0da954 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -20,9 +20,9 @@ #ifndef TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_ #define TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_ +#include #include #include -#include #include #include diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index f8bf69180db5..0f8e446784f3 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -20,10 +20,10 @@ #ifndef TVM_META_SCHEDULE_MUTATOR_H_ #define TVM_META_SCHEDULE_MUTATOR_H_ +#include +#include #include -#include #include -#include #include #include #include diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 5a2b96caf81f..e8648f038e61 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -20,9 +20,9 @@ #ifndef TVM_META_SCHEDULE_POSTPROC_H_ #define TVM_META_SCHEDULE_POSTPROC_H_ +#include #include #include -#include #include namespace tvm { diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index 91b7bfc45c09..6f8072b3f367 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -19,13 +19,13 @@ #ifndef TVM_META_SCHEDULE_PROFILER_H_ #define TVM_META_SCHEDULE_PROFILER_H_ +#include +#include +#include +#include #include #include -#include -#include -#include #include -#include #include #include diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index 0335f81cc16c..c8331a3a60e3 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -19,14 +19,14 @@ #ifndef TVM_META_SCHEDULE_RUNNER_H_ #define TVM_META_SCHEDULE_RUNNER_H_ +#include +#include +#include +#include #include #include #include -#include -#include -#include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 974254afc1b8..1a759c1b50fc 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -20,14 +20,14 @@ #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ #define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ +#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include #include -#include #include namespace tvm { @@ -57,8 +57,7 @@ class ScheduleRuleNode : public runtime::Object { * \param block The specific block to apply the schedule rule. * \return The list of schedules generated by applying the schedule rule. */ - virtual runtime::Array Apply(const tir::Schedule& sch, - const tir::BlockRV& block) = 0; + virtual Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) = 0; /*! * \brief Deep clone the schedule rule. @@ -141,13 +140,13 @@ class ScheduleRule : public runtime::ObjectRef { * - 'SSRSRS' on CPU * - 'SSSRRSRS' on GPU * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: - * - NullOpt on CPU + * - std::nullopt on CPU * - [blockIdx.x, vthread.x, threadIdx.x] on GPU - * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. - * NullOpt means disable vectorization - * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. - * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \param max_innermost_factor The maximum size of the innermost factor. std::nullopt means no + * limit \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * std::nullopt means disable vectorization + * \param reuse_read Data reuse configuration for reading. std::nullopt means no reuse. + * \param reuse_write Data reuse configuration for writing. std::nullopt means no reuse. * \param filter_fn A function that can be passed to overwrite the default condition for applying * MultiLevelTiling to a block. Its signature must be (Schedule, BlockRV) -> bool. * This is useful if there is a need to apply MultiLevelTiling to an operation / block which is @@ -160,7 +159,7 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write, - Optional filter_fn = NullOpt); + Optional filter_fn = std::nullopt); /*! * \brief Extension of MultiLevelTiling for auto-tensorization with a single intrinsic. @@ -170,13 +169,13 @@ class ScheduleRule : public runtime::ObjectRef { * - 'SSRSRS' on CPU * - 'SSSRRSRS' on GPU * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: - * - NullOpt on CPU + * - std::nullopt on CPU * - [blockIdx.x, vthread.x, threadIdx.x] on GPU - * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. - * NullOpt means disable vectorization - * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. - * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \param max_innermost_factor The maximum size of the innermost factor. std::nullopt means no + * limit \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * std::nullopt means disable vectorization + * \param reuse_read Data reuse configuration for reading. std::nullopt means no reuse. + * \param reuse_write Data reuse configuration for writing. std::nullopt means no reuse. * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( @@ -196,11 +195,11 @@ class ScheduleRule : public runtime::ObjectRef { * - 'SSSRRSRS' on GPU * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: * - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU - * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. - * NullOpt means disable vectorization - * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. - * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \param max_innermost_factor The maximum size of the innermost factor. std::nullopt means no + * limit \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * std::nullopt means disable vectorization + * \param reuse_read Data reuse configuration for reading. std::nullopt means no reuse. + * \param reuse_write Data reuse configuration for writing. std::nullopt means no reuse. * \param use_software_pipeline Whether use the software pipeline. * \return The schedule rule created */ @@ -216,9 +215,9 @@ class ScheduleRule : public runtime::ObjectRef { * maximum vector length. * \param structure The tiling structure. 'SSRSRS' is recommended. * \param vector_length_in_bits The length of a vector register in bits. - * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. - * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \param max_innermost_factor The maximum size of the innermost factor. std::nullopt means no + * limit \param reuse_read Data reuse configuration for reading. std::nullopt means no reuse. + * \param reuse_write Data reuse configuration for writing. std::nullopt means no reuse. * \return The schedule rule created */ TVM_DLL static ScheduleRule MultiLevelTilingWideVector( @@ -230,8 +229,8 @@ class ScheduleRule : public runtime::ObjectRef { * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable * parallelism. - * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \return The schedule rule created + * \param max_innermost_factor The maximum size of the innermost factor. std::nullopt means no + * limit \return The schedule rule created */ TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // Optional max_innermost_factor); diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index ca7ee7ec8407..c0b4677f84b5 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -19,16 +19,16 @@ #ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ #define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ +#include +#include +#include #include #include #include #include #include #include -#include -#include #include -#include #include namespace tvm { diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index b626b3e7739f..4ba3c0b089fc 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -19,14 +19,14 @@ #ifndef TVM_META_SCHEDULE_SPACE_GENERATOR_H_ #define TVM_META_SCHEDULE_SPACE_GENERATOR_H_ +#include +#include #include #include #include #include #include -#include #include -#include #include #include diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index e75059116dc2..7bf36873b3ce 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -19,16 +19,16 @@ #ifndef TVM_META_SCHEDULE_TASK_SCHEDULER_H_ #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_ +#include +#include +#include #include #include #include #include #include #include -#include -#include #include -#include #include #include @@ -54,11 +54,11 @@ class TaskRecordNode : public runtime::Object { /*! \brief The latency of each run, in milliseconds. */ std::vector latency_ms = {}; /*! \brief The measure candidates. */ - Optional> measure_candidates = NullOpt; + Optional> measure_candidates = std::nullopt; /*! \brief The building results. */ - Optional> builder_results = NullOpt; + Optional> builder_results = std::nullopt; /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures = NullOpt; + Optional> runner_futures = std::nullopt; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ctx", &ctx); diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 9eacf499f405..9045d4188ac1 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -19,6 +19,11 @@ #ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ +#include +#include +#include +#include +#include #include #include #include @@ -26,12 +31,7 @@ #include #include #include -#include -#include -#include -#include #include -#include #include #include diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 4e075c7e56da..37dc710ac161 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -23,7 +23,7 @@ #ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ #define TVM_NODE_ATTR_REGISTRY_MAP_H_ -#include +#include #include #include diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 12598cb156c2..8a9e763fecbf 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -34,12 +34,12 @@ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ +#include #include #include #include #include -#include -#include +#include #include #include @@ -60,7 +60,6 @@ using ffi::PackedArgs; using ffi::TypeIndex; using runtime::Downcast; using runtime::GetRef; -using runtime::make_object; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h index b2bfa4f27379..9c17487a1d64 100644 --- a/include/tvm/node/object_path.h +++ b/include/tvm/node/object_path.h @@ -26,8 +26,8 @@ #ifndef TVM_NODE_OBJECT_PATH_H_ #define TVM_NODE_OBJECT_PATH_H_ -#include -#include +#include +#include #include #include @@ -122,7 +122,7 @@ class ObjectPathNode : public Object { class ObjectPath : public ObjectRef { public: /*! \brief Create a path that represents the root object itself. */ - static ObjectPath Root(Optional name = NullOpt); + static ObjectPath Root(Optional name = std::nullopt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); }; @@ -137,7 +137,7 @@ class RootPathNode final : public ObjectPathNode { public: Optional name; - explicit RootPathNode(Optional name = NullOpt); + explicit RootPathNode(Optional name = std::nullopt); static constexpr const char* _type_key = "RootPath"; TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 0938f2c56ad2..e56639570e37 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -23,15 +23,15 @@ #ifndef TVM_NODE_REFLECTION_H_ #define TVM_NODE_REFLECTION_H_ +#include +#include +#include #include #include -#include -#include +#include #include -#include #include #include -#include #include #include @@ -164,7 +164,7 @@ class ReflectionVTable { * \param kwargs The field arguments. * \return The created object. */ - TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map& kwargs); + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map& kwargs); /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. @@ -246,12 +246,12 @@ class ReflectionVTable::Registry { * struct StringObjTrait { * static constexpr const std::nullptr_t VisitAttrs = nullptr; * - * static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { - * hash_reduce->SHashReduceHashedValue(runtime::String::StableHashBytes(key->data, key->size)); + * static void SHashReduce(const StringObj* key, SHashReducer hash_reduce) { + * hash_reduce->SHashReduceHashedValue(String::StableHashBytes(key->data, key->size)); * } * - * static bool SEqualReduce(const runtime::StringObj* lhs, - * const runtime::StringObj* rhs, + * static bool SEqualReduce(const StringObj* lhs, + * const StringObj* rhs, * SEqualReducer equal) { * if (lhs == rhs) return true; * if (lhs->size != rhs->size) return false; @@ -260,7 +260,7 @@ class ReflectionVTable::Registry { * } * }; * - * TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); + * TVM_REGISTER_REFLECTION_VTABLE(StringObj, StringObjTrait); * * \endcode * @@ -280,7 +280,7 @@ class ReflectionVTable::Registry { TVM_REGISTER_OBJECT_TYPE(TypeName); \ TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait) \ .set_creator([](const std::string&) -> ObjectPtr { \ - return ::tvm::runtime::make_object(); \ + return ::tvm::ffi::make_object(); \ }) // Implementation details diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index 9d2fa1023e92..721ae0932cdd 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -23,12 +23,13 @@ #ifndef TVM_NODE_SCRIPT_PRINTER_H_ #define TVM_NODE_SCRIPT_PRINTER_H_ +#include +#include +#include +#include #include #include #include -#include -#include -#include #include #include @@ -151,7 +152,7 @@ class PrinterConfigNode : public Object { class PrinterConfig : public ObjectRef { public: - explicit PrinterConfig(Map config_dict = Map()); + explicit PrinterConfig(Map config_dict = Map()); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrinterConfig, runtime::ObjectRef, PrinterConfigNode); @@ -168,7 +169,7 @@ class TVMScriptPrinter { }; #define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ - std::string Script(const Optional& config = NullOpt) const { \ + std::string Script(const Optional& config = std::nullopt) const { \ return TVMScriptPrinter::Script(GetRef(this), config.value_or(PrinterConfig())); \ } diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h index c99d0f7f73fb..5a8e098cfd6e 100644 --- a/include/tvm/node/serialization.h +++ b/include/tvm/node/serialization.h @@ -24,7 +24,7 @@ #ifndef TVM_NODE_SERIALIZATION_H_ #define TVM_NODE_SERIALIZATION_H_ -#include +#include #include #include diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 7f56fd6ca961..46087f0bda40 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -23,9 +23,9 @@ #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ #define TVM_NODE_STRUCTURAL_EQUAL_H_ +#include #include #include -#include #include #include @@ -253,24 +253,26 @@ class SEqualReducer { * \return the immediate check result. */ bool operator()(const double& lhs, const double& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; bool operator()(const int64_t& lhs, const int64_t& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; bool operator()(const uint64_t& lhs, const uint64_t& rhs, - Optional paths = NullOpt) const; - bool operator()(const int& lhs, const int& rhs, Optional paths = NullOpt) const; - bool operator()(const bool& lhs, const bool& rhs, Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; + bool operator()(const int& lhs, const int& rhs, + Optional paths = std::nullopt) const; + bool operator()(const bool& lhs, const bool& rhs, + Optional paths = std::nullopt) const; bool operator()(const std::string& lhs, const std::string& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; bool operator()(const DataType& lhs, const DataType& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; bool operator()(const Optional& lhs, const Optional& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; bool operator()(const Optional& lhs, const Optional& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; template ::value>::type> bool operator()(const ENum& lhs, const ENum& rhs, - Optional paths = NullOpt) const { + Optional paths = std::nullopt) const { using Underlying = typename std::underlying_type::type; static_assert(std::is_same::value, "Enum must have `int` as the underlying type"); @@ -327,7 +329,7 @@ class SEqualReducer { * \return the immediate check result. */ bool AnyEqual(const ffi::Any& lhs, const ffi::Any& rhs, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; /*! * \brief Reduce condition to comparison of two definitions, @@ -355,7 +357,11 @@ class SEqualReducer { // depth as array comparison is pretty common. if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { - if (!(operator()(lhs[i], rhs[i]))) return false; + if constexpr (std::is_same_v) { + if (!(AnyEqual(lhs[i], rhs[i]))) return false; + } else { + if (!(operator()(lhs[i], rhs[i]))) return false; + } } return true; } @@ -401,7 +407,7 @@ class SEqualReducer { private: bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address, - Optional paths = NullOpt) const; + Optional paths = std::nullopt) const; bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const ObjectPathPair* paths) const; @@ -413,7 +419,7 @@ class SEqualReducer { template static bool CompareAttributeValues(const T& lhs, const T& rhs, const PathTracingData* tracing_data, - Optional paths = NullOpt); + Optional paths = std::nullopt); /*! \brief Internal class pointer. */ Handler* handler_ = nullptr; diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 85a7d0eee356..267eb1b66eeb 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -536,7 +536,7 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); * the caller can pass the function's name so recursive calls * can be ignored in the check (must be a Var or GlobalVar). * \return The impure expression, if one exists within the given - * expression. Otherwise, NullOpt. + * expression. Otherwise, std::nullopt. * \note Relies on StructInfo annotations, so ensure that the module has been normalized first. * Also, an impure call in a *nested* function does *not* mean that the outer expression contains * an impure call--it only does if the nested function is *later called*. diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 3a5e3951af44..f8a6ddfe0aa2 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -229,6 +229,15 @@ struct ScatterNDAttrs : public tvm::AttrsNode { } }; // struct ScatterNDAttrs +/*! \brief Attributes used in slice_scatter operator */ +struct SliceScatterAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(SliceScatterAttrs, "relax.attrs.SliceScatterAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("the dimension to insert the slice into "); + } +}; // struct SliceScatterAttrs + /*! \brief Attributes used in one_hot operator */ struct OneHotAttrs : public tvm::AttrsNode { int depth; diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 5efe91a5e437..c33d99b5f91f 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -126,7 +126,7 @@ class BlockBuilderNode : public Object { * \brief Lookup the binding value that var binds to in the current emitted sequences. * \param var The input var. * \return The Expr bound to the input \p var. - * \note For function parameters, this function returns NullOpt. + * \note For function parameters, this function returns std::nullopt. */ virtual Optional LookupBinding(const Var& var) = 0; diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index fd2fa72a2410..80359135c200 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -24,9 +24,9 @@ #ifndef TVM_RELAX_DATAFLOW_MATCHER_H_ #define TVM_RELAX_DATAFLOW_MATCHER_H_ +#include +#include #include -#include -#include #include @@ -44,11 +44,11 @@ namespace relax { * \return true if matched * \return false if unmatched */ -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings = NullOpt); +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings = std::nullopt); /* \brief Similar to above, but return pairs of a matching pattern and an expression. */ -Optional> ExtractMatchedExpr( - DFPattern pattern, Expr expr, Optional> bindings = NullOpt); +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings = std::nullopt); /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 36fac906c4de..987fe16207dc 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -24,11 +24,11 @@ #ifndef TVM_RELAX_DATAFLOW_PATTERN_H_ #define TVM_RELAX_DATAFLOW_PATTERN_H_ +#include +#include #include #include #include -#include -#include #include #include @@ -188,7 +188,7 @@ class DFConstraintNode : public Object { * \param match_state A function that can be called to check the * current state of the match. The function takes as argument a * pattern on which the constraint depends, and returns the relax - * variable matched by that pattern, or NullOpt if the pattern + * variable matched by that pattern, or std::nullopt if the pattern * has not yet been matched. * * \return A tuple of `PrimExpr` and `bool`. The first element is a @@ -946,11 +946,11 @@ ExprPattern IsExpr(const Expr& expr); ExprPattern IsOp(const String& op_name); /*! \brief Syntatic Sugar for call_tir (return a tensor) */ // Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo -CallPattern IsCallTIR(const String& name, Optional args = NullOpt); +CallPattern IsCallTIR(const String& name, Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ CallPattern IsCallTIR(const String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */ -CallPattern IsCallDPSPacked(const String& name, Optional args = NullOpt); +CallPattern IsCallDPSPacked(const String& name, Optional args = std::nullopt); /*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */ CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 5ab5a0263591..67aeccd2970a 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -37,7 +37,7 @@ namespace distributed { class DeviceMeshNode : public GlobalInfoNode { public: /*! \brief logical shape of the mesh*/ - ShapeTuple shape; + ffi::Shape shape; /*! \brief device ids in the mesh*/ Array device_ids; @@ -80,8 +80,8 @@ class DeviceMeshNode : public GlobalInfoNode { */ class DeviceMesh : public GlobalInfo { public: - TVM_DLL DeviceMesh(ShapeTuple shape, Array device_ids); - TVM_DLL DeviceMesh(ShapeTuple shape, Range device_range); + TVM_DLL DeviceMesh(ffi::Shape shape, Array device_ids); + TVM_DLL DeviceMesh(ffi::Shape shape, Range device_range); TVM_DEFINE_OBJECT_REF_METHODS(DeviceMesh, GlobalInfo, DeviceMeshNode); }; diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h index 2cee3bca631b..81d6d4eb379e 100644 --- a/include/tvm/relax/exec_builder.h +++ b/include/tvm/relax/exec_builder.h @@ -23,11 +23,11 @@ #ifndef TVM_RELAX_EXEC_BUILDER_H_ #define TVM_RELAX_EXEC_BUILDER_H_ +#include #include #include #include #include -#include #include #include diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 330ff7e8dab0..08cbd1de538e 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -19,13 +19,13 @@ #ifndef TVM_RELAX_EXPR_H_ #define TVM_RELAX_EXPR_H_ +#include +#include #include #include #include #include #include -#include -#include #include #include #include @@ -520,7 +520,7 @@ class Constant : public LeafExpr { * \param span The source span of the expression. */ TVM_DLL explicit Constant(runtime::NDArray data, - Optional struct_info_annotation = NullOpt, + Optional struct_info_annotation = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, LeafExpr, ConstantNode); @@ -1006,7 +1006,7 @@ class Function : public BaseFunc { * SeqExpr. * * \param ret_struct_info The StructInfo returned by the function. - * If NullOpt, will be inferred from the StructInfo of the + * If std::nullopt, will be inferred from the StructInfo of the * function's body. * * \param is_pure The purity of the function. diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 96b5b20d1ef8..c77383bdbf3d 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -421,7 +421,7 @@ class ExprMutator : public ExprMutatorBase { public: using ExprMutatorBase::VisitExpr_; - ExprMutator(Optional mod = NullOpt) { builder_ = BlockBuilder::Create(mod); } + ExprMutator(Optional mod = std::nullopt) { builder_ = BlockBuilder::Create(mod); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const DataflowVarNode* op) override; @@ -502,7 +502,7 @@ class ExprMutator : public ExprMutatorBase { * * \note The body_expr must be an SeqExpr in the normal form. */ - Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + Expr VisitWithNewScope(const Expr& body_expr, Optional> params = std::nullopt); /*! * \brief Rewrite the expr with a new scope, used in the branches of If. @@ -524,7 +524,7 @@ class ExprMutator : public ExprMutatorBase { * \brief Look up the value bound to a variable. * \param var The var to be looked up. * \return The value bound to the input \p var. - * \note For function parameters, this function returns NullOpt. + * \note For function parameters, this function returns std::nullopt. */ Optional LookupBinding(const Var& var); diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 0ddcb271ab83..af2db582d604 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -28,10 +28,10 @@ #ifndef TVM_RELAX_NESTED_MSG_H_ #define TVM_RELAX_NESTED_MSG_H_ +#include +#include #include #include -#include -#include #include #include @@ -46,7 +46,7 @@ namespace relax { * message state in pass analysis so we can robustly handle message * passing with the presence of nested tuple types. * - * Under the hood, NestedMsg[T] = Union[T, NullOpt, Array[NestedMsg[T]]]. + * Under the hood, NestedMsg[T] = Union[T, std::nullopt, Array[NestedMsg[T]]]. * Each nested message corresponds to the same nesting structure as * the nested tuple types when we encounter them in analysis. * @@ -176,7 +176,7 @@ class NestedMsg : public ObjectRef { bool IsNull() const { return data_ == nullptr; } /*! \return Whether the nested message is nested */ - bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } + bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } /*! * \return The underlying leaf value. @@ -359,7 +359,7 @@ NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { template TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { if (msg.IsNull()) { - return fmapleaf(NullOpt); + return fmapleaf(std::nullopt); } else if (msg.IsLeaf()) { return fmapleaf(msg.LeafValue()); } else { diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index 7c13fcc531a3..5c4e646351d3 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -164,7 +164,7 @@ class TensorStructInfoNode : public StructInfoNode { public: /*! * \brief optionally store the shape expression of the tensor. - * \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var. + * \note shape must be normalized: it can only be std::nullopt or ShapeExpr or Var. */ Optional shape; /*! \brief The virtual device, indicates where the tensor @@ -231,7 +231,7 @@ class TensorStructInfo : public StructInfo { * * \note shape must already be normalized. */ - TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Optional vdevice = NullOpt, + TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Optional vdevice = std::nullopt, Span span = Span()); /*! @@ -241,7 +241,7 @@ class TensorStructInfo : public StructInfo { * \param vdevice The virtual device. * \param span The span of the AST. */ - TVM_DLL TensorStructInfo(DataType dtype, int ndim, Optional vdevice = NullOpt, + TVM_DLL TensorStructInfo(DataType dtype, int ndim, Optional vdevice = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode); @@ -304,7 +304,7 @@ class FuncStructInfoNode : public StructInfoNode { public: /*! * \brief The parameter struct info of the function. - * \note When params is NullOpt means the function can take arbitrary number of arguments. + * \note When params is std::nullopt means the function can take arbitrary number of arguments. * We define such functions as Opaque function. */ Optional> params; @@ -314,7 +314,7 @@ class FuncStructInfoNode : public StructInfoNode { StructInfo ret; /*! * \brief Derivation function of opaque functions that may take any number of parameters. - * \note When derive_func is not empty, then params should be NullOpt, + * \note When derive_func is not empty, then params should be std::nullopt, * ret should be ObjectStructInfo() */ Optional derive_func; @@ -418,7 +418,7 @@ inline Optional MatchStructInfo(const Expr& expr) { if (const TNode* ptr = expr->struct_info_.as()) { return GetRef(ptr); } else { - return NullOpt; + return std::nullopt; } } diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 98aa2673c23b..b8ff0fa59dfe 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -206,13 +206,13 @@ TVM_DLL Pass BindParams(String func_name, Map params); * symbolic variable in each function where it is used. * * \param func_name The name of the function in which to bind shape - * values. If NullOpt, all functions in the module will be + * values. If std::nullopt, all functions in the module will be * updated. * * \return The Pass. */ TVM_DLL Pass BindSymbolicVars(Map binding_map, - Optional func_name = NullOpt); + Optional func_name = std::nullopt); /*! * \brief Fold constant expressions within dataflow blocks. @@ -378,7 +378,7 @@ class FusionPatternNode : public Object { /*! * \brief The function to determine whether the match result is accepted. This can be - * NullOpt if check function is not necessary for this pattern. + * std::nullopt if check function is not necessary for this pattern. * * It should have signature * bool(const PatternCheckContext& context) @@ -411,7 +411,7 @@ class FusionPattern : public ObjectRef { Optional check, Optional attrs_getter); FusionPattern(String name, DFPattern pattern) - : FusionPattern(name, pattern, {}, NullOpt, NullOpt) {} + : FusionPattern(name, pattern, {}, std::nullopt, std::nullopt) {} TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FusionPattern, ObjectRef, FusionPatternNode); }; @@ -497,7 +497,7 @@ class PatternCheckContext : public ObjectRef { * * \note ConvertToDataflow may need to be called first to provide dataflow blocks. */ -TVM_DLL Pass Gradient(String func_name, Optional> require_grads = NullOpt, +TVM_DLL Pass Gradient(String func_name, Optional> require_grads = std::nullopt, int target_index = 0); /*! @@ -548,7 +548,7 @@ TVM_DLL Pass FuseTIR(); * \return The Pass. */ TVM_DLL Pass RunCodegen(Optional>> target_options, - Array entry_functions); + Array entry_functions); /*! * \brief Decompose composite operators during inference. For example, The result of batch norm (a @@ -622,7 +622,7 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); * * \return The Pass. */ -TVM_DLL Pass DeadCodeElimination(Array entry_functions = {}); +TVM_DLL Pass DeadCodeElimination(Array entry_functions = {}); /*! * \brief Pass that changes calls to operators that can be done in-place @@ -646,7 +646,7 @@ TVM_DLL Pass DataflowUseInplaceCalls(); * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. */ TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype, - Optional> fp16_input_names = NullOpt); + Optional> fp16_input_names = std::nullopt); /*! * \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 7e4149fe5548..bd75197bfe21 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -24,10 +24,10 @@ #ifndef TVM_RELAX_TYPE_H_ #define TVM_RELAX_TYPE_H_ +#include #include #include #include -#include #include #include diff --git a/golang/src/gotvm.go b/include/tvm/runtime/base.h similarity index 50% rename from golang/src/gotvm.go rename to include/tvm/runtime/base.h index 072d9cce4619..c704decb63e9 100644 --- a/golang/src/gotvm.go +++ b/include/tvm/runtime/base.h @@ -17,26 +17,43 @@ * under the License. */ -/*! - * \brief gotvm package - * \file gotvm.go +/* + * \file tvm/runtime/base.h + * \brief base macros */ +#ifndef TVM_RUNTIME_BASE_H_ +#define TVM_RUNTIME_BASE_H_ + +// TVM runtime fully relies on TVM FFI C API +// we will avoid defining extra C APIs here +#include +// TVM version +#define TVM_VERSION "0.21.dev0" -// Package gotvm is TVM runtime interface definition for golang. -// -// Application need to import this package to access the c_runtime_api exposed by TVM. -package gotvm +// define extra macros for TVM DLL exprt +#ifdef __EMSCRIPTEN__ +#include +#define TVM_DLL EMSCRIPTEN_KEEPALIVE +#endif -//#include "gotvm.h" -import "C" +// helper macro to suppress unused warning +#if defined(__GNUC__) +#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_ATTRIBUTE_UNUSED +#endif -// DLPackVersion is the dlpack version of tvm runtime. -var DLPackVersion = int(C.DLPACK_VERSION) -// TVMVersion is the TVM runtime version. -var TVMVersion = getTVMVersion() +#ifndef TVM_DLL +#ifdef _WIN32 +#ifdef TVM_EXPORTS +#define TVM_DLL __declspec(dllexport) +#else +#define TVM_DLL __declspec(dllimport) +#endif +#else +#define TVM_DLL __attribute__((visibility("default"))) +#endif +#endif -func getTVMVersion() (retStr string) { - retStr = C.GoString(C._TVM_VERSION()) - return -} +#endif // TVM_RUNTIME_BASE_H_ diff --git a/include/tvm/runtime/builtin_fp16.h b/include/tvm/runtime/builtin_fp16.h index 5b54583da4ff..3ea670017d3d 100644 --- a/include/tvm/runtime/builtin_fp16.h +++ b/include/tvm/runtime/builtin_fp16.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_BUILTIN_FP16_H_ #define TVM_RUNTIME_BUILTIN_FP16_H_ -#include +#include #include diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index eb8d7270b137..0d84b55fe318 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -28,28 +28,12 @@ #ifndef TVM_RUNTIME_C_BACKEND_API_H_ #define TVM_RUNTIME_C_BACKEND_API_H_ -#include +#include #ifdef __cplusplus extern "C" { #endif -/*! - * \brief Signature for backend functions exported as DLL. - * - * \param args The arguments - * \param type_codes The type codes of the arguments - * \param num_args Number of arguments. - * \param out_ret_value The output value of the return value. - * \param out_ret_tcode The output type code of the return value. - * \param resource_handle Pointer to associated resource. - * - * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. - */ -typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, - TVMValue* out_ret_value, int* out_ret_tcode, - void* resource_handle); - /*! * \brief Backend function for modules to get function * from its environment mod_node (its imports and global function). @@ -60,7 +44,8 @@ typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_ar * \param out The result function. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); +TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, + TVMFFIObjectHandle* out); /*! * \brief Backend function to register system-wide library symbol. @@ -100,19 +85,6 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t */ TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); -/*! - * \brief Backend function to register execution environment(e.g. python) - * specific C APIs. - * - * \note We only register the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). In most cases we should use the ffi::Function FFI. - * - * \param name The name of the symbol - * \param ptr The symbol address. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_DLL int TVMBackendRegisterEnvCAPI(const char* name, void* ptr); - /*! * \brief Environment for TVM parallel task. */ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h deleted file mode 100644 index b802dbc22839..000000000000 --- a/include/tvm/runtime/c_runtime_api.h +++ /dev/null @@ -1,732 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/runtime/c_runtime_api.h - * \brief TVM runtime library. - * - * The philosophy of TVM project is to customize the compilation - * stage to generate code that can used by other projects transparently. - * So this is a minimum runtime code gluing, and some limited - * memory management code to enable quick testing. - * - * The runtime API is independent from TVM compilation stack and can - * be linked via libtvm_runtime. - * - * The common flow is: - * - Use TVMFuncListGlobalNames to get global function name - * - Use TVMFuncCall to call these functions. - * - * Possible return values of the API functions: - * * 0: success - * * -1: the error can be retrieved through TVMGetLastError. - * * -2: a frontend error occurred and recorded in the frontend. - */ -#ifndef TVM_RUNTIME_C_RUNTIME_API_H_ -#define TVM_RUNTIME_C_RUNTIME_API_H_ - -// Macros to do weak linking -#ifdef _MSC_VER -#define TVM_WEAK __declspec(selectany) -#else -#define TVM_WEAK __attribute__((weak)) -#endif - -#ifdef __EMSCRIPTEN__ -#include -#define TVM_DLL EMSCRIPTEN_KEEPALIVE -#endif - -// helper macro to suppress unused warning -#if defined(__GNUC__) -#define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define TVM_ATTRIBUTE_UNUSED -#endif - -#ifndef TVM_DLL -#ifdef _WIN32 -#ifdef TVM_EXPORTS -#define TVM_DLL __declspec(dllexport) -#else -#define TVM_DLL __declspec(dllimport) -#endif -#else -#define TVM_DLL __attribute__((visibility("default"))) -#endif -#endif - -// TVM version -#define TVM_VERSION "0.21.dev0" - -// TVM Runtime is DLPack compatible. -#include - -#ifdef __cplusplus -extern "C" { -#endif -#include -#include -#include - -/*! \brief type of array index. */ -typedef int64_t tvm_index_t; - -/*! \brief Extension device types in TVM - * - * Additional enumerators to supplement those provided by - * DLPack's `DLDeviceType` enumeration. - * - * MAINTAINERS NOTE #1: We need to ensure that the two devices - * are identified by the same integer. - * Currently this requires manual verification. - * Discussed here: https://github.com/dmlc/dlpack/issues/111 - * As of DLPack v0.7, the highest-valued enumerator in - * `DLDeviceType` is kDLHexagon = 16. - * - * MAINTAINERS NOTE #2: As of DLPack v0.7, the definition for - * `DLDeviceType` specifies an underlying storage type of - * `int32_t`. That guarantees a variable of type - * `DLDeviceType` is capable of holding any integers provided - * by *either* of these enumerations. - * - * However, the `int32_t` specification only applies when the - * header file is compiled as C++, and this header file is also - * meant to work as C code. So the unspecified storage type - * could be a latent bug when compiled as C. - */ -#ifdef __cplusplus -typedef enum : int32_t { -#else -typedef enum { -#endif - // To help avoid accidental conflicts between `DLDeviceType` - // and this enumeration, start numbering the new enumerators - // a little higher than (currently) seems necessary. - TVMDeviceExtType_End = 36, // sentinel value -} TVMDeviceExtType; - -#ifdef __cplusplus -// Some other parts of TVM hardcode the integer identifier for -// some DLPack / TVM devices, rather then using the symbolic -// enumerator. E.g., `2` rather than `kDLCUDA`. -// These asserts should alert us when that mapping breaks. -#define TVM_HARCODED_INTEGER_CHANGED_MSG \ - "Change in compile-time integer. Make sure hardcoded uses of this integer throughout TVM are " \ - "updated." -static_assert(kDLCPU == 1, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLCUDA == 2, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLCUDAHost == 3, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLOpenCL == 4, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLVulkan == 7, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLMetal == 8, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLVPI == 9, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLROCM == 10, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLROCMHost == 11, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLExtDev == 12, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLCUDAManaged == 13, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLOneAPI == 14, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLWebGPU == 15, TVM_HARCODED_INTEGER_CHANGED_MSG); -static_assert(kDLHexagon == 16, TVM_HARCODED_INTEGER_CHANGED_MSG); - -#undef TVM_HARCODED_INTEGER_CHANGED_MSG -#endif - -/*! - * \brief The type code in used and only used in TVM FFI for argument passing. - * - * DLPack consistency: - * 1) kTVMArgInt is compatible with kDLInt - * 2) kTVMArgFloat is compatible with kDLFloat - * 3) kDLUInt is not in ArgTypeCode, but has a spared slot - * - * Downstream consistency: - * The kDLInt, kDLUInt, kDLFloat are kept consistent with the original ArgType code - * - * It is only used in argument passing, and should not be confused with - * DataType::TypeCode, which is DLPack-compatible. - * - * \sa tvm::runtime::DataType::TypeCode - */ -typedef enum { - kTVMArgInt = kDLInt, - kTVMArgFloat = kDLFloat, - kTVMOpaqueHandle = 3U, - kTVMNullptr = 4U, - kTVMDataType = 5U, - kDLDevice = 6U, - kTVMDLTensorHandle = 7U, - kTVMObjectHandle = 8U, - kTVMModuleHandle = 9U, - kTVMPackedFuncHandle = 10U, - kTVMStr = 11U, - kTVMBytes = 12U, - kTVMNDArrayHandle = 13U, - kTVMObjectRValueRefArg = 14U, - kTVMArgBool = 15U, - // Extension codes for other frameworks to integrate TVM ffi::Function. - // To make sure each framework's id do not conflict, use first and - // last sections to mark ranges. - // Open an issue at the repo if you need a section of code. - kTVMExtBegin = 16U, - kTVMNNVMFirst = 16U, - kTVMNNVMLast = 20U, - // The following section of code is used for non-reserved types. - kTVMExtReserveEnd = 64U, - kTVMExtEnd = 128U, -} TVMArgTypeCode; - -/*! \brief the array handle */ -typedef DLTensor* TVMArrayHandle; - -/*! - * \brief Union type of values - * being passed through API and function calls. - */ -typedef union { - int64_t v_int64; - double v_float64; - void* v_handle; - const char* v_str; - DLDataType v_type; - DLDevice v_device; -} TVMValue; - -/*! - * \brief Byte array type used to pass in byte array - * When kTVMBytes is used as data type. - */ -typedef struct { - const char* data; - size_t size; -} TVMByteArray; - -/*! \brief Handle to TVM runtime modules. */ -typedef void* TVMModuleHandle; -/*! \brief Handle to packed function handle. */ -typedef void* TVMFunctionHandle; -/*! \brief Handle to hold return value. */ -typedef void* TVMRetValueHandle; -/*! - * \brief The stream that is specific to device - * can be NULL, which indicates the default one. - */ -typedef void* TVMStreamHandle; -/*! \brief Handle to Object. */ -typedef void* TVMObjectHandle; - -/*! - * \brief Used for implementing C API function. - * Set last error message before return. - * \param msg The error message to be set. - */ -TVM_DLL void TVMAPISetLastError(const char* msg); - -/*! - * \brief Used for implementing C API function. - * Set last exception before return. - * \param py_object The python exception to be set - */ -TVM_DLL void TVMAPISetLastPythonError(void* py_object); - -/*! \brief Return the previous python error, if any. - * - * Used to propagate the original Python exception to a python - * try/except, when there are C++ stack frames between the location thro - * - * \return The previous argument passed during the most recent call to - * TVMAPISetLastPythonError. If TVMAPISetLastPythonError has not - * been called, or if TVMDropLastPythonError has been called since - * the most recent to TVMAPISetLastPythonError, returns nullptr. - */ -TVM_DLL void* TVMGetLastPythonError(); - -/*! - * \brief return str message of the last error - * all function in this file will return 0 when success - * and nonzero when an error occurred, - * TVMGetLastError can be called to retrieve the error - * - * this function is threadsafe and can be called by different thread - * \return error info - */ -TVM_DLL const char* TVMGetLastError(void); - -/*! - * \brief Return the backtrace of the most recent error - * - * Returns the backtrace of the most recent error, if an error exists, - * and the error contains a backtrace. If no error exists or the - * error does not contain a backtrace, returns nullptr. - * - * \return The backtrace of the most recent error - */ -TVM_DLL const char* TVMGetLastBacktrace(); - -/*! - * \brief Remove the propagated python error, if any - * - * Removes the TVM-held reference to a thrown python exception object. - * Because these objects contain references to the stack frames from - * which the exception was thrown, maintaining a reference to an - * exception object prevents any local python variables from being - * garbage-collected. After retrieving the object using - * TVMGetLastPythonError, the Python FFI interface uses this method to - * clear the TVM-held reference to the exception, to allow garbage - * collection to continue. - */ -TVM_DLL void TVMDropLastPythonError(); - -/*! \brief Re-throw the most recent error. - * - * If an error was previously set using TVMAPISetLastError or - * TVMAPISetLastPythonError, re-throw the error. This is similar to - * `LOG(FATAL) << TVMGetLastError()`, but includes handling to - * propagate a python exception across C++ stack frames, or to append - * a stack trace to an error message. - */ -TVM_DLL void TVMThrowLastError(); - -/*! - * \brief Load module from file. - * \param file_name The file name to load the module from. - * \param format The format of the module. - * \param out The result module - * - * \return 0 when success, nonzero when failure happens - * \note The resulting module do not contain import relation. - * It can be reconstructed by TVMModImport. - */ -TVM_DLL int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out); - -/*! - * \brief Add dep to mod's dependency. - * This allows functions in this module to use modules. - * - * \param mod The module handle. - * \param dep The dependent module to be imported. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep); - -/*! - * \brief Get function from the module. - * \param mod The module handle. - * \param func_name The name of the function. - * \param query_imports Whether to query imported modules - * \param out The result function, can be NULL if it is not available. - * \return 0 when no error is thrown, nonzero when failure happens - */ -TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, - TVMFunctionHandle* out); - -/*! - * \brief Free the Module - * \param mod The module to be freed. - * - * \note This may not free up the module's resources. - * If there is active TVMFunctionHandle uses the module - * Or if this module is imported by another active module. - * - * The all functions remains valid until TVMFuncFree is called. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMModFree(TVMModuleHandle mod); - -/*! - * \brief Free the function when it is no longer needed. - * \param func The function handle - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMFuncFree(TVMFunctionHandle func); - -/*! - * \brief Call a Packed TVM Function. - * - * \param func node handle of the function. - * \param arg_values The arguments - * \param type_codes The type codes of the arguments - * \param num_args Number of arguments. - * - * \param ret_val The return value. - * \param ret_type_code the type code of return value. - * - * \return 0 when success, nonzero when failure happens - * \note TVM calls always exchanges with type bits=64, lanes=1 - * - * \note API calls always exchanges with type bits=64, lanes=1 - * If API call returns container handles (e.g. FunctionHandle) - * these handles should be managed by the front-end. - * The front-end need to call free function (e.g. TVMFuncFree) - * to free these handles. - */ -TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, - TVMValue* ret_val, int* ret_type_code); - -/*! - * \brief Set the return value of TVMPackedCFunc. - * - * This function is called by TVMPackedCFunc to set the return value. - * When this function is not called, the function returns null by default. - * - * \param ret The return value handle, pass by ret in TVMPackedCFunc - * \param value The value to be returned. - * \param type_code The type of the value to be returned. - * \param num_ret Number of return values, for now only 1 is supported. - */ -TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret); - -/*! - * \brief Inplace translate callback argument value to return value. - * This is only needed for non-POD arguments. - * - * \param value The value to be translated. - * \param code The type code to be translated. - * \note This function will do a shallow copy when necessary. - * - * \return 0 when success, nonzero when failure happens. - */ -TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); - -/*! - * \brief C type of packed function. - * - * \param args The arguments - * \param type_codes The type codes of the arguments - * \param num_args Number of arguments. - * \param ret The return value handle. - * \param resource_handle The handle additional resouce handle from front-end. - * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. - * \sa TVMCFuncSetReturn - */ -typedef int (*TVMPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, - void* resource_handle); - -/*! - * \brief C callback to free the resource handle in C packed function. - * \param resource_handle The handle additional resouce handle from front-end. - */ -typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle); - -/*! - * \brief Signature for extension function declarer. - * - * TVM call this function to get the extension functions - * The declarer will call register_func to register function and their name. - * - * \param register_func_handle The register function - * \return 0 if success, -1 if failure happens - */ -typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); - -/*! - * \brief Wrap a TVMPackedCFunc to become a FunctionHandle. - * - * The resource_handle will be managed by TVM API, until the function is no longer used. - * - * \param func The packed C function. - * \param resource_handle The resource handle from front-end, can be NULL. - * \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL - * \param out the result function handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, - TVMPackedCFuncFinalizer fin, TVMFunctionHandle* out); - -/*! - * \brief Register the function to runtime's global table. - * - * The registered function then can be pulled by the backend by the name. - * - * \param name The name of the function. - * \param f The function to be registered. - * \param override Whether allow override already registered function. - */ -TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override); - -/*! - * \brief Get a global function. - * - * \param name The name of the function. - * \param out the result function pointer, NULL if it does not exist. - * - * \note The function handle of global function is managed by TVM runtime, - * So TVMFuncFree is should not be called when it get deleted. - */ -TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); - -/*! - * \brief List all the globally registered function name - * \param out_size The number of functions - * \param out_array The array of function names. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); - -/*! - * \brief Remove a global function. - * \param name The name of the function. - */ -TVM_DLL int TVMFuncRemoveGlobal(const char* name); - -// Array related apis for quick proptyping -/*! - * \brief Allocate a nd-array's memory, - * including space of shape, of given spec. - * - * \param shape The shape of the array, the data content will be copied to out - * \param ndim The number of dimension of the array. - * \param dtype_code The type code of the dtype - * \param dtype_bits The number of bits of dtype - * \param dtype_lanes The number of lanes in the dtype. - * \param device_type The device type. - * \param device_id The device id. - * \param out The output handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, - int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); - -/*! - * \brief Free the TVM Array. - * \param handle The array handle to be freed. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayFree(TVMArrayHandle handle); - -/*! - * \brief Copy array data from CPU byte array. - * \param handle The array handle. - * \param data the data pointer - * \param nbytes The number of bytes to copy. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes); - -/*! - * \brief Copy array data to CPU byte array. - * \param handle The array handle. - * \param data the data pointer - * \param nbytes The number of bytes to copy. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes); - -/*! - * \brief Copy the array, both from and to must be valid during the copy. - * \param from The array to be copied from. - * \param to The target space. - * \param stream The stream where the copy happens, can be NULL. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream); - -/*! - * \brief Produce an array from the DLManagedTensor that shares data memory - * with the DLManagedTensor. - * \param from The source DLManagedTensor. - * \param out The output array handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out); - -/*! - * \brief Produce a DLMangedTensor from the array that shares data memory with - * the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out); - -/*! - * \brief Delete (free) a DLManagedTensor's data. - * \param dltensor Pointer to the DLManagedTensor. - */ -TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor); - -/*! - * \brief Create a new runtime stream. - * - * \param device_type The device type. - * \param device_id The device id. - * \param out The new stream handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out); - -/*! - * \brief Free a created stream handle. - * - * \param device_type The device type. - * \param device_id The device id. - * \param stream The stream to be freed. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream); - -/*! - * \brief Set the runtime stream of current thread to be stream. - * The subsequent calls to the same device_type - * will use the setted stream handle. - * The specific type of stream is runtime device dependent. - * - * \param device_type The device type. - * \param device_id The device id. - * \param handle The stream handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle); - -/*! - * \brief Wait until all computations on stream completes. - * - * \param device_type The device type. - * \param device_id The device id. - * \param stream The stream to be synchronized. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); - -/*! - * \brief Synchronize two streams of execution. - * - * \param device_type The device type. - * \param device_id The device id. - * \param src The source stream to synchronize. - * \param dst The destination stream to synchronize. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, - TVMStreamHandle dst); - -/*! - * \brief Get the type_index from an object. - * - * \param obj The object handle. - * \param out_tindex the output type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); - -/*! - * \brief Convert type index to type key. - * \param tindex The type index. - * \param out_type_key The output type key. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); - -/*! - * \brief Increase the reference count of an object. - * - * \param obj The object handle. - * \note Internally we increase the reference counter of the object. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectRetain(TVMObjectHandle obj); - -/*! - * \brief Free the object. - * - * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMObjectFree(TVMObjectHandle obj); - -/*! - * \brief Free a TVMByteArray returned from TVMFuncCall, and associated memory. - * \param arr The TVMByteArray instance. - * \return 0 on success, -1 on failure. - */ -TVM_DLL int TVMByteArrayFree(TVMByteArray* arr); - -/*! - * \brief Allocate a data space on device. - * \param dev The device to perform operation. - * \param nbytes The number of bytes in memory. - * \param alignment The alignment of the memory. - * \param type_hint The type of elements. Only needed by certain backends such - * as nbytes & alignment are sufficient for most backends. - * \param out_data The allocated device pointer. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, - DLDataType type_hint, void** out_data); - -/*! - * \brief Allocate a data space on device with special memory scope. - * \note The memory could use a special multi-dimensional memory layout. - * That is why we pass shape and dtype instead of raw number of bytes. - * \param dev The device to perform operation. - * \param ndim The number of dimension of the tensor. - * \param shape The shape of the tensor. - * \param dtype The type of elements. - * \param mem_scope The memory scope of the tensor, - * can be nullptr, which indicate the default global DRAM - * \param out_data The allocated device pointer. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, - DLDataType dtype, const char* mem_scope, - void** out_data); - -/*! - * \brief Free a data space on device. - * \param dev The device to perform operation. - * \param ptr The data space. - * \return 0 when success, nonzero when failure happens - */ -TVM_DLL int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr); - -/*! - * \brief Copy data from one place to another. - * \note This API is designed to support special memory with shape dependent layout. - * We pass in DLTensor* with shape information to support these cases. - * \param from The source tensor. - * \param to The target tensor. - * \param stream Optional stream object. - * \return 0 when success, nonzero when failure happens. - */ -TVM_DLL int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream); - -/*! - * \brief Check that an object is derived from another. - * \param child_type_index The type index of the derived type. - * \param parent_type_index The type index of the parent type. - * \param is_derived A boolean representing whether this predicate holds. - * \return 0 when success, nonzero when failure happens. - */ -TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, - int* is_derived); - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif -#endif // TVM_RUNTIME_C_RUNTIME_API_H_ diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h deleted file mode 100644 index 7d8de1c23423..000000000000 --- a/include/tvm/runtime/container/array.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/string.h - * \brief Runtime String container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_ARRAY_H_ -#define TVM_RUNTIME_CONTAINER_ARRAY_H_ - -#include - -namespace tvm { -namespace runtime { - -using tvm::ffi::Array; -using tvm::ffi::ArrayObj; - -} // namespace runtime - -// expose class to root namespace -using tvm::ffi::Array; -using tvm::ffi::ArrayObj; -} // namespace tvm -#endif // TVM_RUNTIME_CONTAINER_ARRAY_H_ diff --git a/include/tvm/runtime/container/base.h b/include/tvm/runtime/container/base.h deleted file mode 100644 index b0295761f6a3..000000000000 --- a/include/tvm/runtime/container/base.h +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/base.h - * \brief Base utilities for common POD(plain old data) container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_BASE_H_ -#define TVM_RUNTIME_CONTAINER_BASE_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Base template for classes with array like memory layout. - * - * It provides general methods to access the memory. The memory - * layout is ArrayType + [ElemType]. The alignment of ArrayType - * and ElemType is handled by the memory allocator. - * - * \tparam ArrayType The array header type, contains object specific metadata. - * \tparam ElemType The type of objects stored in the array right after - * ArrayType. - * - * \code - * // Example usage of the template to define a simple array wrapper - * class ArrayObj : public InplaceArrayBase { - * public: - * // Wrap EmplaceInit to initialize the elements - * template - * void Init(Iterator begin, Iterator end) { - * size_t num_elems = std::distance(begin, end); - * auto it = begin; - * this->size = 0; - * for (size_t i = 0; i < num_elems; ++i) { - * InplaceArrayBase::EmplaceInit(i, *it++); - * this->size++; - * } - * } - * } - * - * void test_function() { - * vector fields; - * auto ptr = make_inplace_array_object(fields.size()); - * ptr->Init(fields.begin(), fields.end()); - * - * // Access the 0th element in the array. - * assert(ptr->operator[](0) == fields[0]); - * } - * - * \endcode - */ -template -class InplaceArrayBase { - public: - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Const reference to ElemType at the index. - */ - const ElemType& operator[](size_t idx) const { - size_t size = Self()->GetSize(); - ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Reference to ElemType at the index. - */ - ElemType& operator[](size_t idx) { - size_t size = Self()->GetSize(); - ICHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n"; - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Destroy the Inplace Array Base object - */ - ~InplaceArrayBase() { - if (!(std::is_standard_layout::value && std::is_trivial::value)) { - size_t size = Self()->GetSize(); - for (size_t i = 0; i < size; ++i) { - ElemType* fp = reinterpret_cast(AddressOf(i)); - fp->ElemType::~ElemType(); - } - } - } - - protected: - /*! - * \brief Construct a value in place with the arguments. - * - * \tparam Args Type parameters of the arguments. - * \param idx Index of the element. - * \param args Arguments to construct the new value. - * - * \note Please make sure ArrayType::GetSize returns 0 before first call of - * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. - */ - template - void EmplaceInit(size_t idx, Args&&... args) { - void* field_ptr = AddressOf(idx); - new (field_ptr) ElemType(std::forward(args)...); - } - - /*! - * \brief Return the self object for the array. - * - * \return Pointer to ArrayType. - */ - inline ArrayType* Self() const { - return static_cast(const_cast(this)); - } - - /*! - * \brief Return the raw pointer to the element at idx. - * - * \param idx The index of the element. - * \return Raw pointer to the element. - */ - void* AddressOf(size_t idx) const { - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); - - size_t kDataStart = sizeof(ArrayType); - ArrayType* self = Self(); - char* data_start = reinterpret_cast(self) + kDataStart; - return data_start + idx * sizeof(ElemType); - } -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - IterAdapter& operator++() { - ++iter_; - return *this; - } - IterAdapter& operator--() { - --iter_; - return *this; - } - IterAdapter operator++(int) { - IterAdapter copy = *this; - ++iter_; - return copy; - } - IterAdapter operator--(int) { - IterAdapter copy = *this; - --iter_; - return copy; - } - - IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(IterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class ReverseIterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} - ReverseIterAdapter& operator++() { - --iter_; - return *this; - } - ReverseIterAdapter& operator--() { - ++iter_; - return *this; - } - ReverseIterAdapter operator++(int) { - ReverseIterAdapter copy = *this; - --iter_; - return copy; - } - ReverseIterAdapter operator--(int) { - ReverseIterAdapter copy = *this; - ++iter_; - return copy; - } - ReverseIterAdapter operator+(difference_type offset) const { - return ReverseIterAdapter(iter_ - offset); - } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const ReverseIterAdapter& rhs) const { - return rhs.iter_ - iter_; - } - - bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } - const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -} // namespace runtime - -// expose the functions to the root namespace. -using runtime::Downcast; -using runtime::IterAdapter; -using runtime::make_object; -using runtime::Object; -using runtime::ObjectPtr; -using runtime::ObjectPtrEqual; -using runtime::ObjectPtrHash; -using runtime::ObjectRef; -} // namespace tvm - -#endif // TVM_RUNTIME_CONTAINER_BASE_H_ diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h deleted file mode 100644 index cd63cc94ada0..000000000000 --- a/include/tvm/runtime/container/map.h +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/map.h - * \brief Runtime Map container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_MAP_H_ -#define TVM_RUNTIME_CONTAINER_MAP_H_ - -#include - -namespace tvm { -namespace runtime { - -using tvm::ffi::Map; - -} // namespace runtime - -// expose the functions to the root namespace. -using tvm::ffi::Map; -using tvm::ffi::MapObj; -} // namespace tvm -#endif // TVM_RUNTIME_CONTAINER_MAP_H_ diff --git a/include/tvm/runtime/container/optional.h b/include/tvm/runtime/container/optional.h deleted file mode 100644 index 4dc3b680de7a..000000000000 --- a/include/tvm/runtime/container/optional.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/string.h - * \brief Runtime String container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_OPTIONAL_H_ -#define TVM_RUNTIME_CONTAINER_OPTIONAL_H_ - -#include - -namespace tvm { -namespace runtime { - -using tvm::ffi::Optional; -} // namespace runtime - -// expose class to root namespace -using tvm::ffi::Optional; -constexpr inline auto NullOpt = std::nullopt; -} // namespace tvm -#endif // TVM_RUNTIME_CONTAINER_OPTIONAL_H_ diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h deleted file mode 100644 index c7a96b6623a6..000000000000 --- a/include/tvm/runtime/container/shape_tuple.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/shape_tuple.h - * \brief Runtime ShapeTuple container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ -#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ - -#include - -#include -#include -#include - -#include "./base.h" - -namespace tvm { -namespace runtime { - -using Shape = tvm::ffi::Shape; -using ShapeTuple = tvm::ffi::Shape; -using ShapeTupleObj = tvm::ffi::ShapeObj; -using IntTuple = ShapeTuple; -using IntTupleObj = ShapeTupleObj; - -} // namespace runtime - -// expose the functions to the root namespace. -using runtime::IntTuple; -using runtime::IntTupleObj; -using runtime::ShapeTuple; -using runtime::ShapeTupleObj; -} // namespace tvm - -#endif // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h deleted file mode 100644 index d55d9dbd7960..000000000000 --- a/include/tvm/runtime/container/string.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/string.h - * \brief Runtime String container types. - */ -#ifndef TVM_RUNTIME_CONTAINER_STRING_H_ -#define TVM_RUNTIME_CONTAINER_STRING_H_ - -#include - -namespace tvm { -namespace runtime { - -using tvm::ffi::String; -using tvm::ffi::StringObj; - -} // namespace runtime - -using tvm::ffi::String; -using tvm::ffi::StringObj; - -} // namespace tvm -#endif // TVM_RUNTIME_CONTAINER_STRING_H_ diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h deleted file mode 100644 index 9ba9a987115b..000000000000 --- a/include/tvm/runtime/container/variant.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/container/variant.h - * \brief Runtime variant container. - */ -#ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_ -#define TVM_RUNTIME_CONTAINER_VARIANT_H_ - -#include - -namespace tvm { -namespace runtime { - -using tvm::ffi::Variant; - -} // namespace runtime - -// expose class to root namespace -using tvm::ffi::Variant; -} // namespace tvm -#endif // TVM_RUNTIME_CONTAINER_VARIANT_H_ diff --git a/include/tvm/runtime/contrib/papi.h b/include/tvm/runtime/contrib/papi.h index ff2d75c483eb..93c1aa274bfd 100644 --- a/include/tvm/runtime/contrib/papi.h +++ b/include/tvm/runtime/contrib/papi.h @@ -22,8 +22,8 @@ #ifndef TVM_RUNTIME_CONTRIB_PAPI_H_ #define TVM_RUNTIME_CONTRIB_PAPI_H_ -#include -#include +#include +#include #include namespace tvm { diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9418a0c902d4..d5f3c6ee3d7f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -24,8 +24,9 @@ #ifndef TVM_RUNTIME_DATA_TYPE_H_ #define TVM_RUNTIME_DATA_TYPE_H_ +#include #include -#include +#include #include #include @@ -35,6 +36,8 @@ namespace tvm { namespace runtime { +using tvm_index_t = ffi::Shape::index_type; + /*! * \brief Runtime primitive data type. * @@ -375,21 +378,20 @@ struct TypeTraits : public TypeTraitsBase { result->v_dtype = src; } - static TVM_FFI_INLINE std::optional TryConvertFromAnyView( - const TVMFFIAny* src) { - auto opt_dtype = TypeTraits::TryConvertFromAnyView(src); + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { + auto opt_dtype = TypeTraits::TryCastFromAnyView(src); if (opt_dtype) { return runtime::DataType(opt_dtype.value()); } return std::nullopt; } - static TVM_FFI_INLINE bool CheckAnyStorage(const TVMFFIAny* src) { - return TypeTraits::CheckAnyStorage(src); + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + return TypeTraits::CheckAnyStrict(src); } - static TVM_FFI_INLINE runtime::DataType CopyFromAnyStorageAfterCheck(const TVMFFIAny* src) { - return runtime::DataType(TypeTraits::CopyFromAnyStorageAfterCheck(src)); + static TVM_FFI_INLINE runtime::DataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return runtime::DataType(TypeTraits::CopyFromAnyViewAfterCheck(src)); } static TVM_FFI_INLINE std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index a4b53eb79734..7366b9895d5e 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -26,10 +26,15 @@ #include #include -#include +#include #include #include +/*! + * \brief The stream that is specific to device + * can be NULL, which indicates the default one. + */ +typedef void* TVMStreamHandle; namespace tvm { @@ -37,6 +42,41 @@ namespace tvm { using Device = DLDevice; namespace runtime { + +/*! \brief Extension device types in TVM + * + * Additional enumerators to supplement those provided by + * DLPack's `DLDeviceType` enumeration. + * + * MAINTAINERS NOTE #1: We need to ensure that the two devices + * are identified by the same integer. + * Currently this requires manual verification. + * Discussed here: https://github.com/dmlc/dlpack/issues/111 + * As of DLPack v0.7, the highest-valued enumerator in + * `DLDeviceType` is kDLHexagon = 16. + * + * MAINTAINERS NOTE #2: As of DLPack v0.7, the definition for + * `DLDeviceType` specifies an underlying storage type of + * `int32_t`. That guarantees a variable of type + * `DLDeviceType` is capable of holding any integers provided + * by *either* of these enumerations. + * + * However, the `int32_t` specification only applies when the + * header file is compiled as C++, and this header file is also + * meant to work as C code. So the unspecified storage type + * could be a latent bug when compiled as C. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + // To help avoid accidental conflicts between `DLDeviceType` + // and this enumeration, start numbering the new enumerators + // a little higher than (currently) seems necessary. + TVMDeviceExtType_End = 36, // sentinel value +} TVMDeviceExtType; + /*! * \brief the query type into GetAttr */ diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index 4453d9737f89..93bf0084db87 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -70,7 +70,7 @@ TVM_DLL Module LoadVMModule(std::string path, Device device); * \param device The device the NDArray is created on. If None, use the thread local default device * \return The NDArray created */ -TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device); +TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device); /*! * \brief Perform an allreduce operation using the underlying communication library * \param send The array send to perform allreduce on diff --git a/include/tvm/runtime/disco/cuda_ipc_memory.h b/include/tvm/runtime/disco/cuda_ipc_memory.h index 120e6a543179..ea272052626f 100644 --- a/include/tvm/runtime/disco/cuda_ipc_memory.h +++ b/include/tvm/runtime/disco/cuda_ipc_memory.h @@ -20,7 +20,7 @@ #ifndef TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_ #define TVM_RUNTIME_DISCO_CUDA_IPC_MEMORY_H_ -#include +#include #include #include diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index c7aeb4e284ad..078c061b7b82 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -25,8 +25,8 @@ #ifndef TVM_RUNTIME_DISCO_DISCO_WORKER_H_ #define TVM_RUNTIME_DISCO_DISCO_WORKER_H_ +#include #include -#include #include diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index fb21f79882ad..0c1ed7ca0aaf 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -72,9 +72,10 @@ #ifndef TVM_RUNTIME_DISCO_SESSION_H_ #define TVM_RUNTIME_DISCO_SESSION_H_ -#include +#include +#include +#include #include -#include #include #include @@ -243,7 +244,7 @@ class SessionObj : public Object { * \param value The value to be set. * \param worker_id The id of the worker to be set. */ - TVM_DLL virtual void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) = 0; + TVM_DLL virtual void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) = 0; struct FFI; friend struct SessionObj::FFI; @@ -338,7 +339,7 @@ template DRef SessionObj::CallPacked(const DRef& func, Args&&... args) { constexpr int offset = 3; constexpr int kNumArgs = offset + sizeof...(Args); - AnyView packed_args[kNumArgs]; + ffi::AnyView packed_args[kNumArgs]; ffi::PackedArgs::Fill(packed_args, /*.0=*/static_cast(DiscoAction::kCallPacked), // action /*.1=*/0, // reg_id, which will be updated by this->CallWithPacked diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/int_tuple.h similarity index 73% rename from include/tvm/runtime/memory.h rename to include/tvm/runtime/int_tuple.h index 96ca0c8d696a..032e8b9a328e 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/int_tuple.h @@ -16,22 +16,23 @@ * specific language governing permissions and limitations * under the License. */ + /*! - * \file tvm/runtime/object.h - * \brief A managed object in the TVM runtime. + * \file int_tuple.h + * \brief Defines tuple of integers. */ -#ifndef TVM_RUNTIME_MEMORY_H_ -#define TVM_RUNTIME_MEMORY_H_ +#ifndef TVM_RUNTIME_INT_TUPLE_H_ +#define TVM_RUNTIME_INT_TUPLE_H_ -#include +#include namespace tvm { namespace runtime { -using tvm::ffi::FObjectDeleter; -using tvm::ffi::make_inplace_array_object; -using tvm::ffi::make_object; +// We simply redirects to ffi::Shape, and ffi::ShapeObj +using IntTuple = ffi::Shape; +using IntTupleObj = ffi::ShapeObj; } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_MEMORY_H_ +#endif // TVM_RUNTIME_INT_TUPLE_H_ diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index 807c9dbf30bc..da715848e09a 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index 537beeb8fa9d..f103c6f30ac8 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_ #define TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_ -#include +#include #include #include @@ -66,8 +66,8 @@ class Allocator { * \param mem_scope The device memory scope hint. * \return The empty NDArray. */ - TVM_DLL NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, - Optional mem_scope = NullOpt); + TVM_DLL NDArray Empty(ffi::Shape shape, DLDataType dtype, Device dev, + Optional mem_scope = std::nullopt); /*! \brief Return the allocator type. */ inline AllocatorType type() const { return type_; } /*! \brief Allocate a buffer given a size, alignment and type. @@ -86,7 +86,7 @@ class Allocator { * \param mem_scope A memory scope of the buffer. * \return A sized allocation in the form of a buffer. */ - TVM_DLL virtual Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint, + TVM_DLL virtual Buffer Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope = ""); /*! \brief Create a view for the buffer given a shape, type and scope. @@ -96,7 +96,7 @@ class Allocator { * \param mem_scope A memory scope of the view. * \return A device pointer to the created view. */ - TVM_DLL virtual void* CreateView(const Buffer& buffer, ShapeTuple shape, DLDataType type_hint, + TVM_DLL virtual void* CreateView(const Buffer& buffer, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope = "global") { return buffer.data; } @@ -164,10 +164,10 @@ class StorageObj : public Object { Allocator* allocator = nullptr; /*! \brief Allocate an NDArray from a given piece of storage. */ - TVM_DLL NDArray AllocNDArray(int64_t offset, ShapeTuple shape, DLDataType dtype); + TVM_DLL NDArray AllocNDArray(int64_t offset, ffi::Shape shape, DLDataType dtype); /*! \brief Allocate an NDArray with memory scope from a given piece of storage. */ - TVM_DLL NDArray AllocNDArrayScoped(int64_t offset, ShapeTuple shape, DLDataType dtype, + TVM_DLL NDArray AllocNDArrayScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, String scope = "global"); ~StorageObj() { diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 37ab906dd422..705fb276d9e7 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -28,15 +28,16 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include #include #include #include +#include #include namespace tvm { @@ -133,7 +134,7 @@ class Module : public ObjectRef { * // instace of MyModuleNode. * Module CreateMyModule() { * ObjectPtr n = - * tvm::runtime::make_object(); + * tvm::ffi::make_object(); * return Module(n); * } * @@ -325,8 +326,80 @@ inline std::ostream& operator<<(std::ostream& out, const Module& module) { return out; } +namespace details { + +template +struct ModuleVTableEntryHelper {}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = R (T::*)(Args...) const; + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { + auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); + } +}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = R (T::*)(Args...); + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { + auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); + } +}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = void (T::*)(Args...) const; + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { + auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); + } +}; + +template +struct ModuleVTableEntryHelper { + using MemFnType = void (T::*)(Args...); + static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { + auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; + ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, + args.data(), args.size(), rv); + } +}; +} // namespace details } // namespace runtime } // namespace tvm -#include // NOLINT(*) -#endif // TVM_RUNTIME_MODULE_H_ +#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ + const char* type_key() const final { return TypeKey; } \ + ffi::Function GetFunction(const String& _name, const ObjectPtr& _self) override { \ + using SelfPtr = std::remove_cv_t; +#define TVM_MODULE_VTABLE_END() \ + return ffi::Function(nullptr); \ + } +#define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ + { \ + auto f = (MemFunc); \ + return (this->*f)(_name); \ + } \ + } // NOLINT(*) +#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \ + if (_name == Name) { \ + return ffi::Function::FromPacked([_self](ffi::PackedArgs args, ffi::Any* rv) -> void { \ + using Helper = ::tvm::runtime::details::ModuleVTableEntryHelper; \ + SelfPtr self = static_cast(_self.get()); \ + Helper::Call(rv, self, MemFunc, args); \ + }); \ + } +#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \ + if (_name == Name) { \ + return ffi::Function([_self](ffi::PackedArgs args, ffi::Any* rv) -> void { \ + (static_cast(_self.get())->*(MemFunc))(args, rv); \ + }); \ + } + +#endif // TVM_RUNTIME_MODULE_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 82c9b229ab90..6eebe49ff135 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,10 +25,10 @@ #define TVM_RUNTIME_NDARRAY_H_ #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include @@ -62,7 +62,7 @@ class NDArray : public tvm::ffi::NDArray { NDArray(ffi::NDArray&& other) : tvm::ffi::NDArray(std::move(other)) {} // NOLINT(*) NDArray(const ffi::NDArray& other) : tvm::ffi::NDArray(other) {} // NOLINT(*) - ShapeTuple Shape() const { return this->shape(); } + ffi::Shape Shape() const { return this->shape(); } runtime::DataType DataType() const { return runtime::DataType(this->dtype()); } // DLPack handling @@ -112,7 +112,7 @@ class NDArray : public tvm::ffi::NDArray { * \return The array under another device. * \note The copy always triggers a TVMSynchronize. */ - TVM_DLL NDArray CopyTo(const Device& dev, Optional mem_scope = NullOpt) const; + TVM_DLL NDArray CopyTo(const Device& dev, Optional mem_scope = std::nullopt) const; /*! * \brief Load NDArray from stream * \param stream The input data stream @@ -145,7 +145,7 @@ class NDArray : public tvm::ffi::NDArray { * outside the bounds of the current array, this function will * raise an exception. */ - TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype, + TVM_DLL NDArray CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_byte_offset = 0) const; /*! * \brief Create an empty NDArray. @@ -155,8 +155,8 @@ class NDArray : public tvm::ffi::NDArray { * \param mem_scope The memory scope of the array. * \return The created Array */ - TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, - Optional mem_scope = NullOpt); + TVM_DLL static NDArray Empty(ffi::Shape shape, DLDataType dtype, Device dev, + Optional mem_scope = std::nullopt); /*! * \brief Function to copy data from one array to another. * \param from The source array. @@ -166,20 +166,15 @@ class NDArray : public tvm::ffi::NDArray { TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); - struct Internal; - - protected: - /*! - * \brief DecRef resource managed by an FFI array handle. - * \param handle The array handle. - */ - inline static void FFIDecRef(TVMArrayHandle handle); /*! - * \brief Get FFI Array handle from ndarray. - * \param nd The object with ndarray type. - * \return The result array handle. + * \brief Function to copy data from one array to a byte buffer. + * \param from The source array. + * \param to The target byte buffer. + * \param nbytes The size of the data buffer. + * \param stream The stream used in copy. */ - inline static TVMArrayHandle FFIGetHandle(const ObjectRef& nd); + TVM_DLL static void CopyToBytes(const DLTensor* from, void* to, size_t nbytes, + TVMStreamHandle stream = nullptr); }; /*! @@ -211,28 +206,6 @@ inline void NDArray::CopyTo(const NDArray& other) const { CopyFromTo(get_mutable(), other.get_mutable()); } -inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { - // NOTE: it is necessary to cast to container then to base - // so that the FFI handle uses the ContainerBase address. - auto ptr = reinterpret_cast( - TVMFFINDArrayGetDLTensorPtr(static_cast(const_cast(nd.get())))); - return ptr; -} - -inline TVMArrayHandle ObjectHandleToTVMArrayHandle(Object* handle) { - return reinterpret_cast( - TVMFFINDArrayGetDLTensorPtr(static_cast(handle))); -} - -inline Object* TVMArrayHandleToObjectHandle(void* handle) { - // NOTE: legacy patch here for TFM FFI - return reinterpret_cast(reinterpret_cast(handle) - sizeof(TVMFFIObject)); -} - -inline void NDArray::FFIDecRef(TVMArrayHandle handle) { - ffi::details::ObjectUnsafe::DecRefObjectHandle(TVMArrayHandleToObjectHandle(handle)); -} - /*! \brief Magic number for NDArray file */ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; @@ -271,10 +244,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { strm->Write(tensor->data, data_byte_size); } else { std::vector bytes(data_byte_size); - ICHECK_EQ( - TVMArrayCopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), - 0) - << TVMGetLastError(); + NDArray::CopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size); if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); } @@ -301,7 +271,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) { if (ndim != 0) { ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } - NDArray ret = NDArray::Empty(ShapeTuple(shape), dtype, dev); + NDArray ret = NDArray::Empty(ffi::Shape(shape), dtype, dev); int64_t num_elems = 1; int elem_bytes = (ret->dtype.bits + 7) / 8; for (int i = 0; i < ret->ndim; ++i) { diff --git a/include/tvm/runtime/nvtx.h b/include/tvm/runtime/nvtx.h index db99154b0b7c..289837c1fda1 100644 --- a/include/tvm/runtime/nvtx.h +++ b/include/tvm/runtime/nvtx.h @@ -19,7 +19,7 @@ #ifndef TVM_RUNTIME_NVTX_H_ #define TVM_RUNTIME_NVTX_H_ -#include +#include #include namespace tvm { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 1c4dfb39e247..6ce95eea1e83 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -25,8 +25,8 @@ #include #include -#include -#include +#include +#include #include @@ -55,10 +55,11 @@ enum TypeIndex : int32_t { kRuntimeModule = TVMFFITypeIndex::kTVMFFIModule, /*! \brief runtime::NDArray. */ kRuntimeNDArray = TVMFFITypeIndex::kTVMFFINDArray, - /*! \brief runtime::ShapeTuple. */ - kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShape, + /*! \brief runtime::Shape. */ + kRuntimeShape = TVMFFITypeIndex::kTVMFFIShape, // Extra builtin static index here - kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIStaticObjectEnd, + // We reserve 16 extra static indices for custom types + kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIDynObjectBegin - 16, /*! \brief ffi::Function. */ kRuntimePackedFunc = kCustomStaticIndex + 1, /*! \brief runtime::DRef for disco distributed runtime */ @@ -73,6 +74,10 @@ enum TypeIndex : int32_t { kStaticIndexEnd, }; +static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= + static_cast(TVMFFITypeIndex::kTVMFFIStaticObjectEnd), + "Static slot overflows to custom indices"); + /* * \brief Define the default copy/move constructor and assign operator * \param TypeName The class typename. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 235bdcf3e32f..3f8ec66bc1d3 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,425 +26,16 @@ #include #include -#include -#include -#include -#include - -#include -#include -#include namespace tvm { namespace runtime { -using ffi::Any; -using ffi::AnyView; - -/*! - * \brief Utility function to convert legacy ffi::AnyView to AnyView - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline TVMFFIAny LegacyTVMArgValueToFFIAny(TVMValue value, int type_code) { - TVMFFIAny res; - // clear first to ensure consistent hash - res.v_uint64 = 0; - switch (type_code) { - case kTVMArgInt: { - res.type_index = ffi::TypeIndex::kTVMFFIInt; - res.v_int64 = value.v_int64; - return res; - } - case kTVMArgFloat: { - res.type_index = ffi::TypeIndex::kTVMFFIFloat; - res.v_float64 = value.v_float64; - return res; - } - case kTVMOpaqueHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIOpaquePtr; - res.v_ptr = value.v_handle; - return res; - } - case kTVMNullptr: { - res.type_index = ffi::TypeIndex::kTVMFFINone; - return res; - } - case kTVMDataType: { - res.type_index = ffi::TypeIndex::kTVMFFIDataType; - res.v_dtype = value.v_type; - return res; - } - case kDLDevice: { - res.type_index = ffi::TypeIndex::kTVMFFIDevice; - res.v_device = value.v_device; - return res; - } - case kTVMDLTensorHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIDLTensorPtr; - res.v_ptr = value.v_handle; - return res; - } - case kTVMObjectHandle: { - res.v_obj = static_cast(value.v_handle); - res.type_index = res.v_obj->type_index; - return res; - } - case kTVMModuleHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIModule; - res.v_obj = static_cast(value.v_handle); - return res; - } - case kTVMPackedFuncHandle: { - res.type_index = ffi::TypeIndex::kTVMFFIFunction; - res.v_obj = static_cast(value.v_handle); - return res; - } - case kTVMStr: { - res.type_index = ffi::TypeIndex::kTVMFFIRawStr; - res.v_c_str = value.v_str; - return res; - } - case kTVMBytes: { - res.type_index = ffi::TypeIndex::kTVMFFIByteArrayPtr; - res.v_ptr = value.v_handle; - return res; - } - case kTVMNDArrayHandle: { - res.type_index = ffi::TypeIndex::kTVMFFINDArray; - res.v_obj = reinterpret_cast(TVMArrayHandleToObjectHandle(value.v_handle)); - return res; - } - case kTVMArgBool: { - res.type_index = ffi::TypeIndex::kTVMFFIBool; - res.v_int64 = value.v_int64; - return res; - } - case kTVMObjectRValueRefArg: { - res.type_index = ffi::TypeIndex::kTVMFFIObjectRValueRef; - res.v_ptr = value.v_handle; - return res; - } - default: { - LOG(FATAL) << "Unsupported type code: " << type_code; - TVM_FFI_UNREACHABLE(); - } - } -} - -/*! - * \brief Utility function to convert legacy ffi::AnyView to AnyView - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline AnyView LegacyTVMArgValueToAnyView(TVMValue value, int type_code) { - return AnyView::CopyFromTVMFFIAny(LegacyTVMArgValueToFFIAny(value, type_code)); -} - -/*! - * \brief Utility function to convert legacy ffi::AnyView to Any - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline Any MoveLegacyTVMArgValueToAny(TVMValue value, int type_code) { - return ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(LegacyTVMArgValueToFFIAny(value, type_code)); -} - -/* - * \brief Convert AnyView to legacy TVMValue and type_code - * \param src The AnyView to convert - * \param value The TVMValue to store the result - * \param type_code The type code to store the result - * \note This routine is not fastest, but serves purpose to do transition of ABI. - */ -inline void AnyViewToLegacyTVMArgValue(TVMFFIAny src, TVMValue* value, int* type_code) { - switch (src.type_index) { - case ffi::TypeIndex::kTVMFFIBool: { - type_code[0] = kTVMArgBool; - value[0].v_int64 = src.v_int64; - break; - } - case ffi::TypeIndex::kTVMFFIInt: { - type_code[0] = kDLInt; - value[0].v_int64 = src.v_int64; - break; - } - case ffi::TypeIndex::kTVMFFIFloat: { - type_code[0] = kDLFloat; - value[0].v_float64 = src.v_float64; - break; - } - case ffi::TypeIndex::kTVMFFIOpaquePtr: { - type_code[0] = kTVMOpaqueHandle; - value[0].v_handle = src.v_ptr; - break; - } - case ffi::TypeIndex::kTVMFFINone: { - type_code[0] = kTVMNullptr; - break; - } - case ffi::TypeIndex::kTVMFFIDataType: { - type_code[0] = kTVMDataType; - value[0].v_type = src.v_dtype; - break; - } - case ffi::TypeIndex::kTVMFFIDevice: { - type_code[0] = kDLDevice; - value[0].v_device = src.v_device; - break; - } - case ffi::TypeIndex::kTVMFFIDLTensorPtr: { - type_code[0] = kTVMDLTensorHandle; - value[0].v_handle = src.v_ptr; - break; - } - case ffi::TypeIndex::kTVMFFIRawStr: { - type_code[0] = kTVMStr; - value[0].v_str = src.v_c_str; - break; - } - case ffi::TypeIndex::kTVMFFIByteArrayPtr: { - type_code[0] = kTVMBytes; - value[0].v_handle = src.v_ptr; - break; - } - case ffi::TypeIndex::kTVMFFINDArray: { - type_code[0] = kTVMNDArrayHandle; - value[0].v_handle = ObjectHandleToTVMArrayHandle(reinterpret_cast(src.v_obj)); - break; - } - case ffi::TypeIndex::kTVMFFIModule: { - type_code[0] = kTVMModuleHandle; - value[0].v_handle = src.v_obj; - break; - } - case ffi::TypeIndex::kTVMFFIFunction: { - type_code[0] = kTVMPackedFuncHandle; - value[0].v_handle = src.v_obj; - break; - } - case ffi::TypeIndex::kTVMFFIObjectRValueRef: { - type_code[0] = kTVMObjectRValueRefArg; - value[0].v_handle = src.v_ptr; - break; - } - default: { - if (src.type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - type_code[0] = kTVMObjectHandle; - value[0].v_handle = src.v_obj; - break; - } - LOG(FATAL) << "Unsupported type index: " << src.type_index; - } - } -} - -/* - * \brief Move Any to legacy TVMValue and type_code - * \param src The Any to move - * \param value The TVMValue to store the result - * \param type_code The type code to store the result - */ -inline void MoveAnyToLegacyTVMValue(Any&& src, TVMValue* value, int* type_code) { - TVMFFIAny val = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src)); - // NOTE: conversion rule is the same as AnyViewToLegacyTVMArgValue - AnyViewToLegacyTVMArgValue(val, value, type_code); -} - -/*! - * \brief Translate legacy ffi::PackedArgs to PackedArgs - * \param value The TVMValue array - * \param type_code The type code array - * \param num_args The number of arguments - * \param dst The destination AnyView array - */ -inline void LegacyTVMArgsToPackedArgs(const TVMValue* value, const int* type_code, int num_args, - AnyView* dst) { - for (int i = 0; i < num_args; ++i) { - dst[i] = LegacyTVMArgValueToAnyView(value[i], type_code[i]); - } -} - -/*! - * \brief Translate legacy ffi::PackedArgs to PackedArgs - * \param args The AnyView array - * \param num_args The number of arguments - * \param value The TVMValue array - * \param type_code The type code array - */ -inline void PackedArgsToLegacyTVMArgs(const AnyView* args, int num_args, TVMValue* value, - int* type_code) { - for (int i = 0; i < num_args; ++i) { - AnyViewToLegacyTVMArgValue(args[i].CopyToTVMFFIAny(), value + i, type_code + i); - } -} +#define TVM_DLL_EXPORT_TYPED_FUNC TVM_FFI_DLL_EXPORT_TYPED_FUNC -/*! - * \brief Convert argument type code to string. - * \param type_code The input type code. - * \return The corresponding string repr. - */ -inline const char* ArgTypeCode2Str(int type_code) { - switch (type_code) { - case kDLInt: - return "int"; - case kTVMArgBool: - return "bool"; - case kDLUInt: - return "uint"; - case kDLFloat: - return "float"; - case kTVMStr: - return "str"; - case kTVMBytes: - return "bytes"; - case kTVMOpaqueHandle: - return "handle"; - case kTVMNullptr: - return "NULL"; - case kTVMDLTensorHandle: - return "ArrayHandle"; - case kTVMDataType: - return "DLDataType"; - case kDLDevice: - return "DLDevice"; - case kTVMPackedFuncHandle: - return "FunctionHandle"; - case kTVMModuleHandle: - return "ModuleHandle"; - case kTVMNDArrayHandle: - return "NDArrayContainer"; - case kTVMObjectHandle: - return "Object"; - case kTVMObjectRValueRefArg: - return "ObjectRValueRefArg"; - default: - LOG(FATAL) << "unknown type_code=" << static_cast(type_code); - } - throw; -} - -namespace details { - -template -struct ModuleVTableEntryHelper {}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = R (T::*)(Args...) const; - static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { - auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; - ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, - args.data(), args.size(), rv); - } -}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = R (T::*)(Args...); - static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { - auto wrapped = [self, f](Args... args) -> R { return (self->*f)(std::forward(args)...); }; - ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, - args.data(), args.size(), rv); - } -}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = void (T::*)(Args...) const; - static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { - auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; - ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, - args.data(), args.size(), rv); - } -}; - -template -struct ModuleVTableEntryHelper { - using MemFnType = void (T::*)(Args...); - static TVM_ALWAYS_INLINE void Call(ffi::Any* rv, T* self, MemFnType f, ffi::PackedArgs args) { - auto wrapped = [self, f](Args... args) -> void { (self->*f)(std::forward(args)...); }; - ffi::details::unpack_call(std::make_index_sequence{}, nullptr, wrapped, - args.data(), args.size(), rv); - } -}; -} // namespace details - -#define TVM_MODULE_VTABLE_BEGIN(TypeKey) \ - const char* type_key() const final { return TypeKey; } \ - ffi::Function GetFunction(const String& _name, const ObjectPtr& _self) override { \ - using SelfPtr = std::remove_cv_t; -#define TVM_MODULE_VTABLE_END() \ - return ffi::Function(nullptr); \ - } -#define TVM_MODULE_VTABLE_END_WITH_DEFAULT(MemFunc) \ - { \ - auto f = (MemFunc); \ - return (this->*f)(_name); \ - } \ - } // NOLINT(*) -#define TVM_MODULE_VTABLE_ENTRY(Name, MemFunc) \ - if (_name == Name) { \ - return ffi::Function::FromPacked([_self](ffi::PackedArgs args, Any* rv) -> void { \ - using Helper = ::tvm::runtime::details::ModuleVTableEntryHelper; \ - SelfPtr self = static_cast(_self.get()); \ - Helper::Call(rv, self, MemFunc, args); \ - }); \ - } -#define TVM_MODULE_VTABLE_ENTRY_PACKED(Name, MemFunc) \ - if (_name == Name) { \ - return ffi::Function([_self](ffi::PackedArgs args, Any* rv) -> void { \ - (static_cast(_self.get())->*(MemFunc))(args, rv); \ - }); \ - } - -/*! - * \brief Export typed function as a ffi::Function - * that can be loaded by LibraryModule. - * - * \param ExportName The symbol name to be exported. - * \param Function The typed function. - * \note ExportName and Function must be different, - * see code examples below. - * - * \sa ffi::TypedFunction - * - * \code - * - * int AddOne_(int x) { - * return x + 1; - * } - * - * // Expose the function as "AddOne" - * TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); - * - * // Expose the function as "SubOne" - * TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { - * return x - 1; - * }); - * - * // The following code will cause compilation error. - * // Because the same Function and ExportName - * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_); - * - * // The following code is OK, assuming the macro - * // is in a different namespace from xyz - * // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_); - * - * \endcode - */ -#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(void* self, TVMFFIAny* args, int32_t num_args, TVMFFIAny* result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call( \ - std::make_index_sequence{}, &name, Function, \ - reinterpret_cast(args), num_args, \ - reinterpret_cast<::tvm::ffi::Any*>(result)); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ - } -} // namespace runtime // NOLINT(*) using ffi::Any; using ffi::AnyView; + +} // namespace runtime } // namespace tvm + #endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 9d6623d9ad95..2a6ecc0e4d43 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -24,13 +24,14 @@ #ifndef TVM_RUNTIME_PROFILING_H_ #define TVM_RUNTIME_PROFILING_H_ -#include -#include -#include +#include +#include +#include +#include #include +#include +#include #include -#include -#include #include #include @@ -133,7 +134,7 @@ class Timer : public ObjectRef { * }; * TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); * - * TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { + * TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { * return Timer(make_object()); * }); * \endcode diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h deleted file mode 100644 index c5c124dd6fb8..000000000000 --- a/include/tvm/runtime/registry.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/registry.h - * \brief This file defines the TVM global function registry. - */ -#ifndef TVM_RUNTIME_REGISTRY_H_ -#define TVM_RUNTIME_REGISTRY_H_ - -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! \brief A class that wraps a Python object and preserves its ownership. - - * This class is used to wrap a PyObject* from the Python API and preserve its ownership. - * Allows for the creation of strong references to Python objects, which prevent them from being - * garbage-collected as long as the wrapper object exists. - */ -class WrappedPythonObject { - public: - /*! \brief Construct a wrapper that doesn't own anything */ - WrappedPythonObject() : python_obj_(nullptr) {} - - /*! \brief Conversion constructor from nullptr */ - explicit WrappedPythonObject(std::nullptr_t) : python_obj_(nullptr) {} - - /*! \brief Take ownership of a python object - * - * A new strong reference is created for the underlying python - * object. - * - * \param python_obj A PyObject* from the Python.h API. A new - * strong reference is created using Py_IncRef. - */ - explicit WrappedPythonObject(void* python_obj); - - /*! \brief Drop ownership of a python object - * - * Removes the strong reference held by the wrapper. - */ - ~WrappedPythonObject(); - - WrappedPythonObject(WrappedPythonObject&&); - WrappedPythonObject& operator=(WrappedPythonObject&&); - - WrappedPythonObject(const WrappedPythonObject&); - WrappedPythonObject& operator=(const WrappedPythonObject&); - WrappedPythonObject& operator=(std::nullptr_t); - - operator bool() { return python_obj_; } - - void* raw_pointer() { return python_obj_; } - - private: - void* python_obj_ = nullptr; -}; - -/*! - * \brief Register a function globally. - * \code - * TVM_REGISTER_GLOBAL("MyPrint") - * .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - * }); - * \endcode - */ -#define TVM_REGISTER_GLOBAL TVM_FFI_REGISTER_GLOBAL - -#define TVM_STRINGIZE_DETAIL(x) #x -#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) -#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) -/*! - * \brief Macro to include current line as string - */ -#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_REGISTRY_H_ diff --git a/include/tvm/runtime/relax_vm/executable.h b/include/tvm/runtime/relax_vm/executable.h index dc9d87025382..8a9fa024cab2 100644 --- a/include/tvm/runtime/relax_vm/executable.h +++ b/include/tvm/runtime/relax_vm/executable.h @@ -23,9 +23,9 @@ #ifndef TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ #define TVM_RUNTIME_RELAX_VM_EXECUTABLE_H_ +#include +#include #include -#include -#include #include #include diff --git a/include/tvm/runtime/relax_vm/ndarray_cache_support.h b/include/tvm/runtime/relax_vm/ndarray_cache_support.h index cf428b62a1f3..579fbf306f68 100644 --- a/include/tvm/runtime/relax_vm/ndarray_cache_support.h +++ b/include/tvm/runtime/relax_vm/ndarray_cache_support.h @@ -19,9 +19,9 @@ #ifndef TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_ #define TVM_RUNTIME_RELAX_VM_NDARRAY_CACHE_SUPPORT_H_ -#include +#include +#include #include -#include #include #include @@ -52,7 +52,7 @@ struct NDArrayCacheMetadata { /*! \brief Name of the parameter */ std::string name; /*! \brief Shape of the parameter */ - ShapeTuple shape; + ffi::Shape shape; /*! \brief Data type of the parameter */ DataType dtype; /*! \brief Format of the parameter */ diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index ce69548d7016..884a8d0f4375 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -155,7 +155,7 @@ class VirtualMachine : public runtime::ModuleNode { * \param rv The return value. */ virtual void InvokeClosurePacked(const ObjectRef& closure_or_packedfunc, ffi::PackedArgs args, - Any* rv) = 0; + ffi::Any* rv) = 0; /*! * \brief Set an instrumentation function. * diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index b35cad368832..2cfd1de44dde 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -27,7 +27,7 @@ #include #include -#include +#include #include namespace dmlc { diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index e50c67a37664..85d6dcce5e1b 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -154,7 +154,7 @@ class IRBuilderFrame : public runtime::ObjectRef { class IRBuilderNode : public runtime::Object { public: /*! \brief A stack of context frames in the IRBuilder */ - runtime::Array frames; + Array frames; /*! \brief The outcome of IR construction */ Optional result; @@ -170,7 +170,7 @@ class IRBuilderNode : public runtime::Object { /*! * \brief Find a frame of the given type in the stack `this->frames` from top to bottom. * \tparam T The type of the frame to find. - * \return The frame if found, otherwise NullOpt. + * \return The frame if found, otherwise std::nullopt. */ template inline Optional FindFrame() const; @@ -178,7 +178,7 @@ class IRBuilderNode : public runtime::Object { * \brief Get the frame on top of the stack `this->frames` if its type is `TFrame`. * \tparam TFrame The assumed type of the last frame on stack. * \return The frame if the stack is non-empty and the top of the stack is of type `TFrame`. - * Otherwise NullOpt. + * Otherwise std::nullopt. */ template inline Optional GetLastFrame() const; @@ -274,7 +274,7 @@ inline Optional IRBuilderNode::FindFrame() const { return GetRef(p); } } - return NullOpt; + return std::nullopt; } template @@ -283,7 +283,7 @@ inline Optional IRBuilderNode::GetLastFrame() const { if (!frames.empty() && frames.back()->IsInstance()) { return Downcast(frames.back()); } - return NullOpt; + return std::nullopt; } template diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 47c058d76a83..98a51fcb7829 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -54,7 +54,7 @@ class SeqExprFrameNode : public RelaxFrameNode { public: /*! \brief The binding blocks inside the frame. */ Array binding_blocks; - /*! \brief The frame output expr. `NullOpt` when undefined. */ + /*! \brief The frame output expr. `std::nullopt` when undefined. */ Optional output; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index c931a09d1f72..49bc1a2851d3 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -101,7 +101,7 @@ TVM_DLL void DataflowBlockOutput(const Array& vars); */ TVM_DLL tvm::relax::Var Emit( const tvm::relax::Expr& value, - const Optional& annotate_struct_info = NullOpt); + const Optional& annotate_struct_info = std::nullopt); /*! * \brief Emit a match_cast binding to the last binding block frame. diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 5927a0a284f8..d3eb8ac435d5 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -109,7 +109,7 @@ Type FuncRet(Type ret_type); * \return The matched buffer. */ Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype = DataType::Float(32), - Optional data = NullOpt, Array strides = {}, + Optional data = std::nullopt, Array strides = {}, PrimExpr elem_offset = PrimExpr(), String storage_scope = "global", int align = -1, int offset_factor = 0, String buffer_type = "default", Optional> axis_separators = std::nullopt); @@ -167,7 +167,7 @@ void BlockAttrs(Map attrs); * \return The allocated buffer. */ Buffer AllocBuffer(Array shape, DataType dtype = DataType::Float(32), - Optional data = NullOpt, Array strides = {}, + Optional data = std::nullopt, Array strides = {}, PrimExpr elem_offset = PrimExpr(), String storage_scope = "", int align = -1, int offset_factor = 0, String buffer_type = "default", Optional> axis_separators = std::nullopt); @@ -227,7 +227,8 @@ Array Remap(String kinds, Array bindings, DataType dtype = DataTy * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame Serial(PrimExpr start, PrimExpr stop, Optional> annotations = NullOpt); +ForFrame Serial(PrimExpr start, PrimExpr stop, + Optional> annotations = std::nullopt); /*! * \brief The parallel For statement. * \param start The minimum value of iteration. @@ -235,7 +236,8 @@ ForFrame Serial(PrimExpr start, PrimExpr stop, Optional> annota * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame Parallel(PrimExpr start, PrimExpr stop, Optional> annotations = NullOpt); +ForFrame Parallel(PrimExpr start, PrimExpr stop, + Optional> annotations = std::nullopt); /*! * \brief The vectorized For statement. * \param start The minimum value of iteration. @@ -244,7 +246,7 @@ ForFrame Parallel(PrimExpr start, PrimExpr stop, Optional> anno * \return The ForFrame. */ ForFrame Vectorized(PrimExpr start, PrimExpr stop, - Optional> annotations = NullOpt); + Optional> annotations = std::nullopt); /*! * \brief The unrolled For statement. * \param start The minimum value of iteration. @@ -252,7 +254,8 @@ ForFrame Vectorized(PrimExpr start, PrimExpr stop, * \param annotations The optional annotations of the For statement. * \return The ForFrame. */ -ForFrame Unroll(PrimExpr start, PrimExpr stop, Optional> annotations = NullOpt); +ForFrame Unroll(PrimExpr start, PrimExpr stop, + Optional> annotations = std::nullopt); /*! * \brief The thread-binding For statement. * \param start The minimum value of iteration. @@ -262,7 +265,7 @@ ForFrame Unroll(PrimExpr start, PrimExpr stop, Optional> annota * \return The ForFrame. */ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, - Optional> annotations = NullOpt); + Optional> annotations = std::nullopt); /*! * \brief The grid For statement. * \param extents The extents of the iteration. @@ -287,8 +290,8 @@ AssertFrame Assert(PrimExpr condition, String message); * \param var The variable to be bound. If not specified, a new variable will be created. * \return The created LetFrame. */ -LetFrame LetStmt(PrimExpr value, Optional type_annotation = NullOpt, - Optional var = NullOpt); +LetFrame LetStmt(PrimExpr value, Optional type_annotation = std::nullopt, + Optional var = std::nullopt); /*! * \brief The realization. @@ -309,8 +312,8 @@ RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, * \return The created AllocateFrame. */ AllocateFrame Allocate(Array extents, DataType dtype, String storage_scope = "", - Optional condition = NullOpt, - Optional> annotations = NullOpt); + Optional condition = std::nullopt, + Optional> annotations = std::nullopt); /*! * \brief The allocate constant node. @@ -321,7 +324,7 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s * \return The created AllocateConstFrame. */ AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array extents, - Optional> annotations = NullOpt); + Optional> annotations = std::nullopt); /*! * \brief Create an attribute. @@ -456,12 +459,12 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); } -#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ - inline PrimExpr FuncName(Optional expr = NullOpt, bool is_size_var = false) { \ - DataType dtype = DType; \ - return expr.defined() \ - ? tvm::cast(dtype, expr.value()) \ - : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(Optional expr = std::nullopt, bool is_size_var = false) { \ + DataType dtype = DType; \ + return expr.defined() \ + ? tvm::cast(dtype, expr.value()) \ + : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ } #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 18d7a8194efc..b9a21b126c4f 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -161,7 +161,7 @@ class StmtDocNode : public DocNode { * line as the statement, or the line above, or inside the statement * if it spans over multiple lines. * */ - mutable Optional comment{NullOpt}; + mutable Optional comment{std::nullopt}; void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); @@ -634,7 +634,7 @@ class TupleDoc : public ExprDoc { /*! * \brief Create an empty TupleDoc */ - TupleDoc() : TupleDoc(runtime::make_object()) {} + TupleDoc() : TupleDoc(ffi::make_object()) {} /*! * \brief Constructor of TupleDoc * \param elements Elements of tuple. @@ -672,7 +672,7 @@ class ListDoc : public ExprDoc { /*! * \brief Create an empty ListDoc */ - ListDoc() : ListDoc(runtime::make_object()) {} + ListDoc() : ListDoc(ffi::make_object()) {} /*! * \brief Constructor of ListDoc * \param elements Elements of list. @@ -718,7 +718,7 @@ class DictDoc : public ExprDoc { /*! * \brief Create an empty dictionary */ - DictDoc() : DictDoc(runtime::make_object()) {} + DictDoc() : DictDoc(ffi::make_object()) {} /*! * \brief Constructor of DictDoc * \param keys Keys of dictionary. @@ -957,7 +957,7 @@ class ForDoc : public StmtDoc { class ScopeDocNode : public StmtDocNode { public: /*! \brief The name of the scoped variable. */ - Optional lhs{NullOpt}; + Optional lhs{std::nullopt}; /*! \brief The value of the scoped variable. */ ExprDoc rhs{nullptr}; /*! \brief The body of the scope doc. */ @@ -1043,7 +1043,7 @@ class AssertDocNode : public StmtDocNode { /*! \brief The expression to test. */ ExprDoc test{nullptr}; /*! \brief The optional error message when assertion failed. */ - Optional msg{NullOpt}; + Optional msg{std::nullopt}; void VisitAttrs(AttrVisitor* v) { StmtDocNode::VisitAttrs(v); @@ -1067,7 +1067,7 @@ class AssertDoc : public StmtDoc { * \param test The expression to test. * \param msg The optional error message when assertion failed. */ - explicit AssertDoc(ExprDoc test, Optional msg = NullOpt); + explicit AssertDoc(ExprDoc test, Optional msg = std::nullopt); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssertDoc, StmtDoc, AssertDocNode); }; @@ -1125,7 +1125,7 @@ class FunctionDocNode : public StmtDocNode { /*! \brief Decorators of function. */ Array decorators; /*! \brief The return type of function. */ - Optional return_type{NullOpt}; + Optional return_type{std::nullopt}; /*! \brief The body of function. */ Array body; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index db94064d538c..9211064526ed 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -203,7 +203,7 @@ class IRDocsifierNode : public Object { * \brief Get the doc for variable. * \param obj The variable object. * - * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. + * \return The doc for variable, if it exists in the table. Otherwise it returns std::nullopt. */ Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 40bb245e72f3..62133ef2c9da 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -19,9 +19,9 @@ #ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ #define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_ +#include #include #include -#include #include #include diff --git a/include/tvm/support/parallel_for.h b/include/tvm/support/parallel_for.h index 8bd2e6b825ab..aa9da30d8f1c 100644 --- a/include/tvm/support/parallel_for.h +++ b/include/tvm/support/parallel_for.h @@ -24,7 +24,7 @@ #ifndef TVM_SUPPORT_PARALLEL_FOR_H_ #define TVM_SUPPORT_PARALLEL_FOR_H_ -#include +#include #include #include diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h deleted file mode 100644 index 768252f77ce9..000000000000 --- a/include/tvm/support/span.h +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * - * \file tvm/support/span.h - * \brief Reimplementation of part of C++-20 style span. - */ -#ifndef TVM_SUPPORT_SPAN_H_ -#define TVM_SUPPORT_SPAN_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace support { - -/*! - * \brief A partial implementation of the C++20 std::span. - * - * At the time of writing, TVM must compile against C++17. - */ -template -class Span { - public: - using value_type = W; - using const_W = typename std::add_const::type; - - template - class iterator_base { - public: - using iterator_category = std::input_iterator_tag; - using value_type = W; - using difference_type = std::ptrdiff_t; - using pointer = const W*; - using reference = const W&; - - inline iterator_base(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); } - - inline W1 operator*() { return W1(*ptr_); } - - inline iterator_base& operator++() { - if (ptr_ != end_) ptr_++; - return *this; - } - - inline bool operator==(iterator_base other) { - return ptr_ == other.ptr_ && end_ == other.end_; - } - - inline bool operator!=(iterator_base other) { return !(*this == other); } - - template ::value>> - inline operator iterator_base() const { - return iterator_base(ptr_, end_); - } - - private: - T* ptr_; - T* end_; - }; - - using iterator = iterator_base; - using const_iterator = iterator_base; - - inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {} - inline Span(T* begin, T* end) : begin_{begin}, end_{end} {} - - inline iterator begin() const { return iterator(begin_, end_); } - - inline iterator end() const { return iterator(end_, end_); } - - size_t size() const { return end_ - begin_; } - - inline W operator[](int i) { - T* to_return = begin_ + i; - ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i; - return W(*to_return); - } - - inline operator std::vector() { return std::vector(begin(), end()); } - - protected: - T* begin_; - T* end_; -}; - -} // namespace support -} // namespace tvm - -#endif // TVM_SUPPORT_SPAN_H_ diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index faa58f84870a..54f09a081b93 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -25,7 +25,7 @@ #define TVM_TARGET_CODEGEN_H_ #include -#include +#include #include #include diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index b851d1a8780d..86e90a7ce2db 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -103,7 +103,7 @@ class TargetNode : public Object { * \tparam TObjectRef Type of the attribute * \param attr_key The name of the attribute key * \param default_value The value returned if the key is not present - * \return An optional, NullOpt if not found, otherwise the value found + * \return An optional, std::nullopt if not found, otherwise the value found */ template Optional GetAttr( @@ -121,7 +121,7 @@ class TargetNode : public Object { * \tparam TObjectRef Type of the attribute * \param attr_key The name of the attribute key * \param default_value The value returned if the key is not present - * \return An optional, NullOpt if not found, otherwise the value found + * \return An optional, std::nullopt if not found, otherwise the value found */ template Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 32b3231781f4..a21112b7d6f6 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -314,7 +314,7 @@ TVM_DLL tvm::Map> CalculateAllocatedBytes(cons * The LCA may be a For loop or a Block. * \param func The PrimFunc to be detected. * \return The Map from buffer to the LCA of all access to it. The lca is function root if the - * return stmt is NullOpt. + * return stmt is std::nullopt. */ TVM_DLL Map> DetectBufferAccessLCA(const PrimFunc& func); @@ -421,7 +421,7 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); * \returns The pass. * \sa tvm::tir::CalculateAllocatedBytes */ -TVM_DLL Pass VerifyVTCMLimit(Optional target = NullOpt); +TVM_DLL Pass VerifyVTCMLimit(Optional target = std::nullopt); /*! * \brief Statically check TIR code for out of bounds array access. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 65bb88082a49..e0a197d41ff8 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -24,10 +24,10 @@ #ifndef TVM_TIR_BUFFER_H_ #define TVM_TIR_BUFFER_H_ +#include +#include #include #include -#include -#include #include #include @@ -204,7 +204,7 @@ class Buffer : public ObjectRef { */ TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0), - Optional input_extent = NullOpt) const; + Optional input_extent = std::nullopt) const; /*! * \brief Create an Expr that does a vector load at begin index. * \param begin The beginning index @@ -213,7 +213,7 @@ class Buffer : public ObjectRef { * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ TVM_DLL PrimExpr vload(Array begin, DataType dtype, - Optional predicate = NullOpt) const; + Optional predicate = std::nullopt) const; /*! * \brief Create a Stmt that does a vector store at begin index. * \param begin The beginning index @@ -222,7 +222,7 @@ class Buffer : public ObjectRef { * stored. The number lanes of the mask must be equal to the number of lanes in value. */ TVM_DLL Stmt vstore(Array begin, PrimExpr value, - Optional predicate = NullOpt) const; + Optional predicate = std::nullopt) const; /*! * \brief Get a flattened version of the buffer diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 2035a511c1bb..c057422a0266 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -415,17 +415,15 @@ TVM_DLL const Op& tvm_thread_invariant(); * type codes are explicitly allocated. * * return_type tvm_call_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, + * TVMFFIAny* args_stack, * int begin, * int end) { * ModuleNode* env = GetCurrentEnv(); * const ffi::Function* f = env->GetFuncFromEnv(name); - * f->CallPacked(ffi::PackedArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * ffi::Any(value_stack + end, tcode_stack + end)); + * f->CallPacked(ffi::PackedArgs(args_stack[begin:end]), + * ffi::Any(args_stack + end)); * // return type can be int, float, handle. - * return cast(return_type, load_return_from(tcode_stack + end)) + * return cast(return_type, load_return_from(args_stack + end)) * } */ TVM_DLL const Op& tvm_call_packed_lowered(); @@ -451,17 +449,15 @@ TVM_DLL const Op& tvm_call_cpacked_lowered(); * (end - 1) value on the stack. * * return_type tvm_call_trace_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, + * TVMFFIAny* args_stack, * int begin, * int end) { * ModuleNode* env = GetCurrentEnv(); * const ffi::Function* f = env->GetFuncFromEnv(name); - * f->CallPacked(ffi::PackedArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * ffi::Any(value_stack + end, tcode_stack + end)); + * f->CallPacked(ffi::PackedArgs(args_stack[begin:end]), + * ffi::Any(args_stack + end)); * // return type can be int, float, handle. - * return cast(return_type, load_return_from(tcode_stack + end)) + * return cast(return_type, load_return_from(args_stack + end)) * } */ TVM_DLL const Op& tvm_call_trace_packed_lowered(); diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index d2b674a80139..5f058f7d5e4c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -25,13 +25,13 @@ #ifndef TVM_TIR_EXPR_H_ #define TVM_TIR_EXPR_H_ +#include +#include +#include #include #include #include -#include -#include -#include -#include +#include #include #include #include @@ -681,7 +681,7 @@ class BufferLoadNode : public PrimExprNode { class BufferLoad : public PrimExpr { public: TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, - Optional predicate = NullOpt, Span span = Span()); + Optional predicate = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index f85c0ed706ef..edb88bcafe55 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -24,9 +24,9 @@ #ifndef TVM_TIR_FUNCTION_H_ #define TVM_TIR_FUNCTION_H_ +#include +#include #include -#include -#include #include #include #include diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index 319b5193e8f6..1a5bdd8e4018 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -26,8 +26,8 @@ #ifndef TVM_TIR_INDEX_MAP_H_ #define TVM_TIR_INDEX_MAP_H_ +#include #include -#include #include #include @@ -182,7 +182,7 @@ class IndexMap : public ObjectRef { * \param inverse_index_map The optional pre-defined inverse index map */ IndexMap(Array initial_indices, Array final_indices, - Optional inverse_index_map = NullOpt); + Optional inverse_index_map = std::nullopt); /*! * \brief Create an index map from a packed function @@ -192,7 +192,7 @@ class IndexMap : public ObjectRef { * \return The created index map */ static IndexMap FromFunc(int ndim, ffi::TypedFunction(Array)> func, - Optional inverse_index_map = NullOpt); + Optional inverse_index_map = std::nullopt); /*! \brief Generate the inverse mapping. * diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 59d4cbbcd507..883477dd645e 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -28,9 +28,9 @@ #ifndef TVM_TIR_OP_ATTR_TYPES_H_ #define TVM_TIR_OP_ATTR_TYPES_H_ +#include +#include #include -#include -#include #include diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index f869a4840ce9..1a24644b5202 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -225,7 +225,7 @@ class ScheduleNode : public runtime::Object { * \return The random variable sampled from candidates */ virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) = 0; + Optional decision = std::nullopt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled @@ -235,7 +235,7 @@ class ScheduleNode : public runtime::Object { * \return A list of length `n`, the random perfect tile sizes sampled */ virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = NullOpt) = 0; + Optional> decision = std::nullopt) = 0; /*! * \brief Sample the factors to a partitioned tile for a specific loop * @@ -253,7 +253,7 @@ class ScheduleNode : public runtime::Object { */ virtual Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, - Optional> decision = NullOpt) = 0; + Optional> decision = std::nullopt) = 0; /*! * \brief Sample a compute-at location of the given block * \param block_rv The block whose compute-at location is to be sampled @@ -261,7 +261,7 @@ class ScheduleNode : public runtime::Object { * \return The sampled loop where the input block is to be computed at */ virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = NullOpt) = 0; + Optional decision = std::nullopt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -278,7 +278,8 @@ class ScheduleNode : public runtime::Object { * * \sa WorkOn */ - virtual BlockRV GetBlock(const String& name, const Optional& func_name = NullOpt) = 0; + virtual BlockRV GetBlock(const String& name, + const Optional& func_name = std::nullopt) = 0; /*! * \brief Get the parent loops of the block in its scope, from outer to inner * \param block_rv The query block @@ -347,14 +348,12 @@ class ScheduleNode : public runtime::Object { * 1) The loop can't have annotation or thread binding. * 2) The loop must start with 0. * \param loop_rv The loop to be split - * \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means - * that factor is inferred. - * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings. - * \param disable_predication If enabled, don't create a predicate for guarding the - * loop. This can be useful when splitting with scalable factors that the schedule writer - * knows are divisible by the loop bound. - * Warning: enabling this feature may result in incorrect code generation if not used carefully. - * \return The new loops after split. + * \param factors The positive tiling factors, and at most one of which is `std::nullopt`, which + * means that factor is inferred. \param preserve_unit_iters Whether or not to preserve unit + * iterators in block bindings. \param disable_predication If enabled, don't create a predicate + * for guarding the loop. This can be useful when splitting with scalable factors that the + * schedule writer knows are divisible by the loop bound. Warning: enabling this feature may + * result in incorrect code generation if not used carefully. \return The new loops after split. */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors, bool preserve_unit_iters = true, @@ -363,7 +362,7 @@ class ScheduleNode : public runtime::Object { * \brief Partition the loops into sequence of multiple loops * 1) The loop can't have annotation or thread binding. * \param loop_rv The loop to be partition - * \param factors The positive integers, and at most one of which is `NullOpt`, which means + * \param factors The positive integers, and at most one of which is `std::nullopt`, which means * that factor is inferred. * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings * \return The new loops after partition @@ -761,7 +760,7 @@ class ScheduleNode : public runtime::Object { */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, const IndexMap& index_map, - const Optional& pad_value = NullOpt, + const Optional& pad_value = std::nullopt, bool assume_injective_transform = false) = 0; /*! diff --git a/include/tvm/tir/schedule/trace.h b/include/tvm/tir/schedule/trace.h index 79a2f8b2a08e..fca5966e198b 100644 --- a/include/tvm/tir/schedule/trace.h +++ b/include/tvm/tir/schedule/trace.h @@ -92,7 +92,7 @@ class TraceNode : public runtime::Object { void Append(Instruction inst, Any decision); /*! * \brief Remove the last instruction, along with the decision made on that instruction, if any - * \return The instruction removed; NullOpt if the trace is empty + * \return The instruction removed; std::nullopt if the trace is empty */ Optional Pop(); /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index efcf4a47d61c..6d93a3a153ad 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -265,7 +265,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate = NullOpt, Span span = Span()); + Optional predicate = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); @@ -818,16 +818,16 @@ class SeqStmt : public Stmt { return t; } if constexpr (!std::is_base_of_v) { - return NullOpt; + return std::nullopt; } if constexpr (std::is_base_of_v) { if (const SeqStmtNode* ptr = t.template as()) { return GetRef(ptr); } else { - return NullOpt; + return std::nullopt; } } - return NullOpt; + return std::nullopt; } template @@ -925,7 +925,7 @@ class IfThenElseNode : public StmtNode { */ class IfThenElse : public Stmt { public: - TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case = NullOpt, + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case = std::nullopt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); @@ -1035,7 +1035,7 @@ class ForNode : public StmtNode { class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, - Optional thread_binding = NullOpt, + Optional thread_binding = std::nullopt, Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); @@ -1280,7 +1280,7 @@ class BlockNode : public StmtNode { * reduction block. The optional init field allows us to represent initialization and * reduction update in a single block and transform them collectively. * We also provide primitives to decompose the init into a separate block during scheduling. - * Init field is `NullOpt` if there is no reduction iter_vars + * Init field is `std::nullopt` if there is no reduction iter_vars */ Optional init; /*! \brief The buffer allocated in the block. */ @@ -1334,7 +1334,7 @@ class Block : public Stmt { public: TVM_DLL explicit Block(Array iter_vars, Array reads, Array writes, String name_hint, Stmt body, - Optional init = NullOpt, + Optional init = std::nullopt, Array alloc_buffers = Array(), Array match_buffers = Array(), Map annotations = Map(), diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 9ce49610ee34..141fe710b371 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -226,7 +226,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { } else { // Make a new copy of the node. // need to rely on the default copy constructor - return runtime::make_object(*node); + return ffi::make_object(*node); } } /*! @@ -331,13 +331,13 @@ class StmtExprMutator : public StmtMutator, public ExprMutator { * won't do further recursion. * \param postorder The function called after recursive mutation. * The recursive mutation result is passed to postorder for further mutation. - * \param only_enable List of runtime::String. + * \param only_enable List of String. * If it is null, all IRNode will call preorder/postorder * If it is not null, preorder/postorder will only be called * when the IRNode's type key is in the list. */ TVM_DLL Stmt IRTransform(Stmt stmt, const ffi::Function& preorder, const ffi::Function& postorder, - Optional> only_enable = NullOpt); + Optional> only_enable = std::nullopt); /*! * \brief Recursively visit the ir in post DFS order node, apply fvisit @@ -418,7 +418,7 @@ auto Substitute(Obj&& obj, const Map& vmap) { if (auto opt = vmap.Get(var)) { return opt.value(); } else { - return NullOpt; + return std::nullopt; } }; return Substitute(std::forward(obj), func); @@ -440,7 +440,7 @@ auto Substitute(Obj&& obj, const std::unordered_map& vmap) if (auto it = vmap.find(var.get()); it != vmap.end()) { return it->second; } else { - return NullOpt; + return std::nullopt; } }; return Substitute(std::forward(obj), func); @@ -462,7 +462,7 @@ auto Substitute(Obj&& obj, const std::unordered_mapsecond; } else { - return NullOpt; + return std::nullopt; } }; return Substitute(std::forward(obj), func); @@ -489,7 +489,7 @@ auto Substitute(Obj&& obj, const std::unordered_map& iter_vmap) { if (auto it = vmap.find(var.get()); it != vmap.end()) { return it->second; } else { - return NullOpt; + return std::nullopt; } }; return Substitute(std::forward(obj), func); diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index b80d4456c0be..eb64d87f9518 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -166,9 +166,9 @@ TVM_DLL Pass InstrumentBoundCheckers(); * f() * * if num_packed_args is not zero: - * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, + * f(void *, TVMFFIAny* packed_args, int num_packed_args, * api_arg_k, api_arg_k+1, ... api_arg_n, - * TVMValue* out_ret_val, int* out_ret_tcode) + * TVMFFIAny* out_ret_val) * * where n == len(api_args), k == num_packed_args * diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 6e2edb3f304a..8e13ae49afdf 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -352,7 +352,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ Map attrs; if (pool_type == kMaxPool) { - attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_max")); + attrs.Set("schedule_rule", tvm::String("meta_schedule.adaptive_pool_max")); return tvm::te::compute( out_shape, [&](const Array& output) { @@ -363,7 +363,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ }, "adaptive_pool_max", "adaptive_pool_max", attrs); } else if (pool_type == kAvgPool) { - attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_avg")); + attrs.Set("schedule_rule", tvm::String("meta_schedule.adaptive_pool_avg")); auto pool_sum = tvm::te::compute( out_shape, [&](const Array& output) { @@ -383,7 +383,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ PrimExpr divide_factor = tvm::cast(x->dtype, 1); for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent); } return div(pool_sum(indices), divide_factor); @@ -566,7 +566,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, Map attrs; if (pool_type == kMaxPool) { auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_max")); + attrs.Set("schedule_rule", tvm::String("meta_schedule.pool_max")); return tvm::te::compute( out_shape, [&](const Array& output) { @@ -581,7 +581,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, }, "pool_max", "pool_max", attrs); } else if (pool_type == kAvgPool) { - attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_avg")); + attrs.Set("schedule_rule", tvm::String("meta_schedule.pool_avg")); // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; diff --git a/include/tvm/topi/utils.h b/include/tvm/topi/utils.h index 368674c0b6cb..b5f2d6c38d61 100644 --- a/include/tvm/topi/utils.h +++ b/include/tvm/topi/utils.h @@ -24,8 +24,8 @@ #ifndef TVM_TOPI_UTILS_H_ #define TVM_TOPI_UTILS_H_ +#include #include -#include namespace tvm { namespace topi { @@ -37,7 +37,7 @@ inline Optional> ArrayOrInt(AnyView arg) { if (arg == nullptr) { return std::nullopt; } - if (auto opt_int = arg.as()) { + if (auto opt_int = arg.try_cast()) { Array result; result.push_back(opt_int.value()); return result; diff --git a/jvm/README.md b/jvm/README.md index 0f53f4e561a2..71c737a4d00a 100644 --- a/jvm/README.md +++ b/jvm/README.md @@ -39,7 +39,7 @@ TVM4J contains three modules: - core * It contains all the Java interfaces. - native - * The JNI native library is compiled in this module. It does not link TVM runtime library (libtvm\_runtime.so for Linux and libtvm\_runtime.dylib for OSX). Instead, you have to specify `libtvm.so.path` which contains the TVM runtime library as Java system property. + * The JNI native library is compiled in this module. Need to expose libtvm_runtime to LD_LIBRARY_PATH - assembly * It assembles Java interfaces (core), JNI library (native) and TVM runtime library together. The simplest way to integrate tvm4j in your project is to rely on this module. It automatically extracts the native library to a tempfile and load it. diff --git a/jvm/core/src/main/java/org/apache/tvm/Base.java b/jvm/core/src/main/java/org/apache/tvm/Base.java index f5e677a2e0b3..97ae274a565c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Base.java +++ b/jvm/core/src/main/java/org/apache/tvm/Base.java @@ -87,37 +87,8 @@ public RefTVMValue() { } System.err.println("libtvm4j loads successfully."); - - if (loadNativeRuntimeLib) { - String tvmLibFilename = System.getProperty("libtvm.so.path"); - if (tvmLibFilename == null || !new File(tvmLibFilename).isFile() - || _LIB.nativeLibInit(tvmLibFilename) != 0) { - try { - String runtimeLibname; - String os = System.getProperty("os.name"); - // ref: http://lopica.sourceforge.net/os.html - if (os.startsWith("Linux")) { - runtimeLibname = "libtvm_runtime.so"; - } else if (os.startsWith("Mac")) { - runtimeLibname = "libtvm_runtime.dylib"; - } else { - // TODO(yizhi) support windows later - throw new UnsatisfiedLinkError(os + " not supported currently"); - } - NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() { - @Override public void invoke(File target) { - System.err.println("Loading tvm runtime from " + target.getPath()); - checkCall(_LIB.nativeLibInit(target.getPath())); - } - }); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - } else { - _LIB.nativeLibInit(null); - } - + // always use linked lib + _LIB.nativeLibInit(null); Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { _LIB.shutdown(); @@ -170,7 +141,7 @@ private static void tryLoadLibraryXPU(String libname, String arch) throws Unsati */ public static void checkCall(int ret) throws TVMError { if (ret != 0) { - throw new TVMError(_LIB.tvmGetLastError()); + throw new TVMError(_LIB.tvmFFIGetLastError()); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/Device.java b/jvm/core/src/main/java/org/apache/tvm/Device.java index 70fe13cec906..2396df94fbf0 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Device.java +++ b/jvm/core/src/main/java/org/apache/tvm/Device.java @@ -17,18 +17,30 @@ package org.apache.tvm; +import org.apache.tvm.rpc.RPC; + import java.util.HashMap; import java.util.Map; -import org.apache.tvm.rpc.RPC; public class Device { /** * Provides the same information as the C++ enums DLDeviceType and * TVMDeviceExtType. */ - static final int kDLCPU = 1, kDLCUDA = 2, kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7, - kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, kDLROCMHost = 11, kDLExtDev = 12, - kDLCUDAManaged = 13, kDLOneAPI = 14, kDLWebGPU = 15, kDLHexagon = 16; + static final int kDLCPU = 1; + static final int kDLCUDA = 2; + static final int kDLCUDAHost = 3; + static final int kDLOpenCL = 4; + static final int kDLVulkan = 7; + static final int kDLMetal = 8; + static final int kDLVPI = 9; + static final int kDLROCM = 10; + static final int kDLROCMHost = 11; + static final int kDLExtDev = 12; + static final int kDLCUDAManaged = 13; + static final int kDLOneAPI = 14; + static final int kDLWebGPU = 15; + static final int kDLHexagon = 16; private static final Map DEVICE_TYPE_TO_NAME = new HashMap(); private static final Map DEVICE_NAME_TO_TYPE = new HashMap(); @@ -161,7 +173,8 @@ public Device(String deviceType, int deviceId) { */ public boolean exist() { TVMValue ret = - APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke(); + APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType) + .pushArg(deviceId).pushArg(0).invoke(); return ((TVMValueLong) ret).value != 0; } @@ -171,7 +184,8 @@ public boolean exist() { */ public long maxThreadsPerBlock() { TVMValue ret = - APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke(); + APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType) + .pushArg(deviceId).pushArg(1).invoke(); return ((TVMValueLong) ret).value; } @@ -181,8 +195,9 @@ public long maxThreadsPerBlock() { */ public long warpSize() { TVMValue ret = - APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke(); - return ((TVMValueLong) ret).value; + APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType) + .pushArg(deviceId).pushArg(2).invoke(); + return ret.asLong(); } /** diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index 594b35b0af68..ee6b8e8cf5c5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -24,23 +24,14 @@ /** * TVM Packed Function. */ -public class Function extends TVMValue { - final long handle; - public final boolean isResident; - private boolean isReleased = false; - +public class Function extends TVMObject { /** * Get registered function. * @param name full function name. * @return TVM function. */ public static Function getFunction(final String name) { - for (String fullName : listGlobalFuncNames()) { - if (fullName.equals(name)) { - return getGlobalFunc(fullName, true, false); - } - } - return null; + return getGlobalFunc(name, true); } /** @@ -49,22 +40,21 @@ public static Function getFunction(final String name) { */ private static List listGlobalFuncNames() { List names = new ArrayList(); - Base.checkCall(Base._LIB.tvmFuncListGlobalNames(names)); + Base.checkCall(Base._LIB.tvmFFIFunctionListGlobalNames(names)); return Collections.unmodifiableList(names); } /** * Get a global function by name. * @param name The name of the function. - * @param isResident Whether it is a global 'resident' function. * @param allowMissing Whether allow missing function or raise an error. * @return The function to be returned, None if function is missing. */ - private static Function getGlobalFunc(String name, boolean isResident, boolean allowMissing) { + private static Function getGlobalFunc(String name, boolean allowMissing) { Base.RefLong handle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmFuncGetGlobal(name, handle)); + Base.checkCall(Base._LIB.tvmFFIFunctionGetGlobal(name, handle)); if (handle.value != 0) { - return new Function(handle.value, isResident); + return new Function(handle.value); } else { if (allowMissing) { return null; @@ -74,24 +64,8 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a } } - /** - * Initialize the function with handle. - * @param handle the handle to the underlying function. - * @param isResident Whether this is a resident function in jvm - */ - Function(long handle, boolean isResident) { - super(ArgTypeCode.FUNC_HANDLE); - this.handle = handle; - this.isResident = isResident; - } - Function(long handle) { - this(handle, false); - } - - @Override protected void finalize() throws Throwable { - release(); - super.finalize(); + super(handle, TypeIndex.kTVMFFIFunction); } /** @@ -102,32 +76,13 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a return this; } - @Override long asHandle() { - return handle; - } - - /** - * Release the Function. - *

- * We highly recommend you to do this manually since the GC strategy is lazy. - *

- */ - @Override public void release() { - if (!isReleased) { - if (!isResident) { - Base.checkCall(Base._LIB.tvmFuncFree(handle)); - isReleased = true; - } - } - } - /** * Invoke the function. * @return the result. */ public TVMValue invoke() { Base.RefTVMValue ret = new Base.RefTVMValue(); - Base.checkCall(Base._LIB.tvmFuncCall(handle, ret)); + Base.checkCall(Base._LIB.tvmFFIFunctionCall(handle, ret)); return ret.value; } @@ -137,7 +92,7 @@ public TVMValue invoke() { * @return this */ public Function pushArg(int arg) { - Base._LIB.tvmFuncPushArgLong(arg); + Base._LIB.tvmFFIFunctionPushArgLong(arg); return this; } @@ -147,7 +102,7 @@ public Function pushArg(int arg) { * @return this */ public Function pushArg(long arg) { - Base._LIB.tvmFuncPushArgLong(arg); + Base._LIB.tvmFFIFunctionPushArgLong(arg); return this; } @@ -157,7 +112,7 @@ public Function pushArg(long arg) { * @return this */ public Function pushArg(float arg) { - Base._LIB.tvmFuncPushArgDouble(arg); + Base._LIB.tvmFFIFunctionPushArgDouble(arg); return this; } @@ -167,7 +122,7 @@ public Function pushArg(float arg) { * @return this */ public Function pushArg(double arg) { - Base._LIB.tvmFuncPushArgDouble(arg); + Base._LIB.tvmFFIFunctionPushArgDouble(arg); return this; } @@ -177,7 +132,7 @@ public Function pushArg(double arg) { * @return this */ public Function pushArg(String arg) { - Base._LIB.tvmFuncPushArgString(arg); + Base._LIB.tvmFFIFunctionPushArgString(arg); return this; } @@ -187,8 +142,11 @@ public Function pushArg(String arg) { * @return this */ public Function pushArg(NDArrayBase arg) { - int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; - Base._LIB.tvmFuncPushArgHandle(arg.handle, id); + if (arg instanceof NDArray) { + Base._LIB.tvmFFIFunctionPushArgHandle(((NDArray) arg).handle, TypeIndex.kTVMFFINDArray); + } else { + Base._LIB.tvmFFIFunctionPushArgHandle(arg.dltensorHandle, TypeIndex.kTVMFFIDLTensorPtr); + } return this; } @@ -198,7 +156,7 @@ public Function pushArg(NDArrayBase arg) { * @return this */ public Function pushArg(Module arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFFIFunctionPushArgHandle(arg.handle, TypeIndex.kTVMFFIModule); return this; } @@ -208,7 +166,7 @@ public Function pushArg(Module arg) { * @return this */ public Function pushArg(Function arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFFIFunctionPushArgHandle(arg.handle, TypeIndex.kTVMFFIFunction); return this; } @@ -218,7 +176,7 @@ public Function pushArg(Function arg) { * @return this */ public Function pushArg(byte[] arg) { - Base._LIB.tvmFuncPushArgBytes(arg); + Base._LIB.tvmFFIFunctionPushArgBytes(arg); return this; } @@ -228,7 +186,7 @@ public Function pushArg(byte[] arg) { * @return this */ public Function pushArg(Device arg) { - Base._LIB.tvmFuncPushArgDevice(arg); + Base._LIB.tvmFFIFunctionPushArgDevice(arg); return this; } @@ -245,53 +203,44 @@ public TVMValue call(Object... args) { } private static void pushArgToStack(Object arg) { - if (arg instanceof Integer) { - Base._LIB.tvmFuncPushArgLong((Integer) arg); + if (arg instanceof NDArrayBase) { + NDArrayBase nd = (NDArrayBase) arg; + if (nd instanceof NDArray) { + Base._LIB.tvmFFIFunctionPushArgHandle(((NDArray) nd).handle, TypeIndex.kTVMFFINDArray); + } else { + Base._LIB.tvmFFIFunctionPushArgHandle(nd.dltensorHandle, TypeIndex.kTVMFFIDLTensorPtr); + } + } else if (arg instanceof TVMObject) { + TVMObject obj = (TVMObject) arg; + Base._LIB.tvmFFIFunctionPushArgHandle(obj.handle, obj.typeIndex); + } else if (arg instanceof Integer) { + Base._LIB.tvmFFIFunctionPushArgLong((Integer) arg); } else if (arg instanceof Long) { - Base._LIB.tvmFuncPushArgLong((Long) arg); + Base._LIB.tvmFFIFunctionPushArgLong((Long) arg); } else if (arg instanceof Float) { - Base._LIB.tvmFuncPushArgDouble((Float) arg); + Base._LIB.tvmFFIFunctionPushArgDouble((Float) arg); } else if (arg instanceof Double) { - Base._LIB.tvmFuncPushArgDouble((Double) arg); + Base._LIB.tvmFFIFunctionPushArgDouble((Double) arg); } else if (arg instanceof String) { - Base._LIB.tvmFuncPushArgString((String) arg); + Base._LIB.tvmFFIFunctionPushArgString((String) arg); } else if (arg instanceof byte[]) { - Base._LIB.tvmFuncPushArgBytes((byte[]) arg); - } else if (arg instanceof NDArrayBase) { - NDArrayBase nd = (NDArrayBase) arg; - int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; - Base._LIB.tvmFuncPushArgHandle(nd.handle, id); - } else if (arg instanceof Module) { - Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); - } else if (arg instanceof Function) { - Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFFIFunctionPushArgBytes((byte[]) arg); } else if (arg instanceof Device) { - Base._LIB.tvmFuncPushArgDevice((Device) arg); - } else if (arg instanceof TVMValue) { - TVMValue tvmArg = (TVMValue) arg; - switch (tvmArg.typeCode) { - case UINT: - case INT: - Base._LIB.tvmFuncPushArgLong(tvmArg.asLong()); - break; - case FLOAT: - Base._LIB.tvmFuncPushArgDouble(tvmArg.asDouble()); - break; - case STR: - Base._LIB.tvmFuncPushArgString(tvmArg.asString()); - break; - case BYTES: - Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes()); - break; - case HANDLE: - case ARRAY_HANDLE: - case MODULE_HANDLE: - case FUNC_HANDLE: - Base._LIB.tvmFuncPushArgHandle(tvmArg.asHandle(), tvmArg.typeCode.id); - break; - default: - throw new IllegalArgumentException("Invalid argument: " + arg); - } + Base._LIB.tvmFFIFunctionPushArgDevice((Device) arg); + } else if (arg instanceof TVMValueBytes) { + byte[] bytes = ((TVMValueBytes) arg).value; + Base._LIB.tvmFFIFunctionPushArgBytes(bytes); + } else if (arg instanceof TVMValueString) { + String str = ((TVMValueString) arg).value; + Base._LIB.tvmFFIFunctionPushArgString(str); + } else if (arg instanceof TVMValueDouble) { + double value = ((TVMValueDouble) arg).value; + Base._LIB.tvmFFIFunctionPushArgDouble(value); + } else if (arg instanceof TVMValueLong) { + long value = ((TVMValueLong) arg).value; + Base._LIB.tvmFFIFunctionPushArgLong(value); + } else if (arg instanceof TVMValueNull) { + Base._LIB.tvmFFIFunctionPushArgHandle(0, TypeIndex.kTVMFFINone); } else { throw new IllegalArgumentException("Invalid argument: " + arg); } @@ -309,9 +258,9 @@ public static interface Callback { */ public static void register(String name, Callback function, boolean override) { Base.RefLong createdFuncHandleRef = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef)); + Base.checkCall(Base._LIB.tvmFFIFunctionCreateFromCallback(function, createdFuncHandleRef)); int ioverride = override ? 1 : 0; - Base.checkCall(Base._LIB.tvmFuncRegisterGlobal(name, createdFuncHandleRef.value, ioverride)); + Base.checkCall(Base._LIB.tvmFFIFunctionSetGlobal(name, createdFuncHandleRef.value, ioverride)); } /** @@ -330,7 +279,7 @@ public static void register(String name, Callback function) { */ public static Function convertFunc(Callback function) { Base.RefLong createdFuncHandleRef = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef)); + Base.checkCall(Base._LIB.tvmFFIFunctionCreateFromCallback(function, createdFuncHandleRef)); return new Function(createdFuncHandleRef.value); } diff --git a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java index aede9be334c8..f471883ca5bc 100644 --- a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java +++ b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java @@ -24,55 +24,50 @@ class LibInfo { native int shutdown(); - native String tvmGetLastError(); + native String tvmFFIGetLastError(); - // Function - native void tvmFuncPushArgLong(long arg); - - native void tvmFuncPushArgDouble(double arg); - - native void tvmFuncPushArgString(String arg); - - native void tvmFuncPushArgBytes(byte[] arg); + // Object + native int tvmFFIObjectFree(long handle); - native void tvmFuncPushArgHandle(long arg, int argType); + // Function + native void tvmFFIFunctionPushArgLong(long arg); - native void tvmFuncPushArgDevice(Device device); + native void tvmFFIFunctionPushArgDouble(double arg); - native int tvmFuncListGlobalNames(List funcNames); + native void tvmFFIFunctionPushArgString(String arg); - native int tvmFuncFree(long handle); + native void tvmFFIFunctionPushArgBytes(byte[] arg); - native int tvmFuncGetGlobal(String name, Base.RefLong handle); + native void tvmFFIFunctionPushArgHandle(long arg, int argTypeIndex); - native int tvmFuncCall(long handle, Base.RefTVMValue retVal); + native void tvmFFIFunctionPushArgDevice(Device device); - native int tvmFuncCreateFromCFunc(Function.Callback function, Base.RefLong handle); + native int tvmFFIFunctionListGlobalNames(List funcNames); - native int tvmFuncRegisterGlobal(String name, long handle, int override); + native int tvmFFIFunctionGetGlobal(String name, Base.RefLong handle); - // Module - native int tvmModFree(long handle); + native int tvmFFIFunctionSetGlobal(String name, long handle, int override); - native int tvmModGetFunction(long handle, String name, - int queryImports, Base.RefLong retHandle); + native int tvmFFIFunctionCall(long handle, Base.RefTVMValue retVal); - native int tvmModImport(long mod, long dep); + native int tvmFFIFunctionCreateFromCallback(Function.Callback function, Base.RefLong handle); // NDArray - native int tvmArrayFree(long handle); - - native int tvmArrayAlloc(long[] shape, int dtypeCode, int dtypeBits, int dtypeLanes, - int deviceType, int deviceId, Base.RefLong refHandle); + native int tvmFFIDLTensorGetShape(long handle, List shape); - native int tvmArrayGetShape(long handle, List shape); + native int tvmFFIDLTensorCopyFromTo(long from, long to); - native int tvmArrayCopyFromTo(long from, long to); + native int tvmFFIDLTensorCopyFromJArray(byte[] fromRaw, long to); - native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to); - - native int tvmArrayCopyToJArray(long from, byte[] to); + native int tvmFFIDLTensorCopyToJArray(long from, byte[] to); + // the following functions are binded to keep things simpler + // One possibility is to enhance FFI to support shape directly + // so we do not need to run this binding through JNI // Device native int tvmSynchronize(int deviceType, int deviceId); + + native int tvmNDArrayEmpty(long[] shape, int dtypeCode, int dtypeBits, + int dtypeLanes, int deviceType, int deviceId, + Base.RefLong handle); } diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 0682a6595a3e..5e78e26ae739 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -23,10 +23,7 @@ /** * Container of compiled functions of TVM. */ -public class Module extends TVMValue { - public final long handle; - private boolean isReleased = false; - +public class Module extends TVMObject { private static ThreadLocal> apiFuncs = new ThreadLocal>() { @Override @@ -45,17 +42,12 @@ private static Function getApi(String name) { } Module(long handle) { - super(ArgTypeCode.MODULE_HANDLE); - this.handle = handle; + super(handle, TypeIndex.kTVMFFIModule); } private Function entry = null; private final String entryName = "__tvm_main__"; - @Override protected void finalize() throws Throwable { - release(); - super.finalize(); - } /** * Easy for user to get the instance from returned TVMValue. @@ -65,23 +57,6 @@ private static Function getApi(String name) { return this; } - @Override long asHandle() { - return handle; - } - - /** - * Release the Module. - *

- * We highly recommend you to do this manually since the GC strategy is lazy. - *

- */ - @Override public void release() { - if (!isReleased) { - Base.checkCall(Base._LIB.tvmModFree(handle)); - isReleased = true; - } - } - /** * Get the entry function. * @return The entry function if exist @@ -100,13 +75,9 @@ public Function entryFunc() { * @return The result function. */ public Function getFunction(String name, boolean queryImports) { - Base.RefLong retHandle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmModGetFunction( - handle, name, queryImports ? 1 : 0, retHandle)); - if (retHandle.value == 0) { - throw new IllegalArgumentException("Module has no function " + name); - } - return new Function(retHandle.value, false); + TVMValue ret = getApi("ModuleGetFunction") + .pushArg(this).pushArg(name).pushArg(queryImports ? 1 : 0).invoke(); + return ret.asFunction(); } public Function getFunction(String name) { @@ -118,7 +89,8 @@ public Function getFunction(String name) { * @param module The other module. */ public void importModule(Module module) { - Base.checkCall(Base._LIB.tvmModImport(handle, module.handle)); + getApi("ModuleImport") + .pushArg(this).pushArg(module).invoke(); } /** @@ -138,7 +110,6 @@ public String typeKey() { */ public static Module load(String path, String fmt) { TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); - assert ret.typeCode == ArgTypeCode.MODULE_HANDLE; return ret.asModule(); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArray.java b/jvm/core/src/main/java/org/apache/tvm/NDArray.java index 68020db03999..6b151d7bf9d2 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArray.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArray.java @@ -35,11 +35,6 @@ public class NDArray extends NDArrayBase { this.device = dev; } - @Override - protected void finalize() throws Throwable { - super.finalize(); - } - /** * Copy from a native array. * The NDArray type must by float64 @@ -54,9 +49,7 @@ public void copyFrom(double[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putDouble(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -73,9 +66,7 @@ public void copyFrom(float[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putFloat(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -92,9 +83,7 @@ public void copyFrom(long[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putLong(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -111,9 +100,7 @@ public void copyFrom(int[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putInt(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -130,9 +117,7 @@ public void copyFrom(short[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putShort(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } /** @@ -162,9 +147,7 @@ public void copyFrom(char[] sourceArray) { for (int i = 0; i < sourceArray.length; ++i) { wrapBytes(nativeArr, i * dtype.numOfBytes, dtype.numOfBytes).putChar(sourceArray[i]); } - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(nativeArr, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(nativeArr, this.dltensorHandle)); } private void checkCopySize(int sourceLength) { @@ -180,9 +163,7 @@ private void checkCopySize(int sourceLength) { * @param sourceArray the source data */ public void copyFromRaw(byte[] sourceArray) { - NDArray tmpArr = empty(shape(), this.dtype); - Base.checkCall(Base._LIB.tvmArrayCopyFromJArray(sourceArray, tmpArr.handle, handle)); - tmpArr.release(); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromJArray(sourceArray, this.dltensorHandle)); } /** @@ -191,7 +172,7 @@ public void copyFromRaw(byte[] sourceArray) { */ public long[] shape() { List data = new ArrayList(); - Base.checkCall(Base._LIB.tvmArrayGetShape(handle, data)); + Base.checkCall(Base._LIB.tvmFFIDLTensorGetShape(this.dltensorHandle, data)); long[] shapeArr = new long[data.size()]; for (int i = 0; i < shapeArr.length; ++i) { shapeArr[i] = data.get(i); @@ -343,7 +324,7 @@ public byte[] internal() { int arrLength = dtype.numOfBytes * (int) size(); byte[] arr = new byte[arrLength]; - Base.checkCall(Base._LIB.tvmArrayCopyToJArray(tmp.handle, arr)); + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyToJArray(this.dltensorHandle, arr)); return arr; } @@ -380,8 +361,9 @@ public Device device() { */ public static NDArray empty(long[] shape, TVMType dtype, Device dev) { Base.RefLong refHandle = new Base.RefLong(); - Base.checkCall(Base._LIB.tvmArrayAlloc( - shape, dtype.typeCode, dtype.bits, dtype.lanes, dev.deviceType, dev.deviceId, refHandle)); + Base.checkCall(Base._LIB.tvmNDArrayEmpty( + shape, dtype.typeCode, dtype.bits, + dtype.lanes, dev.deviceType, dev.deviceId, refHandle)); return new NDArray(refHandle.value, false, dtype, dev); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java index 26bb735e1a5b..534dcb38d4a9 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java @@ -22,50 +22,27 @@ * Only deep-copy supported. */ public class NDArrayBase extends TVMValue { - protected final long handle; - protected final boolean isView; - private boolean isReleased = false; + protected long handle; + public final boolean isView; + protected final long dltensorHandle; NDArrayBase(long handle, boolean isView) { - super(ArgTypeCode.ARRAY_HANDLE); - this.handle = handle; + this.dltensorHandle = isView ? handle : handle + 8 * 2; + this.handle = isView ? 0 : handle; this.isView = isView; } - NDArrayBase(long handle) { - this(handle, true); - } - @Override public NDArrayBase asNDArray() { return this; } - @Override long asHandle() { - return handle; - } - - /** - * Copy array to target. - * @param target The target array to be copied, must have same shape as this array. - * @return target - */ - public NDArrayBase copyTo(NDArrayBase target) { - Base.checkCall(Base._LIB.tvmArrayCopyFromTo(handle, target.handle)); - return target; - } - /** - * Release the NDArray memory. - *

- * We highly recommend you to do this manually since the GC strategy is lazy. - *

+ * Release the NDArray. */ public void release() { - if (!isReleased) { - if (!isView) { - Base.checkCall(Base._LIB.tvmArrayFree(handle)); - isReleased = true; - } + if (this.handle != 0) { + Base.checkCall(Base._LIB.tvmFFIObjectFree(this.handle)); + this.handle = 0; } } @@ -73,4 +50,14 @@ public void release() { release(); super.finalize(); } + + /** + * Copy array to target. + * @param target The target array to be copied, must have same shape as this array. + * @return target + */ + public NDArrayBase copyTo(NDArrayBase target) { + Base.checkCall(Base._LIB.tvmFFIDLTensorCopyFromTo(this.dltensorHandle, target.dltensorHandle)); + return target; + } } diff --git a/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java b/jvm/core/src/main/java/org/apache/tvm/TVMObject.java similarity index 65% rename from jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java rename to jvm/core/src/main/java/org/apache/tvm/TVMObject.java index ed6d0f1a0e12..c2b3f0eb497f 100644 --- a/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMObject.java @@ -17,20 +17,25 @@ package org.apache.tvm; -// Type code used in API calls -public enum ArgTypeCode { - INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), - DLDEVICE(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), - FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13); +/** + * Base class of all TVM Objects. + */ +public class TVMObject extends TVMValue { + protected long handle; + public final int typeIndex; - public final int id; + public TVMObject(long handle, int typeIndex) { + this.handle = handle; + this.typeIndex = typeIndex; + } - private ArgTypeCode(int id) { - this.id = id; + public void release() { + Base.checkCall(Base._LIB.tvmFFIObjectFree(this.handle)); + this.handle = 0; } - @Override - public String toString() { - return String.valueOf(id); + @Override protected void finalize() throws Throwable { + release(); + super.finalize(); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java index d30cfcc4f30a..45aef808f44c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java @@ -18,10 +18,8 @@ package org.apache.tvm; public class TVMValue { - public final ArgTypeCode typeCode; + protected TVMValue() { - public TVMValue(ArgTypeCode tc) { - typeCode = tc; } public void release() { diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java index 132d88f7622b..253dcbe66c87 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java @@ -21,7 +21,6 @@ public class TVMValueBytes extends TVMValue { public final byte[] value; public TVMValueBytes(byte[] value) { - super(ArgTypeCode.BYTES); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java index 9db4c3bb0e8c..16351b3244ea 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java @@ -21,7 +21,6 @@ public class TVMValueDouble extends TVMValue { public final double value; public TVMValueDouble(double value) { - super(ArgTypeCode.FLOAT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java index b91f55e2f59b..849510ec3078 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java @@ -24,7 +24,6 @@ public class TVMValueHandle extends TVMValue { public final long value; public TVMValueHandle(long value) { - super(ArgTypeCode.HANDLE); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java index 8a9b157d3961..0c232adf42b8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java @@ -21,7 +21,6 @@ public class TVMValueLong extends TVMValue { public final long value; public TVMValueLong(long value) { - super(ArgTypeCode.INT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java index 8c49ee5b3df5..45e85a160728 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java @@ -19,6 +19,5 @@ public class TVMValueNull extends TVMValue { public TVMValueNull() { - super(ArgTypeCode.NULL); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java index 46926e7d3fc6..c93a5600931e 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java @@ -21,7 +21,6 @@ public class TVMValueString extends TVMValue { public final String value; public TVMValueString(String value) { - super(ArgTypeCode.STR); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java b/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java new file mode 100644 index 000000000000..97169bb6c58c --- /dev/null +++ b/jvm/core/src/main/java/org/apache/tvm/TypeIndex.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tvm; + +// Type code used in API calls +public class TypeIndex { + public static final int kTVMFFINone = 0; + public static final int kTVMFFIInt = 1; + public static final int kTVMFFIBool = 2; + public static final int kTVMFFIFloat = 3; + public static final int kTVMFFIOpaquePtr = 4; + public static final int kTVMFFIDataType = 5; + public static final int kTVMFFIDevice = 6; + public static final int kTVMFFIDLTensorPtr = 7; + public static final int kTVMFFIRawStr = 8; + public static final int kTVMFFIByteArrayPtr = 9; + public static final int kTVMFFIObjectRValueRef = 10; + public static final int kTVMFFIStaticObjectBegin = 64; + public static final int kTVMFFIObject = 64; + public static final int kTVMFFIStr = 65; + public static final int kTVMFFIBytes = 66; + public static final int kTVMFFIError = 67; + public static final int kTVMFFIFunction = 68; + public static final int kTVMFFIArray = 69; + public static final int kTVMFFIMap = 70; + public static final int kTVMFFIShape = 71; + public static final int kTVMFFINDArray = 72; + public static final int kTVMFFIModule = 73; +} diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 69321c3b51c8..4b20362c2a47 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -20,6 +20,9 @@ import org.apache.tvm.Function; import org.apache.tvm.TVMValue; +/** + * RPC Client. + */ public class Client { /** * Connect to RPC Server. diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 07278f07b8c2..f3cf95f931cb 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -17,15 +17,16 @@ package org.apache.tvm.rpc; +import org.apache.tvm.Device; +import org.apache.tvm.Function; +import org.apache.tvm.Module; + import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.HashMap; import java.util.Map; -import org.apache.tvm.Device; -import org.apache.tvm.Function; -import org.apache.tvm.Module; /** * RPC Client session module. diff --git a/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java b/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java index 9ffcc5ab65ea..c2a1f78fa432 100644 --- a/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/FunctionTest.java @@ -43,6 +43,8 @@ public void test_reg_sum_number() { @Test public void test_add_string() { + System.err.println("[TEST] test_add_string"); + Function func = Function.convertFunc(new Function.Callback() { @Override public Object invoke(TVMValue... args) { String res = ""; diff --git a/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java b/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java index b9538ca96b5d..888cd18923be 100644 --- a/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/ModuleTest.java @@ -71,8 +71,6 @@ public void test_load_add_func_cuda() { } Module fadd = Module.load(loadingDir + File.separator + "add_cuda.so"); - Module faddDev = Module.load(loadingDir + File.separator + "add_cuda.ptx"); - fadd.importModule(faddDev); final int dim = 100; long[] shape = new long[]{dim}; @@ -93,7 +91,6 @@ public void test_load_add_func_cuda() { arr.release(); res.release(); - faddDev.release(); fadd.release(); } } diff --git a/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java index 641633def8a0..ca24c123da8e 100644 --- a/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java +++ b/jvm/core/src/test/java/org/apache/tvm/rpc/RPCTest.java @@ -31,6 +31,7 @@ public class RPCTest { private final Logger logger = LoggerFactory.getLogger(RPCTest.class); + @Ignore("RPC test is not enabled") @Test public void test_addone() { if (!Module.enabled("rpc")) { @@ -57,6 +58,7 @@ public void test_addone() { } } + @Ignore("RPC test is not enabled") @Test public void test_strcat() { if (!Module.enabled("rpc")) { diff --git a/jvm/core/src/test/scripts/prepare_test_libs.py b/jvm/core/src/test/scripts/prepare_test_libs.py new file mode 100644 index 000000000000..550082adb816 --- /dev/null +++ b/jvm/core/src/test/scripts/prepare_test_libs.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# Prepare test library for standalone wasm runtime test. + +import sys +import os +import tvm +from tvm import te +from tvm import relax +from tvm.script import relax as R + + +def prepare_relax_lib(base_path): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor(["n"], "float32"), y: R.Tensor(["n"], "float32")): + lv0 = R.add(x, y) + return lv0 + + target = tvm.target.Target("llvm") + + mod = pipeline(Mod) + ex = relax.build(mod, target) + relax_path = os.path.join(base_path, "add_relax.so") + ex.export_library(relax_path) + + +def prepare_cpu_lib(base_path): + target = "llvm" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") + mod = tvm.IRModule.from_expr(te.create_prim_func([A, B, C]).with_attr("global_symbol", "myadd")) + fadd = tvm.build(mod, target) + lib_path = os.path.join(base_path, "add_cpu.so") + fadd.export_library(lib_path) + + +def prepare_gpu_lib(base_path): + if not tvm.cuda().exist: + print("CUDA is not enabled, skip the generation") + return + n = te.var("n") + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") + mod = tvm.IRModule.from_expr(te.create_prim_func([A, B, C]).with_attr("global_symbol", "myadd")) + sch = tvm.tir.Schedule(mod) + sch.work_on("myadd") + (i,) = sch.get_loops(block=sch.get_block("C")) + i0, i1 = sch.split(i, [None, 32]) + sch.bind(i0, "blockIdx.x") + sch.bind(i1, "threadIdx.x") + fadd = tvm.build(sch.mod, "cuda") + lib_path = os.path.join(base_path, "add_cuda.so") + fadd.export_library(lib_path) + + +if __name__ == "__main__": + base_path = sys.argv[1] + prepare_cpu_lib(base_path) + prepare_gpu_lib(base_path) + prepare_relax_lib(base_path) diff --git a/jvm/native/linux-x86_64/pom.xml b/jvm/native/linux-x86_64/pom.xml index 10d9a1bbfe3c..c21a3d2ae5af 100644 --- a/jvm/native/linux-x86_64/pom.xml +++ b/jvm/native/linux-x86_64/pom.xml @@ -127,6 +127,8 @@ under the License. -shared + -L${project.parent.basedir}/../../build/ + -ltvm_runtime ${ldflags} diff --git a/jvm/native/osx-x86_64/pom.xml b/jvm/native/osx-x86_64/pom.xml index ef28537b98f2..e2bd0fd7ae9d 100644 --- a/jvm/native/osx-x86_64/pom.xml +++ b/jvm/native/osx-x86_64/pom.xml @@ -134,6 +134,8 @@ under the License. -Wl,-x + -L${project.parent.basedir}/../../build/ + -ltvm_runtime ${ldflags} diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index 3e44f757392d..76520d43f7a9 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -113,8 +113,8 @@ jobject newTVMValueDouble(JNIEnv* env, jdouble value) { return object; } -jobject newTVMValueString(JNIEnv* env, const char* value) { - jstring jvalue = env->NewStringUTF(value); +jobject newTVMValueString(JNIEnv* env, const TVMFFIByteArray* value) { + jstring jvalue = env->NewStringUTF(value->data); jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jmethodID constructor = env->GetMethodID(cls, "", "(Ljava/lang/String;)V"); jobject object = env->NewObject(cls, constructor, jvalue); @@ -123,7 +123,7 @@ jobject newTVMValueString(JNIEnv* env, const char* value) { return object; } -jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) { +jobject newTVMValueBytes(JNIEnv* env, const TVMFFIByteArray* arr) { jbyteArray jarr = env->NewByteArray(arr->size); env->SetByteArrayRegion(jarr, 0, arr->size, reinterpret_cast(const_cast(arr->data))); @@ -159,14 +159,22 @@ jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) { return object; } -jobject newObject(JNIEnv* env, const char* clsname) { - jclass cls = env->FindClass(clsname); +jobject newTVMNull(JNIEnv* env) { + jclass cls = env->FindClass("org/apache/tvm/TVMValueNull"); jmethodID constructor = env->GetMethodID(cls, "", "()V"); jobject object = env->NewObject(cls, constructor); env->DeleteLocalRef(cls); return object; } +jobject newTVMObject(JNIEnv* env, jlong handle, jint type_index) { + jclass cls = env->FindClass("org/apache/tvm/TVMObject"); + jmethodID constructor = env->GetMethodID(cls, "", "(JI)V"); + jobject object = env->NewObject(cls, constructor, handle, type_index); + env->DeleteLocalRef(cls); + return object; +} + void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) { jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType"); dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I"))); @@ -184,55 +192,56 @@ void fromJavaDevice(JNIEnv* env, jobject jdev, DLDevice* dev) { env->DeleteLocalRef(deviceClass); } -jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { - switch (tcode) { - case kDLUInt: - case kDLInt: - case kTVMArgBool: +jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) { + using tvm::ffi::TypeIndex; + switch (value.type_index) { + case TypeIndex::kTVMFFINone: { + return newTVMNull(env); + } + case TypeIndex::kTVMFFIBool: { + // use long for now to represent bool + return newTVMValueLong(env, static_cast(value.v_int64)); + } + case TypeIndex::kTVMFFIInt: { return newTVMValueLong(env, static_cast(value.v_int64)); - case kDLFloat: + } + case TypeIndex::kTVMFFIFloat: { return newTVMValueDouble(env, static_cast(value.v_float64)); - case kTVMOpaqueHandle: - return newTVMValueHandle(env, reinterpret_cast(value.v_handle)); - case kTVMModuleHandle: - return newModule(env, reinterpret_cast(value.v_handle)); - case kTVMPackedFuncHandle: - return newFunction(env, reinterpret_cast(value.v_handle)); - case kTVMDLTensorHandle: - return newNDArray(env, reinterpret_cast(value.v_handle), true); - case kTVMNDArrayHandle: - return newNDArray(env, reinterpret_cast(value.v_handle), false); - case kTVMStr: - return newTVMValueString(env, value.v_str); - case kTVMBytes: - return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); - case kTVMNullptr: - return newObject(env, "org/apache/tvm/TVMValueNull"); - default: - LOG(FATAL) << "Do NOT know how to handle return type code " << tcode; - } - return NULL; -} - -// Helper function to pack two int32_t values into an int64_t -inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) { - int64_t result; - int32_t* parts = reinterpret_cast(&result); - - // Lambda function to check endianness - const auto isLittleEndian = []() -> bool { - uint32_t i = 1; - return *reinterpret_cast(&i) == 1; - }; - - if (isLittleEndian()) { - parts[0] = device_type; - parts[1] = device_id; - } else { - parts[1] = device_type; - parts[0] = device_id; + } + case TypeIndex::kTVMFFIOpaquePtr: { + return newTVMValueHandle(env, reinterpret_cast(value.v_ptr)); + } + case TypeIndex::kTVMFFIModule: { + return newModule(env, reinterpret_cast(value.v_obj)); + } + case TypeIndex::kTVMFFIFunction: { + return newFunction(env, reinterpret_cast(value.v_obj)); + } + case TypeIndex::kTVMFFIDLTensorPtr: { + return newNDArray(env, reinterpret_cast(value.v_ptr), true); + } + case TypeIndex::kTVMFFINDArray: { + return newNDArray(env, reinterpret_cast(value.v_obj), false); + } + case TypeIndex::kTVMFFIStr: { + jobject ret = newTVMValueString(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); + TVMFFIObjectFree(value.v_obj); + return ret; + } + case TypeIndex::kTVMFFIBytes: { + jobject ret = newTVMValueBytes(env, TVMFFIBytesGetByteArrayPtr(value.v_obj)); + TVMFFIObjectFree(value.v_obj); + return ret; + } + default: { + if (value.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + return newTVMObject(env, reinterpret_cast(value.v_obj), value.type_index); + } + TVM_FFI_THROW(RuntimeError) << "Do NOT know how to handle return type_index " + << value.type_index; + TVM_FFI_UNREACHABLE(); + } } - return result; } #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 77bc8d636098..a5481dd9ac54 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -25,12 +25,14 @@ #include "tvm_runtime.h" #else #include -#include -#include -#include +#include +#include +#include +#include #endif #include #include +#include #include #include @@ -38,14 +40,18 @@ JavaVM* _jvm; void* _tvmHandle = nullptr; -struct TVMFuncArgsThreadLocalEntry { - std::vector tvmFuncArgValues; - std::vector tvmFuncArgTypes; + +struct TVMFFIJVMStack { + std::vector packed_args; // for later release - std::vector> tvmFuncArgPushedStrs; - std::vector> tvmFuncArgPushedBytes; + std::vector> str_args; + std::vector>> byte_args; + + static TVMFFIJVMStack* ThreadLocal() { + static thread_local TVMFFIJVMStack stack; + return &stack; + } }; -typedef dmlc::ThreadLocalStore TVMFuncArgsThreadLocalStore; JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj, jstring jtvmLibFile) { @@ -68,172 +74,132 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject return 0; } -JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv* env, jobject obj) { - return env->NewStringUTF(TVMGetLastError()); +JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmFFIGetLastError(JNIEnv* env, jobject obj) { + std::string err_msg = ::tvm::ffi::details::MoveFromSafeCallRaised().what(); + return env->NewStringUTF(err_msg.c_str()); } // Function -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(JNIEnv* env, jobject obj, - jlong arg) { - TVMValue value; - value.v_int64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kDLInt); +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgLong(JNIEnv* env, + jobject obj, + jlong arg) { + TVMFFIJVMStack::ThreadLocal()->packed_args.emplace_back(static_cast(arg)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(JNIEnv* env, jobject obj, - jdouble arg) { - TVMValue value; - value.v_float64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kDLFloat); +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgDouble(JNIEnv* env, + jobject obj, + jdouble arg) { + TVMFFIJVMStack::ThreadLocal()->packed_args.emplace_back(static_cast(arg)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(JNIEnv* env, jobject obj, - jstring arg) { - TVMValue value; +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgString(JNIEnv* env, + jobject obj, + jstring arg) { jstring garg = reinterpret_cast(env->NewGlobalRef(arg)); - value.v_str = env->GetStringUTFChars(garg, 0); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kTVMStr); - // release string args later - e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str)); + const char* str = env->GetStringUTFChars(garg, 0); + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + stack->str_args.emplace_back(garg, str); + stack->packed_args.emplace_back(str); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* env, jobject obj, - jlong arg, jint argType) { - TVMValue value; - value.v_handle = reinterpret_cast(arg); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(static_cast(argType)); +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgHandle(JNIEnv* env, + jobject obj, + jlong arg, + jint argTypeIndex) { + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + TVMFFIAny temp; + temp.v_int64 = static_cast(arg); + temp.type_index = static_cast(argTypeIndex); + stack->packed_args.emplace_back(tvm::ffi::AnyView::CopyFromTVMFFIAny(temp)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj, - jobject arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgDevice(JNIEnv* env, + jobject obj, + jobject arg) { jclass deviceClass = env->FindClass("org/apache/tvm/Device"); jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I"); jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I"); jint deviceType = env->GetIntField(arg, deviceTypeField); jint deviceId = env->GetIntField(arg, deviceIdField); - - TVMValue value; - value.v_int64 = deviceToInt64(deviceType, deviceId); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kDLDevice); + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + stack->packed_args.emplace_back(DLDevice{static_cast(deviceType), deviceId}); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, - jbyteArray arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgBytes(JNIEnv* env, + jobject obj, + jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); jbyte* data = env->GetByteArrayElements(garg, 0); - TVMByteArray* byteArray = new TVMByteArray(); + std::unique_ptr byteArray = std::make_unique(); byteArray->size = static_cast(env->GetArrayLength(garg)); byteArray->data = reinterpret_cast(data); - TVMValue value; - value.v_handle = reinterpret_cast(byteArray); - - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - e->tvmFuncArgValues.push_back(value); - e->tvmFuncArgTypes.push_back(kTVMBytes); - - e->tvmFuncArgPushedBytes.push_back(std::make_pair(garg, byteArray)); + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + stack->packed_args.emplace_back(byteArray.get()); + stack->byte_args.emplace_back(garg, std::move(byteArray)); // release (garg, data), byteArray later } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(JNIEnv* env, jobject obj, - jobject jfuncNames) { - int outSize; - const char** outArray; - - int ret = TVMFuncListGlobalNames(&outSize, &outArray); - if (ret) { - return ret; - } - +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionListGlobalNames( + JNIEnv* env, jobject obj, jobject jfuncNames) { + TVM_FFI_SAFE_CALL_BEGIN(); jclass arrayClass = env->FindClass("java/util/List"); jmethodID arrayAppend = env->GetMethodID(arrayClass, "add", "(Ljava/lang/Object;)Z"); - // fill names - for (int i = 0; i < outSize; ++i) { - jstring jname = env->NewStringUTF(outArray[i]); + for (const auto& name : tvm::ffi::Function::ListGlobalNames()) { + jstring jname = env->NewStringUTF(name.c_str()); env->CallBooleanMethod(jfuncNames, arrayAppend, jname); env->DeleteLocalRef(jname); } env->DeleteLocalRef(arrayClass); - - return ret; -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(JNIEnv* env, jobject obj, - jlong jhandle) { - return TVMFuncFree(reinterpret_cast(jhandle)); + TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(JNIEnv* env, jobject obj, - jstring jname, - jobject jhandle) { - TVMFunctionHandle handle; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionGetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jobject jhandle) { const char* name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncGetGlobal(name, &handle); + TVMFFIByteArray name_bytes{name, strlen(name)}; + TVMFFIObjectHandle handle; + int ret = TVMFFIFunctionGetGlobal(&name_bytes, &handle); env->ReleaseStringUTFChars(jname, name); setLongField(env, jhandle, reinterpret_cast(handle)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(JNIEnv* env, jobject obj, - jlong jhandle, jobject jretVal) { - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - int numArgs = e->tvmFuncArgValues.size(); - - TVMValue retVal; - int retTypeCode; - - // function can be invoked recursively, - // thus we copy the pushed arguments here. - auto argValues = e->tvmFuncArgValues; - auto argTypes = e->tvmFuncArgTypes; - auto pushedStrs = e->tvmFuncArgPushedStrs; - auto pushedBytes = e->tvmFuncArgPushedBytes; - - e->tvmFuncArgPushedStrs.clear(); - e->tvmFuncArgPushedBytes.clear(); - e->tvmFuncArgTypes.clear(); - e->tvmFuncArgValues.clear(); - - int ret = TVMFuncCall(reinterpret_cast(jhandle), &argValues[0], &argTypes[0], - numArgs, &retVal, &retTypeCode); - - if (ret != 0) { - return ret; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCall(JNIEnv* env, jobject obj, + jlong jhandle, + jobject jretVal) { + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + TVMFFIAny ret_val; + ret_val.type_index = tvm::ffi::TypeIndex::kTVMFFINone; + ret_val.v_int64 = 0; + int ret = TVMFFIFunctionCall(reinterpret_cast(jhandle), + reinterpret_cast(stack->packed_args.data()), + stack->packed_args.size(), &ret_val); + // release all temp resources + for (auto& str_pair : stack->str_args) { + env->ReleaseStringUTFChars(str_pair.first, str_pair.second); + env->DeleteGlobalRef(str_pair.first); } - for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) { - env->ReleaseStringUTFChars(iter->first, iter->second); - env->DeleteGlobalRef(iter->first); - } - for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { + for (auto& byte_pair : stack->byte_args) { env->ReleaseByteArrayElements( - iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); - env->DeleteGlobalRef(iter->first); - delete iter->second; + byte_pair.first, reinterpret_cast(const_cast(byte_pair.second->data)), 0); + env->DeleteGlobalRef(byte_pair.first); } + stack->str_args.clear(); + stack->byte_args.clear(); + stack->packed_args.clear(); // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue"); jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); - env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode)); - + env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, ret_val)); env->DeleteLocalRef(refTVMValueCls); - return ret; } @@ -255,27 +221,24 @@ class JNIEnvPtrHelper { }; // Callback function -extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, - TVMRetValueHandle ret, void* resourceHandle) { +extern "C" int funcInvokeCallback(void* self, const TVMFFIAny* args, int num_args, TVMFFIAny* ret) { JNIEnv* env; int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { _jvm->AttachCurrentThread(JNIEnvPtrHelper(&env), nullptr); } else { - CHECK(jniStatus == JNI_OK); + TVM_FFI_ICHECK(jniStatus == JNI_OK); } jclass tvmValueCls = env->FindClass("org/apache/tvm/TVMValue"); - jobjectArray jargs = env->NewObjectArray(numArgs, tvmValueCls, 0); - for (int i = 0; i < numArgs; ++i) { - TVMValue arg = args[i]; - int tcode = typeCodes[i]; - if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || - tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle || - tcode == kTVMNDArrayHandle) { - TVMCbArgToReturn(&arg, &tcode); + jobjectArray jargs = env->NewObjectArray(num_args, tvmValueCls, 0); + + for (int i = 0; i < num_args; ++i) { + TVMFFIAny arg = args[i]; + if (args[i].type_index >= tvm::ffi::TypeIndex::kTVMFFIRawStr) { + TVMFFIAnyViewToOwnedAny(&args[i], &arg); } - jobject jarg = tvmRetValueToJava(env, arg, tcode); + jobject jarg = tvmRetValueToJava(env, arg); env->SetObjectArrayElement(jargs, i, jarg); } @@ -285,46 +248,39 @@ extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, "(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;"); jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V"); - jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc, - reinterpret_cast(resourceHandle), jargs); + reinterpret_cast(self), jargs); - TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); - const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); - const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); + // the stack + TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal(); + const size_t prev_num_str_args = stack->str_args.size(); + const size_t prev_num_bytes_args = stack->byte_args.size(); // convert returned (java) TVMValue to (C) TVMValue env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue); - TVMValue retValue = e->tvmFuncArgValues.back(); - e->tvmFuncArgValues.pop_back(); - - int retCode = e->tvmFuncArgTypes.back(); - e->tvmFuncArgTypes.pop_back(); - - // set back the return value - TVMCFuncSetReturn(ret, &retValue, &retCode, 1); + TVMFFIAny ret_val = stack->packed_args.back().CopyToTVMFFIAny(); + stack->packed_args.pop_back(); + TVMFFIAnyViewToOwnedAny(&ret_val, ret); // release allocated strings. - if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) { - const auto& pairArg = e->tvmFuncArgPushedStrs.back(); + if (stack->str_args.size() > prev_num_str_args) { + const auto& pairArg = stack->str_args.back(); env->ReleaseStringUTFChars(pairArg.first, pairArg.second); env->DeleteGlobalRef(pairArg.first); - e->tvmFuncArgPushedStrs.pop_back(); + stack->str_args.pop_back(); } // release allocated bytes. - if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) { - const auto& pairArg = e->tvmFuncArgPushedBytes.back(); + if (stack->byte_args.size() > prev_num_bytes_args) { + const auto& pairArg = stack->byte_args.back(); env->ReleaseByteArrayElements( pairArg.first, reinterpret_cast(const_cast(pairArg.second->data)), 0); env->DeleteGlobalRef(pairArg.first); - delete pairArg.second; - e->tvmFuncArgPushedBytes.pop_back(); + stack->byte_args.pop_back(); } env->DeleteLocalRef(clsFunc); env->DeleteLocalRef(tvmValueCls); - return 0; } @@ -335,90 +291,43 @@ extern "C" void funcFreeCallback(void* resourceHandle) { if (jniStatus == JNI_EDETACHED) { _jvm->AttachCurrentThread(JNIEnvPtrHelper(&env), nullptr); } else { - CHECK(jniStatus == JNI_OK); + TVM_FFI_ICHECK(jniStatus == JNI_OK); } env->DeleteGlobalRef(reinterpret_cast(resourceHandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(JNIEnv* env, jobject obj, - jobject jfunction, - jobject jretHandle) { - TVMFunctionHandle out; - int ret = - TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), - reinterpret_cast(env->NewGlobalRef(jfunction)), - reinterpret_cast(&funcFreeCallback), &out); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCreateFromCallback( + JNIEnv* env, jobject obj, jobject jfunction, jobject jretHandle) { + TVMFFIObjectHandle out; + int ret = TVMFFIFunctionCreate(reinterpret_cast(env->NewGlobalRef(jfunction)), + funcInvokeCallback, funcFreeCallback, &out); setLongField(env, jretHandle, reinterpret_cast(out)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(JNIEnv* env, jobject obj, - jstring jname, - jlong jhandle, - jint joverride) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jlong jhandle, + jint joverride) { const char* name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncRegisterGlobal(name, reinterpret_cast(jhandle), - reinterpret_cast(joverride)); + TVMFFIByteArray name_bytes{name, strlen(name)}; + int ret = TVMFFIFunctionSetGlobal(&name_bytes, reinterpret_cast(jhandle), + reinterpret_cast(joverride)); env->ReleaseStringUTFChars(jname, name); return ret; } // Module -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(JNIEnv* env, jobject obj, - jlong jhandle) { - return TVMModFree(reinterpret_cast(jhandle)); -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(JNIEnv* env, jobject obj, - jlong jmod, jlong jdep) { - return TVMModImport(reinterpret_cast(jmod), - reinterpret_cast(jdep)); -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(JNIEnv* env, jobject obj, - jlong jhandle, jstring jname, - jint jimport, jobject jret) { - TVMFunctionHandle retFunc; - - const char* name = env->GetStringUTFChars(jname, 0); - int ret = TVMModGetFunction(reinterpret_cast(jhandle), name, - reinterpret_cast(jimport), &retFunc); - env->ReleaseStringUTFChars(jname, name); - - setLongField(env, jret, reinterpret_cast(retFunc)); - - return ret; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv* env, jobject obj, + jlong jhandle) { + return TVMFFIObjectFree(reinterpret_cast(jhandle)); } // NDArray -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(JNIEnv* env, jobject obj, - jlong jhandle) { - return TVMArrayFree(reinterpret_cast(jhandle)); -} - -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(JNIEnv* env, jobject obj, - jlongArray jshape, jint jdtypeCode, - jint jdtypeBits, jint jdtypeLanes, - jint jdeviceType, jint jdeviceId, - jobject jret) { - int ndim = static_cast(env->GetArrayLength(jshape)); - - TVMArrayHandle out; - - jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); - int ret = TVMArrayAlloc(reinterpret_cast(shapeArray), ndim, - static_cast(jdtypeCode), static_cast(jdtypeBits), - static_cast(jdtypeLanes), static_cast(jdeviceType), - static_cast(jdeviceId), &out); - env->ReleaseLongArrayElements(jshape, shapeArray, 0); - - setLongField(env, jret, reinterpret_cast(out)); - - return ret; -} -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, jobject obj, - jlong jhandle, jobject jshape) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorGetShape(JNIEnv* env, jobject obj, + jlong jhandle, + jobject jshape) { DLTensor* array = reinterpret_cast(jhandle); int64_t* shape = array->shape; int ndim = array->ndim; @@ -440,45 +349,72 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, return 0; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(JNIEnv* env, jobject obj, - jlong jfrom, jlong jto) { - return TVMArrayCopyFromTo(reinterpret_cast(jfrom), - reinterpret_cast(jto), NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromTo(JNIEnv* env, + jobject obj, + jlong jfrom, + jlong jto) { + TVM_FFI_SAFE_CALL_BEGIN(); + static auto fcopy_from_to = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyFromTo"); + fcopy_from_to(reinterpret_cast(jfrom), reinterpret_cast(jto)); + TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(JNIEnv* env, jobject obj, - jbyteArray jarr, - jlong jfrom, jlong jto) { - jbyte* data = env->GetByteArrayElements(jarr, NULL); - - DLTensor* from = reinterpret_cast(jfrom); - from->data = static_cast(data); - - int ret = TVMArrayCopyFromTo(static_cast(from), - reinterpret_cast(jto), NULL); - - from->data = NULL; - env->ReleaseByteArrayElements(jarr, data, 0); - - return ret; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromJArray(JNIEnv* env, + jobject obj, + jbyteArray jarr, + jlong jto) { + TVM_FFI_SAFE_CALL_BEGIN(); + jbyte* pdata = env->GetByteArrayElements(jarr, NULL); + DLTensor* to = reinterpret_cast(jto); + size_t size = tvm::ffi::GetDataSize(*to); + static auto fcopy_from_bytes = + tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyFromBytes"); + fcopy_from_bytes(to, static_cast(pdata), size); + env->ReleaseByteArrayElements(jarr, pdata, 0); + TVM_FFI_SAFE_CALL_END(); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(JNIEnv* env, jobject obj, - jlong jfrom, - jbyteArray jarr) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyToJArray(JNIEnv* env, + jobject obj, + jlong jfrom, + jbyteArray jarr) { + TVM_FFI_SAFE_CALL_BEGIN(); DLTensor* from = reinterpret_cast(jfrom); - int size = static_cast(env->GetArrayLength(jarr)); + size_t size = tvm::ffi::GetDataSize(*from); jbyte* pdata = env->GetByteArrayElements(jarr, NULL); - int ret = 0; - if (memcpy(static_cast(pdata), from->data, size) == NULL) { - ret = 1; - } - env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically - return ret; + static auto fcopy_to_bytes = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayCopyToBytes"); + fcopy_to_bytes(from, static_cast(pdata), size); + env->ReleaseByteArrayElements(jarr, static_cast(pdata), + 0); // copy back to java array automatically + TVM_FFI_SAFE_CALL_END(); +} + +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jobject obj, + jint jdeviceType, + jint jdeviceId) { + TVM_FFI_SAFE_CALL_BEGIN(); + static auto fsync = tvm::ffi::Function::GetGlobalRequired("runtime.Device_StreamSync"); + DLDevice device{static_cast(jdeviceType), jdeviceId}; + fsync(device, nullptr); + TVM_FFI_SAFE_CALL_END(); } -// Device -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jint deviceType, - jint deviceId) { - return TVMSynchronize(static_cast(deviceType), static_cast(deviceId), NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmNDArrayEmpty( + JNIEnv* env, jobject obj, jlongArray jshape, jint jdtypeCode, jint jdtypeBits, jint jdtypeLanes, + jint jdeviceType, jint jdeviceId, jobject jret) { + TVM_FFI_SAFE_CALL_BEGIN(); + int ndim = static_cast(env->GetArrayLength(jshape)); + jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); + tvm::ffi::Shape shape(shapeArray, shapeArray + ndim); + DLDataType dtype; + dtype.code = static_cast(jdtypeCode); + dtype.bits = static_cast(jdtypeBits); + dtype.lanes = static_cast(jdtypeLanes); + DLDevice device{static_cast(jdeviceType), jdeviceId}; + env->ReleaseLongArrayElements(jshape, shapeArray, 0); + static auto fempty = tvm::ffi::Function::GetGlobalRequired("runtime.TVMArrayAllocWithScope"); + tvm::ffi::NDArray out = fempty(shape, dtype, device, nullptr).cast(); + void* handle = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out)); + setLongField(env, jret, reinterpret_cast(handle)); + TVM_FFI_SAFE_CALL_END(); } diff --git a/python/setup.py b/python/setup.py index 3900e5b3d02d..679f5078d3c1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -20,6 +20,7 @@ import pathlib import shutil import sys +import sys from setuptools import find_packages from setuptools.dist import Distribution @@ -42,7 +43,7 @@ def get_lib_path(): """Get library path, name and version""" # We can not import `libinfo.py` in setup.py directly since __init__.py # Will be invoked which introduces dependencies - libinfo_py = os.path.join(CURRENT_DIR, "./tvm/_ffi/libinfo.py") + libinfo_py = os.path.join(CURRENT_DIR, "./tvm/libinfo.py") libinfo = {"__file__": libinfo_py} exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo) version = libinfo["__version__"] @@ -145,7 +146,15 @@ def config_cython(): try: from Cython.Build import cythonize - subdir = "_cy3" + # for python 3.12+, use limited API for future compact + limited_api_kwargs = {} + if sys.version_info >= (3, 12): + limited_api_kwargs = { + "define_macros": [ + ("Py_LIMITED_API", 0x030C0000), + ], + "py_limited_api": True, + } ret = [] extra_compile_args = ["-std=c++17", "-DDMLC_USE_LOGGING_LIBRARY="] @@ -179,6 +188,7 @@ def config_cython(): library_dirs=library_dirs, libraries=libraries, language="c++", + **limited_api_kwargs, ) ) return cythonize(ret, compiler_directives={"language_level": 3}) @@ -242,7 +252,6 @@ def long_description_contents(): license="Apache", # See https://pypi.org/classifiers/ classifiers=[ - "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Intended Audience :: Education", @@ -250,7 +259,6 @@ def long_description_contents(): ], keywords="machine learning", zip_safe=False, - entry_points={"console_scripts": ["tvmc = tvm.driver.tvmc.main:main"]}, install_requires=requirements["core"][1], extras_require=extras_require, packages=find_packages(), diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 8563c84ab398..150e5d4b1dbc 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -19,13 +19,12 @@ import multiprocessing import sys import os -import traceback # top-level alias # tvm._ffi -from ._ffi.base import TVMError, __version__, _RUNTIME_ONLY +from .base import TVMError, __version__, _RUNTIME_ONLY -from ._ffi import register_object, register_func, get_global_func +from .ffi import register_object, register_func, get_global_func # top-level alias # tvm.runtime diff --git a/python/tvm/_ffi/__init__.py b/python/tvm/_ffi/__init__.py deleted file mode 100644 index 559ca84635bd..000000000000 --- a/python/tvm/_ffi/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""C interfacing code. - -This namespace contains everything that interacts with C code. -Most TVM C related object are ctypes compatible, which means -they contains a handle field that is ctypes.c_void_p and can -be used via ctypes function calls. - -Some performance critical functions are implemented by cython -and have a ctypes fallback implementation. -""" -from . import _pyversion -from . import base -from .registry import register_object, register_func -from .registry import _init_api, get_global_func -from ..ffi import register_error diff --git a/python/tvm/_ffi/_pyversion.py b/python/tvm/_ffi/_pyversion.py deleted file mode 100644 index b661cfd875fc..000000000000 --- a/python/tvm/_ffi/_pyversion.py +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Python version check -""" -import sys - -# ---------------------------- -# Python3 version. -# ---------------------------- -if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 6): - PY3STATEMENT = "The minimal Python requirement is Python 3.6" - raise Exception(PY3STATEMENT) diff --git a/python/tvm/_ffi/registry.py b/python/tvm/_ffi/registry.py deleted file mode 100644 index b11cab48b841..000000000000 --- a/python/tvm/_ffi/registry.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=invalid-name, unused-import, wrong-import-position -"""FFI registry to register function and objects.""" - -import tvm.ffi - -from tvm.ffi import register_object, register_func, get_global_func - -from tvm.ffi.registry import ( - list_global_func_names, - remove_global_func, - _init_api, -) diff --git a/python/tvm/arith/_ffi_api.py b/python/tvm/arith/_ffi_api.py index c551e5651563..e05405b0fcc6 100644 --- a/python/tvm/arith/_ffi_api.py +++ b/python/tvm/arith/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.arith""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("arith", __name__) +tvm.ffi._init_api("arith", __name__) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index f8069a717da3..919272a2734b 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -19,7 +19,7 @@ import enum from typing import Union -import tvm._ffi +import tvm.ffi from tvm import tir, ir from tvm.runtime import Object @@ -46,7 +46,7 @@ class Extension(enum.Flag): ComparisonOfProductAndSum = 1 << 3 -@tvm._ffi.register_object("arith.ModularSet") +@tvm.ffi.register_object("arith.ModularSet") class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" @@ -54,7 +54,7 @@ def __init__(self, coeff, base): self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base) -@tvm._ffi.register_object("arith.ConstIntBound") +@tvm.ffi.register_object("arith.ConstIntBound") class ConstIntBound(Object): """Represent constant integer bound diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index d38f5e805f39..f779df5d4c92 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Integer set.""" -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from . import _ffi_api @@ -64,7 +64,7 @@ def single_point(point): return _ffi_api.intset_single_point(point) -@tvm._ffi.register_object("arith.IntervalSet") +@tvm.ffi.register_object("arith.IntervalSet") class IntervalSet(IntSet): """Represent set of continuous interval [min_value, max_value] @@ -81,7 +81,7 @@ def __init__(self, min_value, max_value): self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value) -@tvm._ffi.register_object("arith.PresburgerSet") +@tvm.ffi.register_object("arith.PresburgerSet") class PresburgerSet(IntSet): """Represent of Presburger Set""" diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 6e8a010eec16..a97cda10f8eb 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """integer constraints data structures and solvers""" -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from . import _ffi_api -@tvm._ffi.register_object("arith.IntGroupBounds") +@tvm.ffi.register_object("arith.IntGroupBounds") class IntGroupBounds(Object): """Represent integer grouped bounds which are classified into lower bounds (include), upper bounds (include) and equalities. @@ -66,7 +66,7 @@ def find_best_range(self): return _ffi_api.IntGroupBounds_FindBestRange(self) -@tvm._ffi.register_object("arith.IntConstraints") +@tvm.ffi.register_object("arith.IntConstraints") class IntConstraints(Object): """Represent a set of integer constraints including variables, their ranges and the relations between them (either equations or inequalities) @@ -85,7 +85,7 @@ def __init__(self, variables, ranges, relations): self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations) -@tvm._ffi.register_object("arith.IntConstraintsTransform") +@tvm.ffi.register_object("arith.IntConstraintsTransform") class IntConstraintsTransform(Object): """We can have different set of variables to represent the same integer constraints. For example, the following two constrains are equivalent, diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index f19dd0a1bac9..dbb4087f325f 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" Iterator (quasi)affine mapping patterns.""" +"""Iterator (quasi)affine mapping patterns.""" from enum import IntEnum -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from tvm.ir import PrimExpr from . import _ffi_api @@ -26,7 +26,7 @@ class IterMapExpr(PrimExpr): """Base class of all IterMap expressions.""" -@tvm._ffi.register_object("arith.IterMark") +@tvm.ffi.register_object("arith.IterMark") class IterMark(Object): """Mark the source as an iterator in [0, extent). @@ -43,7 +43,7 @@ def __init__(self, source, extent): self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) -@tvm._ffi.register_object("arith.IterSplitExpr") +@tvm.ffi.register_object("arith.IterSplitExpr") class IterSplitExpr(IterMapExpr): """Split of an iterator. @@ -70,7 +70,7 @@ def __init__(self, source, lower_factor, extent, scale): ) -@tvm._ffi.register_object("arith.IterSumExpr") +@tvm.ffi.register_object("arith.IterSumExpr") class IterSumExpr(IterMapExpr): """Fuse multiple iterators by summing them with scaling. diff --git a/python/tvm/_ffi/base.py b/python/tvm/base.py similarity index 86% rename from python/tvm/_ffi/base.py rename to python/tvm/base.py index 18ed40fb4cb9..13b4fc8d443a 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/base.py @@ -16,22 +16,23 @@ # under the License. # coding: utf-8 # pylint: disable=invalid-name, import-outside-toplevel -"""Base library for TVM FFI.""" +"""Base library for TVM.""" import ctypes import os import sys - -import numpy as np - from . import libinfo +# ---------------------------- +# Python3 version. +# ---------------------------- +if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 8): + PY3STATEMENT = "The minimal Python requirement is Python 3.8" + raise Exception(PY3STATEMENT) + # ---------------------------- # library loading # ---------------------------- -string_types = (str,) -integer_types = (int, np.int32) -numeric_types = integer_types + (float, np.float16, np.float32) def _load_lib(): @@ -62,7 +63,7 @@ def _load_lib(): if _RUNTIME_ONLY: - from ..ffi import registry as _tvm_ffi_registry + from .ffi import registry as _tvm_ffi_registry _tvm_ffi_registry._SKIP_UNKNOWN_OBJECTS = True diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 110f80db6186..04a69baee9c1 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -23,7 +23,7 @@ import sys from typing import Dict -from .._ffi.base import py_str +from ..base import py_str from . import tar as _tar from . import utils as _utils diff --git a/python/tvm/contrib/clang.py b/python/tvm/contrib/clang.py index 16c465dc22ab..4d2769436d06 100644 --- a/python/tvm/contrib/clang.py +++ b/python/tvm/contrib/clang.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name import subprocess -from tvm._ffi.base import py_str +from tvm.base import py_str import tvm.target from . import utils diff --git a/python/tvm/contrib/coreml_runtime.py b/python/tvm/contrib/coreml_runtime.py index aa4f21279967..def5d3c2e06e 100644 --- a/python/tvm/contrib/coreml_runtime.py +++ b/python/tvm/contrib/coreml_runtime.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """CoreML runtime that load and run coreml models.""" -import tvm._ffi +import tvm.ffi from ..rpc import base as rpc_base @@ -41,7 +41,7 @@ def create(symbol, compiled_model_path, device): if device_type >= rpc_base.RPC_SESS_MASK: fcreate = device._rpc_sess.get_function(runtime_func) else: - fcreate = tvm._ffi.get_global_func(runtime_func) + fcreate = tvm.ffi.get_global_func(runtime_func) assert fcreate, "Cannot find `tvm.coreml_runtime.create` function." return CoreMLModule(fcreate(symbol, compiled_model_path)) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 9d3f80a5c74b..1c80d4a3b9e1 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -20,7 +20,7 @@ import numpy as np import tvm -import tvm._ffi +import tvm.ffi from tvm import te # algos can be read from cudnn.h @@ -349,7 +349,7 @@ def _conv_find_algo( dims - 2, pad, stride, dilation, x_shape, w_shape ) yshape = np.array(y_shape, dtype=np.int32) - func = tvm._ffi.get_global_func(func_name) + func = tvm.ffi.get_global_func(func_name) return func( tensor_format, dims - 2, diff --git a/python/tvm/contrib/cutlass/_ffi_api.py b/python/tvm/contrib/cutlass/_ffi_api.py index e71eb8c13f19..be71b0d48f13 100644 --- a/python/tvm/contrib/cutlass/_ffi_api.py +++ b/python/tvm/contrib/cutlass/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI API for CUTLASS BYOC.""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("contrib.cutlass", __name__) +tvm.ffi._init_api("contrib.cutlass", __name__) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index e12ce93e270b..0aea5bf1416a 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -26,7 +26,7 @@ import tvm from tvm import relax, runtime -from tvm._ffi.registry import register_func +from tvm.ffi.registry import register_func from tvm.contrib.nvcc import get_cuda_version from tvm.topi.utils import get_const_tuple diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 5d04cf13e693..6fa349b28e44 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -24,7 +24,7 @@ import subprocess import tempfile -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from tvm.tir import IntImm @@ -461,7 +461,7 @@ def _get_optional_int_annotation(annotations, key, default=None): return int(value) -@tvm._ffi.register_func("contrib.cutlass.instantiate_template") +@tvm.ffi.register_func("contrib.cutlass.instantiate_template") def instantiate_template(func_name, annotations, func_args): """Return CUTLASS host code based on a template and the provided annotations. @@ -487,7 +487,7 @@ def instantiate_template(func_name, annotations, func_args): if k in annotations: attrs[k] = annotations[k] - headers = ["tvm/runtime/registry.h"] + headers = ["tvm/ffi/function.h"] if "relu" in func_name: headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h") diff --git a/python/tvm/contrib/emcc.py b/python/tvm/contrib/emcc.py index 3beb096b6747..a9f0bec0cf9d 100644 --- a/python/tvm/contrib/emcc.py +++ b/python/tvm/contrib/emcc.py @@ -20,8 +20,8 @@ import subprocess from pathlib import Path -from tvm._ffi.base import py_str -from tvm._ffi.libinfo import find_lib_path +from tvm.base import py_str +from tvm.libinfo import find_lib_path def create_tvmjs_wasm(output, objects, options=None, cc="emcc", libs=None): diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index d1ca5227fd0f..f4b02ff80f73 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -35,7 +35,7 @@ from typing import Union from tvm.contrib.hexagon.hexagon_profiler import HexagonProfiler -from ..._ffi import libinfo +from ...ffi import libinfo from .session import Session from .tools import HEXAGON_SIMULATOR_NAME diff --git a/python/tvm/contrib/hexagon/tools.py b/python/tvm/contrib/hexagon/tools.py index 3b129b03323b..5ee89713d9a5 100644 --- a/python/tvm/contrib/hexagon/tools.py +++ b/python/tvm/contrib/hexagon/tools.py @@ -29,7 +29,7 @@ import tvm import tvm.contrib.cc as cc -from ..._ffi.registry import register_func +from ...ffi.registry import register_func # Linking Hexagon shared libraries. diff --git a/python/tvm/contrib/miopen.py b/python/tvm/contrib/miopen.py index 0e336c1c82b9..22b08f38ca76 100644 --- a/python/tvm/contrib/miopen.py +++ b/python/tvm/contrib/miopen.py @@ -19,7 +19,7 @@ import ctypes import numpy as np import tvm -import tvm._ffi +import tvm.ffi from tvm import te @@ -94,7 +94,7 @@ def conv2d_forward( oshape = np.zeros((len(x.shape)), dtype=np.int32) xshape = x.shape wshape = w.shape - setup_func = tvm._ffi.get_global_func("tvm.contrib.miopen.conv2d.setup") + setup_func = tvm.ffi.get_global_func("tvm.contrib.miopen.conv2d.setup") algo = setup_func( conv_mode, data_type, diff --git a/python/tvm/contrib/mrvl.py b/python/tvm/contrib/mrvl.py index 3cf393b34160..36c932cd1a1d 100644 --- a/python/tvm/contrib/mrvl.py +++ b/python/tvm/contrib/mrvl.py @@ -24,10 +24,10 @@ import base64 import numpy as np import tvm -import tvm._ffi +import tvm.ffi -@tvm._ffi.register_func("tvm.mrvl.find_value_in_KV_pair") +@tvm.ffi.register_func("tvm.mrvl.find_value_in_KV_pair") def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: """This function takes the graph_json string and key to be searched in the json string, using json parser routine it loads the json string @@ -54,7 +54,7 @@ def find_value_in_KV_pair(json_input: str, key_to_find: str) -> str: return value -@tvm._ffi.register_func("tvm.mrvl.GetNodesJSONString") +@tvm.ffi.register_func("tvm.mrvl.GetNodesJSONString") def get_nodes_json_string(graph_json): """This takes the graph_json string from MrvlJSONSerializer and adds / modifies the json string to a form suitable for the Marvell Backend. @@ -206,7 +206,7 @@ def get_nodes_json_string(graph_json): return nodes_json_string -@tvm._ffi.register_func("tvm.mrvl.ModifyConstNames") +@tvm.ffi.register_func("tvm.mrvl.ModifyConstNames") def modify_const_names(nodes_json_str, consts_json_str): """This takes the graph module returned by build an generates nodes and constant meta data suitable for compilation by the back end. @@ -329,7 +329,7 @@ def get_working_dir(): return os.getcwd() -@tvm._ffi.register_func("tvm.mrvl.WriteJsonFile") +@tvm.ffi.register_func("tvm.mrvl.WriteJsonFile") def write_json_file(json_string, json_filename): """Generate json file under working directory""" working_dir = get_working_dir() @@ -351,7 +351,7 @@ def delete_temp_files(symbol_name): shutil.rmtree(bin_folder) -@tvm._ffi.register_func("tvm.mrvl.CompileModel") +@tvm.ffi.register_func("tvm.mrvl.CompileModel") def compile_model( symbol_name, nodes_json_string, @@ -414,7 +414,7 @@ def compile_model( raise RuntimeError(error_msg) -@tvm._ffi.register_func("tvm.mrvl.CleanUpSim") +@tvm.ffi.register_func("tvm.mrvl.CleanUpSim") def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): os.remove(bin_file) os.remove(input_json) @@ -424,7 +424,7 @@ def clean_up_sim(bin_file, input_json, input_bin, out_bin_prefix, num_outputs): os.remove(out_bin) -@tvm._ffi.register_func("tvm.mrvl.SearchPath") +@tvm.ffi.register_func("tvm.mrvl.SearchPath") def search_path(file_name): path = shutil.which(file_name) if path is None: @@ -432,7 +432,7 @@ def search_path(file_name): return os.path.dirname(path) -@tvm._ffi.register_func("tvm.mrvl.JsonToBin") +@tvm.ffi.register_func("tvm.mrvl.JsonToBin") def convert_json_to_bin(json_file, input_bin_file): with open(json_file) as input_json: data = json.load(input_json) @@ -442,7 +442,7 @@ def convert_json_to_bin(json_file, input_bin_file): f.write(data_b) -@tvm._ffi.register_func("tvm.mrvl.RunSim") +@tvm.ffi.register_func("tvm.mrvl.RunSim") def run_simulation(run_command, sim_directory): cwd_path = get_working_dir() os.mkdir(sim_directory) @@ -452,6 +452,6 @@ def run_simulation(run_command, sim_directory): shutil.rmtree(sim_directory) -@tvm._ffi.register_func("tvm.mrvl.TempDir") +@tvm.ffi.register_func("tvm.mrvl.TempDir") def get_temp_dir(): return tempfile.gettempdir() diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py index c0b0e21267ea..f7c975aff98a 100644 --- a/python/tvm/contrib/msc/core/_ffi_api.py +++ b/python/tvm/contrib/msc/core/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.core._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.core", __name__) +tvm.ffi._init_api("msc.core", __name__) diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 172f40e06a31..9aa5bde93380 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -25,7 +25,7 @@ from tvm.contrib.msc.core import utils as msc_utils -@tvm._ffi.register_object("msc.core.MSCTensor") +@tvm.ffi.register_object("msc.core.MSCTensor") class MSCTensor(Object): """Tensor in MSCGraph @@ -198,7 +198,7 @@ class BaseJoint(Object): """Base class of all MSC Nodes.""" -@tvm._ffi.register_object("msc.core.MSCJoint") +@tvm.ffi.register_object("msc.core.MSCJoint") class MSCJoint(BaseJoint): """Node in MSCGraph @@ -423,7 +423,7 @@ def equal(self, other: BaseJoint) -> bool: return msc_utils.dict_equal(self.get_attrs(), other.get_attrs()) -@tvm._ffi.register_object("msc.core.MSCPrim") +@tvm.ffi.register_object("msc.core.MSCPrim") class MSCPrim(BaseJoint): """Prim in MSCGraph @@ -447,7 +447,7 @@ def __init__( self.__init_handle_by_constructor__(_ffi_api.MSCPrim, index, name, optype, attrs, parents) -@tvm._ffi.register_object("msc.core.WeightJoint") +@tvm.ffi.register_object("msc.core.WeightJoint") class WeightJoint(BaseJoint): """Node in WeightGraph @@ -565,7 +565,7 @@ class BaseGraph(Object): """Base class of all MSC Graphs.""" -@tvm._ffi.register_object("msc.core.MSCGraph") +@tvm.ffi.register_object("msc.core.MSCGraph") class MSCGraph(BaseGraph): """The MSCGraph @@ -954,7 +954,7 @@ def visualize(self, path: Optional[str] = None) -> str: return graph_proto -@tvm._ffi.register_object("msc.core.WeightGraph") +@tvm.ffi.register_object("msc.core.WeightGraph") class WeightGraph(Object): """The WeightGraph diff --git a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py index d43984b9c292..5b85e16a53ba 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.tensorflow._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.framework.tensorflow", __name__) +tvm.ffi._init_api("msc.framework.tensorflow", __name__) diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py index c0fa9c2c0559..4db71f3a19de 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.tensorrt._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.framework.tensorrt", __name__) +tvm.ffi._init_api("msc.framework.tensorrt", __name__) diff --git a/python/tvm/contrib/msc/framework/torch/_ffi_api.py b/python/tvm/contrib/msc/framework/torch/_ffi_api.py index 190e7507fb07..d12fcf2e2f87 100644 --- a/python/tvm/contrib/msc/framework/torch/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/torch/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.torch._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.framework.torch", __name__) +tvm.ffi._init_api("msc.framework.torch", __name__) diff --git a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py index e82612a6403f..a3683181b0e4 100644 --- a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.framework.tvm._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.framework.tvm", __name__) +tvm.ffi._init_api("msc.framework.tvm", __name__) diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py index 0e12c29242d1..c566d3b0d332 100644 --- a/python/tvm/contrib/msc/plugin/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.plugin._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.plugin", __name__) +tvm.ffi._init_api("msc.plugin", __name__) diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py index 2111e11227a1..0d8ad3c5e457 100644 --- a/python/tvm/contrib/msc/plugin/op/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/op/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """tvm.contrib.msc.plugin.op._ffi_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("msc.plugin.op", __name__) +tvm.ffi._init_api("msc.plugin.op", __name__) diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index 14820c0ca8ab..c1441c496ae8 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -25,8 +25,8 @@ import tempfile from pathlib import Path -from .._ffi import register_func -from .._ffi.base import py_str +from ..ffi import register_func +from ..base import py_str from . import utils as _utils, tar as _tar, cc as _cc from .cc import get_target_by_dump_machine diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 010bef533c00..1b4f51850805 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -17,7 +17,7 @@ """External function interface to NNPACK libraries.""" import tvm from tvm import te -import tvm._ffi +import tvm.ffi def is_available(): @@ -232,4 +232,4 @@ def convolution_inference_weight_transform( ) -tvm._ffi._init_api("tvm.contrib.nnpack") +tvm.ffi._init_api("tvm.contrib.nnpack") diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index c8b749b36bf1..45e2793fbb6f 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -22,10 +22,10 @@ import subprocess import warnings -import tvm._ffi +import tvm.ffi from tvm.target import Target -from .._ffi.base import py_str +from ..base import py_str from . import utils @@ -198,14 +198,14 @@ def get_cuda_version(cuda_path=None): raise RuntimeError("Cannot read cuda version file") -@tvm._ffi.register_func +@tvm.ffi.register_func def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument """use nvcc to generate fatbin code for better optimization""" ptx = compile_cuda(code, target_format="fatbin") return ptx -@tvm._ffi.register_func("tvm_callback_libdevice_path") +@tvm.ffi.register_func("tvm_callback_libdevice_path") def find_libdevice_path(arch): """Utility function to find libdevice @@ -270,7 +270,7 @@ def callback_libdevice_path(arch): return "" -@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version") +@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version") def get_target_compute_version(target=None): """Utility function to get compute capability of compilation target. @@ -415,7 +415,7 @@ def have_cudagraph(): return False -@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16") +@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16") def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not @@ -431,7 +431,7 @@ def have_bf16(compute_version): return False -@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8") +@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8") def have_fp8(compute_version): """Whether fp8 support is provided in the specified compute capability or not @@ -449,7 +449,7 @@ def have_fp8(compute_version): return False -@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp4") +@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp4") def have_fp4(compute_version): """Whether fp4 support is provided in the specified compute capability or not diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index 72efb8064312..a5d09fb9316a 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -23,7 +23,6 @@ import sys import functools -from .._ffi.base import string_types try: import cPickle as pickle @@ -115,7 +114,7 @@ def memoize(key, save_at_exit=False): def _register(f): """Registration function""" - allow_types = (string_types, int, float, tuple) + allow_types = (str, int, float, tuple) fkey = key + "." + f.__name__ + ".pkl" if fkey not in Cache.cache_by_key: Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit) diff --git a/python/tvm/contrib/random.py b/python/tvm/contrib/random.py index bbc74fccac94..6a17693b9162 100644 --- a/python/tvm/contrib/random.py +++ b/python/tvm/contrib/random.py @@ -17,7 +17,7 @@ """External function interface to random library.""" import tvm from tvm import te -import tvm._ffi +import tvm.ffi def randint(low, high, size, dtype="int32"): @@ -112,4 +112,4 @@ def normal(loc, scale, size): ) -tvm._ffi._init_api("tvm.contrib.random") +tvm.ffi._init_api("tvm.contrib.random") diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index f3427463b3e0..6e6a985c2732 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -20,8 +20,8 @@ import os from os.path import join, exists -import tvm._ffi -from tvm._ffi.base import py_str +import tvm.ffi +from tvm.base import py_str import tvm.runtime import tvm.target @@ -99,7 +99,7 @@ def rocm_link(in_file, out_file, lld=None): raise RuntimeError(msg) -@tvm._ffi.register_func("tvm_callback_rocm_link") +@tvm.ffi.register_func("tvm_callback_rocm_link") def callback_rocm_link(obj_bin): """Links object file generated from LLVM to HSA Code Object @@ -123,7 +123,7 @@ def callback_rocm_link(obj_bin): return cobj_bin -@tvm._ffi.register_func("tvm_callback_rocm_bitcode_path") +@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path") def callback_rocm_bitcode_path(rocdl_dir=None): """Utility function to find ROCm device library bitcodes @@ -227,7 +227,7 @@ def have_matrixcore(compute_version=None): return False -@tvm._ffi.register_func("tvm_callback_rocm_get_arch") +@tvm.ffi.register_func("tvm_callback_rocm_get_arch") def get_rocm_arch(rocm_path=None): """Utility function to get the AMD GPU architecture diff --git a/python/tvm/contrib/spirv.py b/python/tvm/contrib/spirv.py index 94b24d0c7b09..0484562ee737 100644 --- a/python/tvm/contrib/spirv.py +++ b/python/tvm/contrib/spirv.py @@ -18,7 +18,7 @@ import subprocess import os from . import utils -from .._ffi.base import py_str +from ..base import py_str def optimize(spv_bin): diff --git a/python/tvm/contrib/tar.py b/python/tvm/contrib/tar.py index 67175b8b278c..7322ed447197 100644 --- a/python/tvm/contrib/tar.py +++ b/python/tvm/contrib/tar.py @@ -22,7 +22,7 @@ import shutil import subprocess from . import utils -from .._ffi.base import py_str +from ..base import py_str def tar(output, files): diff --git a/python/tvm/contrib/tflite_runtime.py b/python/tvm/contrib/tflite_runtime.py index 1558e36d51af..aceeefd248f4 100644 --- a/python/tvm/contrib/tflite_runtime.py +++ b/python/tvm/contrib/tflite_runtime.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """TFLite runtime that load and run tflite models.""" -import tvm._ffi +import tvm.ffi from ..rpc import base as rpc_base @@ -45,7 +45,7 @@ def create(tflite_model_bytes, device, runtime_target="cpu"): if device_type >= rpc_base.RPC_SESS_MASK: fcreate = device._rpc_sess.get_function(runtime_func) else: - fcreate = tvm._ffi.get_global_func(runtime_func) + fcreate = tvm.ffi.get_global_func(runtime_func) return TFLiteModule(fcreate(bytearray(tflite_model_bytes), device)) diff --git a/python/tvm/contrib/thrust.py b/python/tvm/contrib/thrust.py index 8f3178429589..9a05cfafbac3 100644 --- a/python/tvm/contrib/thrust.py +++ b/python/tvm/contrib/thrust.py @@ -17,7 +17,7 @@ """Utilities for thrust""" import logging -from tvm._ffi import get_global_func +from tvm.ffi import get_global_func def maybe_warn(target, func_name): diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index d936c8a2276c..e24b88a3f8c3 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -34,7 +34,7 @@ ml_dtypes = None import tvm -from tvm._ffi.libinfo import find_lib_path +from tvm.libinfo import find_lib_path from tvm.runtime import DataType from .emcc import create_tvmjs_wasm diff --git a/python/tvm/contrib/xcode.py b/python/tvm/contrib/xcode.py index d12367330dde..adfc2dcd8459 100644 --- a/python/tvm/contrib/xcode.py +++ b/python/tvm/contrib/xcode.py @@ -21,7 +21,7 @@ import sys import subprocess import json -from .._ffi.base import py_str +from ..base import py_str from . import utils diff --git a/python/tvm/dlight/analysis/common_analysis.py b/python/tvm/dlight/analysis/common_analysis.py index be260b894203..a3499274e5a8 100644 --- a/python/tvm/dlight/analysis/common_analysis.py +++ b/python/tvm/dlight/analysis/common_analysis.py @@ -20,7 +20,7 @@ from typing_extensions import Literal from tvm import ir, tir -from tvm._ffi import get_global_func +from tvm.ffi import get_global_func from tvm.target.target import Target from tvm.tir import Schedule from tvm.tir.schedule import BlockRV diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index d3979ce0e4c3..b1564bf61fa9 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -61,6 +61,23 @@ def apply( # pylint: disable=too-many-locals # Align the number of block iters of the last block. num_last_block_iter = len(block_infos[-1].dom_kind()) if num_last_block_iter < len(dom_kind): + # If the last block is a scalar value, there is nothing left to + # tile/parallelise, and `iters` is an empty tuple. + # Add a unit thread loop so the final write happens inside a valid + # GPU thread environment. + if num_last_block_iter == 0: + # Put every block (both the running reductions and the final + # scalar write) inside a trivial GPU thread. The very first block + # gets a `blockIdx.x` wrapper so that kernels still have a unique + # block scope. + for i, info in enumerate(block_infos): + loop_rv = sch.add_unit_loop(info.block_rv) + if i == 0: + sch.bind(loop_rv, "blockIdx.x") + else: + sch.bind(loop_rv, "threadIdx.x") + + return sch def f_layout_mapping(*iters): analyzer = arith.Analyzer() diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index c423656d78f5..1ceecc9c94c6 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.driver""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("driver", __name__) +tvm.ffi._init_api("driver", __name__) diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py index b1f1554b56f9..ecfeaa1ebb88 100644 --- a/python/tvm/exec/disco_worker.py +++ b/python/tvm/exec/disco_worker.py @@ -22,7 +22,7 @@ from typing import Callable import tvm -from tvm._ffi import get_global_func, register_func +from tvm.ffi import get_global_func, register_func from tvm.runtime import NDArray, ShapeTuple, String from tvm.runtime.ndarray import array diff --git a/python/tvm/ffi/convert.py b/python/tvm/ffi/convert.py index 467f7a2fb491..5b25ddae259b 100644 --- a/python/tvm/ffi/convert.py +++ b/python/tvm/ffi/convert.py @@ -54,6 +54,11 @@ def convert(value: Any) -> Any: return core._convert_to_ffi_func(value) elif value is None: return None + elif hasattr(value, "__dlpack__"): + return core.from_dlpack( + value, + required_alignment=core.__dlpack_auto_import_required_alignment__, + ) elif isinstance(value, Exception): return core._convert_to_ffi_error(value) else: diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi index 8fe23cd23b29..8b9c1f3d947b 100644 --- a/python/tvm/ffi/cython/base.pxi +++ b/python/tvm/ffi/cython/base.pxi @@ -150,7 +150,7 @@ cdef extern from "tvm/ffi/c_api.h": int TVMFFIEnvRegisterCAPI(TVMFFIByteArray* name, void* ptr) nogil int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil - int TVMFFIDataTypeToString(DLDataType dtype, TVMFFIObjectHandle* out) nogil + int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out) nogil const TVMFFIByteArray* TVMFFITraceback(const char* filename, int lineno, const char* func) nogil; int TVMFFINDArrayFromDLPack(DLManagedTensor* src, int32_t require_alignment, int32_t require_contiguous, TVMFFIObjectHandle* out) nogil diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi index 30f9f274b4af..80ec5d9364b1 100644 --- a/python/tvm/ffi/cython/dtype.pxi +++ b/python/tvm/ffi/cython/dtype.pxi @@ -94,7 +94,7 @@ cdef class DataType: def __str__(self): cdef TVMFFIObjectHandle dtype_str cdef TVMFFIByteArray* bytes - CHECK_CALL(TVMFFIDataTypeToString(self.cdtype, &dtype_str)) + CHECK_CALL(TVMFFIDataTypeToString(&(self.cdtype), &dtype_str)) bytes = TVMFFIBytesGetByteArrayPtr(dtype_str) res = py_str(PyBytes_FromStringAndSize(bytes.data, bytes.size)) CHECK_CALL(TVMFFIObjectFree(dtype_str)) diff --git a/python/tvm/ffi/cython/function.pxi b/python/tvm/ffi/cython/function.pxi index be80023c85b4..294a1246b27b 100644 --- a/python/tvm/ffi/cython/function.pxi +++ b/python/tvm/ffi/cython/function.pxi @@ -17,6 +17,11 @@ import ctypes from numbers import Real, Integral +try: + import torch +except ImportError: + torch = None + cdef inline object make_ret(TVMFFIAny result): """convert result to return value.""" @@ -71,6 +76,17 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except elif isinstance(arg, Object): out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) out[i].v_ptr = (arg).chandle + elif torch is not None and isinstance(arg, torch.Tensor): + arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg), + required_alignment=__dlpack_auto_import_required_alignment__) + out[i].type_index = kTVMFFINDArray + out[i].v_ptr = (arg).chandle + temp_args.append(arg) + elif hasattr(arg, "__dlpack__"): + arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__) + out[i].type_index = kTVMFFINDArray + out[i].v_ptr = (arg).chandle + temp_args.append(arg) elif isinstance(arg, PyNativeObject): arg = arg.__tvm_ffi_object__ out[i].type_index = TVMFFIObjectGetTypeIndex((arg).chandle) diff --git a/python/tvm/ffi/cython/ndarray.pxi b/python/tvm/ffi/cython/ndarray.pxi index cadf3de4fd6e..b8534b41b38b 100644 --- a/python/tvm/ffi/cython/ndarray.pxi +++ b/python/tvm/ffi/cython/ndarray.pxi @@ -16,8 +16,10 @@ # under the License. __dlpack_version__ = (1, 1) +__dlpack_auto_import_required_alignment__ = 8 _CLASS_NDARRAY = None + def _set_class_ndarray(cls): global _CLASS_NDARRAY _CLASS_NDARRAY = cls diff --git a/python/tvm/generic.py b/python/tvm/generic.py deleted file mode 100644 index 7c46312c2ea5..000000000000 --- a/python/tvm/generic.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Generic operators.""" -# pylint:disable=unused-wildcard-import, wildcard-import -from .tir.generic import * diff --git a/python/tvm/ir/_ffi_analysis_api.py b/python/tvm/ir/_ffi_analysis_api.py index 0013ec3b5026..ca38c2309f41 100644 --- a/python/tvm/ir/_ffi_analysis_api.py +++ b/python/tvm/ir/_ffi_analysis_api.py @@ -16,7 +16,7 @@ # under the License. """FFI APIs for tvm.ir.analysis""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("ir.analysis", __name__) +tvm.ffi._init_api("ir.analysis", __name__) diff --git a/python/tvm/ir/_ffi_api.py b/python/tvm/ir/_ffi_api.py index d3a9505c38d0..6434a3925e98 100644 --- a/python/tvm/ir/_ffi_api.py +++ b/python/tvm/ir/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.ir""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("ir", __name__) +tvm.ffi._init_api("ir", __name__) diff --git a/python/tvm/ir/_ffi_instrument_api.py b/python/tvm/ir/_ffi_instrument_api.py index bf62caf30e5a..d88faf7fddd0 100644 --- a/python/tvm/ir/_ffi_instrument_api.py +++ b/python/tvm/ir/_ffi_instrument_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.instrument""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("instrument", __name__) +tvm.ffi._init_api("instrument", __name__) diff --git a/python/tvm/ir/_ffi_transform_api.py b/python/tvm/ir/_ffi_transform_api.py index bb01b559c3d8..1a27fc58776c 100644 --- a/python/tvm/ir/_ffi_transform_api.py +++ b/python/tvm/ir/_ffi_transform_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.transform""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("transform", __name__) +tvm.ffi._init_api("transform", __name__) diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index fc63138043fa..6565a8de37b4 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. """TVM Attribute module, which is mainly used for defining attributes of operators.""" -import tvm._ffi +import tvm.ffi from tvm.runtime import Object import tvm.runtime._ffi_node_api from . import _ffi_api -@tvm._ffi.register_object +@tvm.ffi.register_object class Attrs(Object): """Attribute node, which is mainly use for defining attributes of operators. @@ -93,7 +93,7 @@ def __getitem__(self, item): return self.__getattr__(item) -@tvm._ffi.register_object +@tvm.ffi.register_object class DictAttrs(Attrs): """Dictionary attributes.""" diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 50b2c595b33f..a31be4c40ccb 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. """Common base structures.""" -import tvm._ffi +import tvm.ffi import tvm.error -from tvm._ffi import get_global_func, register_object +from tvm.ffi import get_global_func, register_object from tvm.runtime import Object, _ffi_node_api from . import _ffi_api, json_compact diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index c4d4fcc57807..ac4adc3306e6 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -22,7 +22,7 @@ and the DiagnosticRenderer. """ import enum -import tvm._ffi +import tvm.ffi from . import _ffi_api from ... import get_global_func, register_func, Object @@ -69,7 +69,7 @@ class DiagnosticLevel(enum.IntEnum): HELP = 50 -@tvm._ffi.register_object("Diagnostic") +@tvm.ffi.register_object("Diagnostic") class Diagnostic(Object): """A single diagnostic object from TVM.""" @@ -77,7 +77,7 @@ def __init__(self, level, span, message): self.__init_handle_by_constructor__(_ffi_api.Diagnostic, level, span, message) -@tvm._ffi.register_object("DiagnosticRenderer") +@tvm.ffi.register_object("DiagnosticRenderer") class DiagnosticRenderer(Object): """ A diagnostic renderer, which given a diagnostic context produces a "rendered" @@ -100,7 +100,7 @@ def render(self, ctx): # Register the diagnostic context. -@tvm._ffi.register_object("DiagnosticContext") +@tvm.ffi.register_object("DiagnosticContext") class DiagnosticContext(Object): """ A diagnostic context which records active errors diff --git a/python/tvm/ir/diagnostics/_ffi_api.py b/python/tvm/ir/diagnostics/_ffi_api.py index 430fd17f4d8a..fb157c977510 100644 --- a/python/tvm/ir/diagnostics/_ffi_api.py +++ b/python/tvm/ir/diagnostics/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI for TVM diagnostics.""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("diagnostics", __name__) +tvm.ffi._init_api("diagnostics", __name__) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index d140b5867c6e..197b0831bf25 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -18,7 +18,7 @@ from numbers import Number from typing import Optional -import tvm._ffi +import tvm.ffi from ..runtime import Object, Scriptable from . import _ffi_api @@ -71,7 +71,7 @@ def struct_info(self) -> Optional["tvm.relax.StructInfo"]: return _ffi_api.ExprStructInfo(self) -@tvm._ffi.register_object("GlobalVar") +@tvm.ffi.register_object("GlobalVar") class GlobalVar(RelaxExpr): """A global variable in the IR. @@ -117,7 +117,7 @@ def __call__(self, *args: RelaxExpr) -> BaseExpr: raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}") -@tvm._ffi.register_object +@tvm.ffi.register_object class Range(Node, Scriptable): """Represent a range in TVM. diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 4aee761af698..1e1505858f50 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -19,13 +19,13 @@ import inspect import functools -import tvm._ffi +import tvm.ffi import tvm.runtime from . import _ffi_instrument_api -@tvm._ffi.register_object("instrument.PassInstrument") +@tvm.ffi.register_object("instrument.PassInstrument") class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. @@ -225,7 +225,7 @@ def create_pass_instrument(pi_cls): return create_pass_instrument -@tvm._ffi.register_object("instrument.PassInstrument") +@tvm.ffi.register_object("instrument.PassInstrument") class PassTimingInstrument(tvm.runtime.Object): """A wrapper to create a passes time instrument that implemented in C++""" diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 8347e218beb9..6033dc6f8066 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -20,8 +20,7 @@ from typing import Dict, Union -import tvm._ffi -from tvm._ffi.base import string_types +import tvm.ffi from tvm.runtime import Scriptable from tvm.runtime.object import Object @@ -31,7 +30,7 @@ from .base import Node -@tvm._ffi.register_object("IRModule") +@tvm.ffi.register_object("IRModule") class IRModule(Node, Scriptable): """IRModule that holds functions and type definitions. @@ -49,7 +48,7 @@ def __init__(self, functions=None, attrs=None, global_infos=None): elif isinstance(functions, dict): mapped_funcs = {} for k, v in functions.items(): - if isinstance(k, string_types): + if isinstance(k, str): k = _expr.GlobalVar(k) if not isinstance(k, _expr.GlobalVar): raise TypeError("Expect functions to be Dict[GlobalVar, Function]") @@ -98,7 +97,7 @@ def __setitem__(self, var, val): def _add(self, var, val, update=True): if isinstance(val, _expr.RelaxExpr): - if isinstance(var, string_types): + if isinstance(var, str): if _ffi_api.Module_ContainGlobalVar(self, var): var = _ffi_api.Module_GetGlobalVar(self, var) else: @@ -118,7 +117,7 @@ def __getitem__(self, var): val: Union[Function, Type] The definition referenced by :code:`var` (either a function or type). """ - if isinstance(var, string_types): + if isinstance(var, str): return _ffi_api.Module_Lookup_str(self, var) assert isinstance(var, _expr.GlobalVar) return _ffi_api.Module_Lookup(self, var) diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 932aef24c60d..41105c4549dd 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=invalid-name """Primitive operators in the TVM IR.""" -import tvm._ffi +import tvm.ffi from . import _ffi_api from .expr import RelaxExpr -@tvm._ffi.register_object("Op") +@tvm.ffi.register_object("Op") class Op(RelaxExpr): """Primitive operator in the IR.""" diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index a501e8849e03..046432edfd99 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -20,7 +20,7 @@ from . import _ffi_api -@tvm._ffi.register_object("NameSupply") +@tvm.ffi.register_object("NameSupply") class NameSupply(Object): """NameSupply that can be used to generate unique names. @@ -77,7 +77,7 @@ def contains_name(self, name, add_prefix=True): return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) -@tvm._ffi.register_object("GlobalVarSupply") +@tvm.ffi.register_object("GlobalVarSupply") class GlobalVarSupply(Object): """GlobalVarSupply that holds a mapping between names and GlobalVars. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 644909f4d481..45050d44af0b 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -19,13 +19,13 @@ import inspect import functools -import tvm._ffi +import tvm.ffi import tvm.runtime from . import _ffi_transform_api -@tvm._ffi.register_object("transform.PassInfo") +@tvm.ffi.register_object("transform.PassInfo") class PassInfo(tvm.runtime.Object): """The class contains the meta data required by a pass. It is the container of information needed by running an optimization or analysis. @@ -50,7 +50,7 @@ def __init__(self, opt_level, name, required=None, traceable=False): ) -@tvm._ffi.register_object("transform.PassContext") +@tvm.ffi.register_object("transform.PassContext") class PassContext(tvm.runtime.Object): """The basis where a TVM optimization/analysis runs on. Each pass context contains a number of auxiliary information that is used @@ -209,7 +209,7 @@ def get_tuning_api_database(self): return _ffi_transform_api.GetTuningAPIDatabase(self) -@tvm._ffi.register_object("transform.Pass") +@tvm.ffi.register_object("transform.Pass") class Pass(tvm.runtime.Object): """The base class of all passes. All methods here are just simple wrappers that are implemented in the backend. They are defined for users to @@ -238,7 +238,7 @@ def __call__(self, mod): return _ffi_transform_api.RunPass(self, mod) -@tvm._ffi.register_object("transform.ModulePass") +@tvm.ffi.register_object("transform.ModulePass") class ModulePass(Pass): """A pass that works on tvm.IRModule. Users don't need to interact with this class directly. Instead, a module pass should be created through @@ -249,7 +249,7 @@ class ModulePass(Pass): """ -@tvm._ffi.register_object("transform.Sequential") +@tvm.ffi.register_object("transform.Sequential") class Sequential(Pass): """A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class. diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 3d372012b649..9ec8ef8fbd02 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -16,7 +16,7 @@ # under the License. """Unified type system in the project.""" import tvm -import tvm._ffi +import tvm.ffi from tvm.runtime import Scriptable from . import _ffi_api @@ -38,7 +38,7 @@ def same_as(self, other): return super().__eq__(other) -@tvm._ffi.register_object("PrimType") +@tvm.ffi.register_object("PrimType") class PrimType(Type): """Primitive data type in the low level IR @@ -52,7 +52,7 @@ def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) -@tvm._ffi.register_object("PointerType") +@tvm.ffi.register_object("PointerType") class PointerType(Type): """PointerType used in the low-level TIR. @@ -69,7 +69,7 @@ def __init__(self, element_type, storage_scope=""): self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope) -@tvm._ffi.register_object("TupleType") +@tvm.ffi.register_object("TupleType") class TupleType(Type): """The type of tuple values. @@ -83,7 +83,7 @@ def __init__(self, fields): self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) -@tvm._ffi.register_object("FuncType") +@tvm.ffi.register_object("FuncType") class FuncType(Type): """Function type. diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py index dba42dbce4a1..d0175fda5706 100644 --- a/python/tvm/ir/type_relation.py +++ b/python/tvm/ir/type_relation.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """Type relation and function for type checking.""" -import tvm._ffi +import tvm.ffi from .type import Type, TypeConstraint from . import _ffi_api -@tvm._ffi.register_object("TypeCall") +@tvm.ffi.register_object("TypeCall") class TypeCall(Type): """Type function application. @@ -43,7 +43,7 @@ def __init__(self, func, args): self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) -@tvm._ffi.register_object("TypeRelation") +@tvm.ffi.register_object("TypeRelation") class TypeRelation(TypeConstraint): """User defined type relation, it is an input-output relation on types. diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/libinfo.py similarity index 98% rename from python/tvm/_ffi/libinfo.py rename to python/tvm/libinfo.py index 55d4d8165aee..d05f448540aa 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/libinfo.py @@ -47,8 +47,8 @@ def get_dll_directories(): # An installed TVM's curr_path will look something like: # $PREFIX/lib/python3.6/site-packages/tvm/_ffi ffi_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - source_dir = os.path.join(ffi_dir, "..", "..", "..") - install_lib_dir = os.path.join(ffi_dir, "..", "..", "..", "..") + source_dir = os.path.join(ffi_dir, "..", "..") + install_lib_dir = os.path.join(ffi_dir, "..", "..", "..") dll_path = [] @@ -65,7 +65,7 @@ def get_dll_directories(): dll_path.extend(split_env_var("PATH", ";")) # Pip lib directory - dll_path.append(os.path.join(ffi_dir, "..")) + dll_path.append(ffi_dir) # Default cmake build directory dll_path.append(os.path.join(source_dir, "build")) dll_path.append(os.path.join(source_dir, "build", "Release")) diff --git a/python/tvm/meta_schedule/_ffi_api.py b/python/tvm/meta_schedule/_ffi_api.py index 24022191a8b4..89b8df086001 100644 --- a/python/tvm/meta_schedule/_ffi_api.py +++ b/python/tvm/meta_schedule/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.meta_schedule""" -from .._ffi import _init_api +from ..ffi import _init_api _init_api("meta_schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py index 7390c544a50b..69c8d6d4c5dc 100644 --- a/python/tvm/meta_schedule/arg_info.py +++ b/python/tvm/meta_schedule/arg_info.py @@ -17,7 +17,7 @@ """The argument information""" from typing import Any, List, Union -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir import IRModule from tvm.runtime import DataType, Object, ShapeTuple from tvm.tir import PrimFunc diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index 221077cfbd6c..f323e15bd532 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -21,7 +21,7 @@ from typing_extensions import Literal # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir import IRModule from tvm.runtime import NDArray, Object from tvm.target import Target diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index ae9ad6574e34..ff738c6265c3 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -19,7 +19,7 @@ import tempfile from typing import Callable, Dict, List, Optional, Union -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.ir import IRModule from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict from tvm.target import Target diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index 541154d4cc59..9abd50b94c75 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -24,7 +24,7 @@ # isort: on import numpy as np # type: ignore -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py b/python/tvm/meta_schedule/cost_model/mlp_model.py index 9167d30e9008..9191eee6a68f 100644 --- a/python/tvm/meta_schedule/cost_model/mlp_model.py +++ b/python/tvm/meta_schedule/cost_model/mlp_model.py @@ -542,7 +542,7 @@ def load( # pylint: disable=too-many-locals "_workload.json", "_candidates.json" ), ) - except tvm._ffi.base.TVMError: + except tvm.base.TVMError: continue candidates, results = [], [] tuning_records = database.get_all_tuning_records() diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py index aaee58fc94c8..5806454cdddb 100644 --- a/python/tvm/meta_schedule/cost_model/xgb_model.py +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -537,13 +537,17 @@ def _mean_cost(x: RunnerResult) -> float: self.last_train_size = self.data_size # Step 5. Re-train the model - self._train( - xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])), - ys=np.concatenate( - [g.min_cost / g.costs for g in self.data.values()], - axis=0, - ), - ) + with np.errstate(divide="ignore", invalid="ignore"): + feature_list = list( + itertools_chain.from_iterable([g.features for g in self.data.values()]) + ) + cost_ratio_list = [ + np.divide(g.min_cost, g.costs, out=np.zeros_like(g.costs), where=g.costs != 0) + for g in self.data.values() + ] + cost_ratios = np.concatenate(cost_ratio_list, axis=0) + + self._train(xs=feature_list, ys=cost_ratios) def predict( self, diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 601571089592..7abaead68018 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -22,7 +22,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object from tvm.target import Target diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index 102a13b90d98..f3b188493767 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -18,7 +18,7 @@ import os.path as osp from typing import Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .database import Database diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 34a6a141970a..53755333839c 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A database that stores TuningRecords in memory""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .database import Database diff --git a/python/tvm/meta_schedule/database/ordered_union_database.py b/python/tvm/meta_schedule/database/ordered_union_database.py index 35b0a9e282c1..a451d8ee2fd1 100644 --- a/python/tvm/meta_schedule/database/ordered_union_database.py +++ b/python/tvm/meta_schedule/database/ordered_union_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .database import Database diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index c7d175cb79d3..3b7dfa79f6bf 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -17,7 +17,7 @@ """A database for injecting handcrafted schedule functions.""" from typing import Callable -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.tir import Schedule from .. import _ffi_api diff --git a/python/tvm/meta_schedule/database/union_database.py b/python/tvm/meta_schedule/database/union_database.py index ae55ebe79614..7f896c1da61f 100644 --- a/python/tvm/meta_schedule/database/union_database.py +++ b/python/tvm/meta_schedule/database/union_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .database import Database diff --git a/python/tvm/meta_schedule/extracted_task.py b/python/tvm/meta_schedule/extracted_task.py index b69a38ef6dc0..0cdede120b6f 100644 --- a/python/tvm/meta_schedule/extracted_task.py +++ b/python/tvm/meta_schedule/extracted_task.py @@ -17,7 +17,7 @@ """Extracted tasks from high-level IR.""" from typing import List -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.target import Target diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index c14c97e0f526..bd37214db997 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -22,7 +22,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from tvm.runtime.ndarray import NDArray diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py index f6a456d707be..b1098bd4ea7c 100644 --- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -18,7 +18,7 @@ """We extract one feature vector per BufferStoreNode statement in a TIR Stmt, so we call this feature as "per-store" feature. """ -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .feature_extractor import FeatureExtractor diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py index ab61e87f647d..f40dffeaad44 100644 --- a/python/tvm/meta_schedule/measure_callback/add_to_database.py +++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A callback that adds the measurement results into the database""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .measure_callback import MeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py index d4a10c1e4009..17a7f45460e9 100644 --- a/python/tvm/meta_schedule/measure_callback/measure_callback.py +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -23,7 +23,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py index 4b2e1ab7f428..82c18f8f9065 100644 --- a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py +++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A callback that removes the build artifacts from the disk""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .measure_callback import MeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/update_cost_model.py b/python/tvm/meta_schedule/measure_callback/update_cost_model.py index c6ee1d26fe6d..5b8b0306d421 100644 --- a/python/tvm/meta_schedule/measure_callback/update_cost_model.py +++ b/python/tvm/meta_schedule/measure_callback/update_cost_model.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A measure callback that updates the cost model""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .measure_callback import MeasureCallback diff --git a/python/tvm/meta_schedule/mutator/mutate_compute_location.py b/python/tvm/meta_schedule/mutator/mutate_compute_location.py index bb361247bf62..5ebe04a6b13a 100644 --- a/python/tvm/meta_schedule/mutator/mutate_compute_location.py +++ b/python/tvm/meta_schedule/mutator/mutate_compute_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """A mutator that mutates the compute-at location decision of SampleComputeLocation""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py index c66dddb825f4..c7736fdcf71d 100644 --- a/python/tvm/meta_schedule/mutator/mutate_parallel.py +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the parallel extent""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py index 6a2553f94346..2225ca76c77d 100644 --- a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py +++ b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the thread binding extent""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py index ff432a6633b9..90cccdc3f5db 100644 --- a/python/tvm/meta_schedule/mutator/mutate_tile_size.py +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py index f81953d008d4..9575c3fc22d9 100644 --- a/python/tvm/meta_schedule/mutator/mutate_unroll.py +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates auto unroll step""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .mutator import Mutator diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 188cb30c5b69..6991c72bec41 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -22,7 +22,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Trace @@ -80,7 +80,7 @@ def create( "cuda", "cuda-tensorcore", "hexagon", - ] + ], ) -> Dict["Mutator", float]: """Create a list of default mutators. diff --git a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py index 0dcff9bf45a3..5c50b2064426 100644 --- a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py +++ b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that checks if the IRModule has any strided memory copies""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py index 5515d288e0e7..34c13aded935 100644 --- a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that checks if the IRModule has any loop with non-constant extent""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index af7fe9e9c502..33daabc3951c 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -22,7 +22,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Schedule diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py index e2d7c2212382..20c354ce601d 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -17,7 +17,7 @@ """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized cooperative fetching in loop bindings.""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_layout.py b/python/tvm/meta_schedule/postproc/rewrite_layout.py index 10addefee542..13556f1909d2 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_layout.py +++ b/python/tvm/meta_schedule/postproc/rewrite_layout.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that rewrites the layout of input tensor""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py index abe7288acba9..0be7cdbe118f 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -17,7 +17,7 @@ """A postprocessor that applies parallelization, vectorization and auto unrolling according to the annotation of each block""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py index 7e15ed493ccb..30c8cf9b0699 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that rewrites reduction block by moving the init block out.""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py index 85075c41b43c..e04ddcbdf223 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that tensorize related components.""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py index aef5bca690e4..ca4c9cdcd624 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that adds thread binding to unbound blocks""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py index 501e4423196c..1a74eadaa906 100644 --- a/python/tvm/meta_schedule/postproc/verify_gpu_code.py +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that verifies if the GPU code is correct""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py index 28d202d5b338..51a38624d28e 100644 --- a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py +++ b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py @@ -16,7 +16,7 @@ # under the License. """A postprocessor that verifies the VTCM usage of a given schedule.""" -from tvm._ffi.registry import register_object +from tvm.ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 7b7bb6e6d17f..65c1079d65b0 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from typing import Dict, Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from . import _ffi_api diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 0e5adc4e9982..613405c8ad3b 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -23,7 +23,7 @@ # isort: on -from tvm._ffi import get_global_func, register_func +from tvm.ffi import get_global_func, register_func from tvm.ir import IRModule from tvm.ir.transform import PassContext from tvm.runtime import NDArray diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 1a8f78414e91..0c2609469a19 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -22,7 +22,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py index 72f9fc92f96e..ceb18a6c3aa6 100644 --- a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -17,7 +17,7 @@ """Add-rfactor Rule that add-rfactor to some blocks if needed""" from typing import Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py index 29e25f992930..26f61aa8ceb6 100644 --- a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py @@ -16,7 +16,7 @@ # under the License. """Create a rule that applies customized rules registered using block attribute `schedule_rule`. The rule will be dispatched according to target keys.""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_bind.py b/python/tvm/meta_schedule/schedule_rule/auto_bind.py index 99a91f606e32..ef34e45061f7 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_bind.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_bind.py @@ -17,7 +17,7 @@ """Auto-bind Rule that binds blocks to threads if needed""" from typing import List, Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py index c84dbaf89b97..8cd122ec93d3 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -17,7 +17,7 @@ """Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" from typing import List, Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py index f242e42aea4b..d2c780b72854 100644 --- a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py +++ b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py @@ -17,7 +17,7 @@ """Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" from typing import List -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 19651a2ce18e..2f389190d662 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Callable from tvm.tir.schedule import Schedule, BlockRV -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py index a79ea918670e..e9626c40e39c 100644 --- a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py @@ -18,7 +18,7 @@ each block in a follow-up post processor""" from typing import List, Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py index 2355b0bfa8e5..81de07afbbed 100644 --- a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Rule that randomly select a compute-at location for a free block""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .schedule_rule import ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index b71a4c7c3538..5684e68c715f 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -25,7 +25,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import BlockRV, Schedule diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 44f32527fad9..1833ef23bda1 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Evolutionary Search Strategy""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .search_strategy import SearchStrategy diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py index f4660014241a..09e5c58d077a 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_func.py +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .search_strategy import SearchStrategy diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index e24ad5a5219a..a25596524451 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .search_strategy import SearchStrategy diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 8822b097945e..ab4a6fb7b636 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -24,7 +24,7 @@ from typing_extensions import Literal # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from tvm.tir.schedule import Schedule diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 930e8a51dc61..eee9ea0d0e5d 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Post Order Apply Space Generator.""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .space_generator import ( diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index 65956e843679..2cb1538a5abc 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Union of meta Schedule design space generators.""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .space_generator import ( diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index c1d9765067bc..8c9effa6e656 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -24,7 +24,7 @@ from typing_extensions import Literal # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.tir.schedule import Schedule diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py index e3d8f441d1ef..f512f6535550 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator_union.py +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -17,7 +17,7 @@ """Union of meta Schedule design space generators.""" from typing import List -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from .space_generator import ( diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index 963de8711e10..7bac23bb3fad 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Gradient Based Task Scheduler""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from ..logging import get_logger, get_logging_func diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index e5c7f14af424..6475b4102a1d 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Round Robin Task Scheduler""" -from tvm._ffi import register_object +from tvm.ffi import register_object from .. import _ffi_api from ..logging import get_logger, get_logging_func diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index d56d944474e9..9d6fec88b63b 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -22,7 +22,7 @@ # isort: on -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from .. import _ffi_api diff --git a/python/tvm/meta_schedule/testing/validate_database.py b/python/tvm/meta_schedule/testing/validate_database.py index ec7a4237546b..14ff32c0178a 100644 --- a/python/tvm/meta_schedule/testing/validate_database.py +++ b/python/tvm/meta_schedule/testing/validate_database.py @@ -25,7 +25,7 @@ import tvm from tvm import meta_schedule as ms -from tvm._ffi import get_global_func, register_func +from tvm.ffi import get_global_func, register_func from tvm.ir import IRModule from tvm.support import describe from tvm.target import Target @@ -469,9 +469,9 @@ def f_with_args_run_evaluator_common( number=evaluator_config.number, repeat=evaluator_config.repeat, min_repeat_ms=evaluator_config.min_repeat_ms, - f_preproc="cache_flush_cpu_non_first_arg" - if evaluator_config.enable_cpu_cache_flush - else "", + f_preproc=( + "cache_flush_cpu_non_first_arg" if evaluator_config.enable_cpu_cache_flush else "" + ), ) repeated_costs: List[List[float]] = [] diff --git a/python/tvm/meta_schedule/tir_integration.py b/python/tvm/meta_schedule/tir_integration.py index bffade49a072..b171c9711802 100644 --- a/python/tvm/meta_schedule/tir_integration.py +++ b/python/tvm/meta_schedule/tir_integration.py @@ -22,7 +22,7 @@ # isort: on from tvm import ir, tir -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.target import Target from tvm.tir.expr import IntImm diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 6f76452a57b5..5512b7a2682b 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -24,7 +24,7 @@ # isort: on from tvm import IRModule -from tvm._ffi import register_object, register_func +from tvm.ffi import register_object, register_func from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 8932fcdc3eaa..61b32e1e324b 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -22,7 +22,7 @@ import numpy as np # type: ignore import psutil # type: ignore -from tvm._ffi import get_global_func, register_func +from tvm.ffi import get_global_func, register_func from tvm.error import TVMError from tvm.ir import Array, IRModule, Map from tvm.rpc import RPCSession diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py index a127e1c81378..db1ca055865a 100644 --- a/python/tvm/relax/_ffi_api.py +++ b/python/tvm/relax/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI API for Relax.""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax", __name__) +tvm.ffi._init_api("relax", __name__) diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py index 40ee05c3960d..fb44606f1122 100644 --- a/python/tvm/relax/analysis/_ffi_api.py +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.analysis", __name__) +tvm.ffi._init_api("relax.analysis", __name__) diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py index d1378b2eacc2..17d7a18a338d 100644 --- a/python/tvm/relax/backend/_ffi_api.py +++ b/python/tvm/relax/backend/_ffi_api.py @@ -16,6 +16,6 @@ # under the License. """FFI API for Relax backend.""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.backend", __name__) +tvm.ffi._init_api("relax.backend", __name__) diff --git a/python/tvm/relax/backend/metal/coreml.py b/python/tvm/relax/backend/metal/coreml.py index b5caa688f221..139e5cc2b997 100644 --- a/python/tvm/relax/backend/metal/coreml.py +++ b/python/tvm/relax/backend/metal/coreml.py @@ -19,7 +19,7 @@ import os import shutil -import tvm._ffi +import tvm.ffi from tvm.contrib import coreml_runtime from tvm.contrib.xcode import compile_coreml @@ -144,7 +144,7 @@ def _conv2d_pattern(pattern_name): *default_unary_patterns(op_name="nn.avg_pool2d"), *conv2d_patterns(), *clip_patterns(), - *matmul_patterns() + *matmul_patterns(), # TODO(@tvm-team): enable when relax op is implemented # ("coreml.nn.batch_flatten", is_op("relax.nn.batch_flatten")(wildcard())), ] @@ -463,7 +463,7 @@ def compile(self, out_dir): compile_coreml(model, self.model_name, out_dir) -@tvm._ffi.register_func("relax.ext.coreml") +@tvm.ffi.register_func("relax.ext.coreml") def coreml_compiler(funcs, options, constant_names): """ Create a CoreML runtime from a Relax module. diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py index a9f6d878ad0d..22215206ac4b 100644 --- a/python/tvm/relax/binding_rewrite.py +++ b/python/tvm/relax/binding_rewrite.py @@ -20,13 +20,13 @@ from typing import Optional import tvm -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from . import Binding, DataflowBlock, Expr, Function, Var from . import _ffi_api -@tvm._ffi.register_object("relax.DataflowBlockRewrite") +@tvm.ffi.register_object("relax.DataflowBlockRewrite") class DataflowBlockRewrite(Object): """ A binding/statement-level dataflow block rewriter. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 37866840bd68..e09a9fab263a 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -100,7 +100,7 @@ def __exit__(self, ptype, value, trace): self._bb.end_scope() -@tvm._ffi.register_object("relax.BlockBuilder") +@tvm.ffi.register_object("relax.BlockBuilder") class BlockBuilder(Object): """A builder to build Relax IR for testing and dev. diff --git a/python/tvm/relax/distributed/_ffi_api.py b/python/tvm/relax/distributed/_ffi_api.py index 57411e82613f..6544a8d35572 100644 --- a/python/tvm/relax/distributed/_ffi_api.py +++ b/python/tvm/relax/distributed/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.distributed""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.distributed", __name__) +tvm.ffi._init_api("relax.distributed", __name__) diff --git a/python/tvm/relax/distributed/global_info.py b/python/tvm/relax/distributed/global_info.py index 9639fd885cb6..3f549ecfa37e 100644 --- a/python/tvm/relax/distributed/global_info.py +++ b/python/tvm/relax/distributed/global_info.py @@ -26,7 +26,7 @@ from . import _ffi_api as ffi -@tvm._ffi.register_object("relax.distributed.DeviceMesh") +@tvm.ffi.register_object("relax.distributed.DeviceMesh") class DeviceMesh(GlobalInfo): """Device mesh express a view of topology of devices, represented by an n-d matrix of device ids. diff --git a/python/tvm/relax/distributed/struct_info.py b/python/tvm/relax/distributed/struct_info.py index b5e258516d05..50087b98841a 100644 --- a/python/tvm/relax/distributed/struct_info.py +++ b/python/tvm/relax/distributed/struct_info.py @@ -33,7 +33,7 @@ class PlacementSpecKind(enum.IntEnum): kReplica = 1 -@tvm._ffi.register_object("relax.distributed.PlacementSpec") +@tvm.ffi.register_object("relax.distributed.PlacementSpec") class PlacementSpec(Object): """Describes how data is distributed in one dimension of the device mesh @@ -80,7 +80,7 @@ def replica() -> "PlacementSpec": return _ffi_api.Replica() -@tvm._ffi.register_object("relax.distributed.Placement") +@tvm.ffi.register_object("relax.distributed.Placement") class Placement(Object): """Describes how data is distributed in each dimension of the device mesh @@ -110,7 +110,7 @@ def from_text(text: str) -> "Placement": return _ffi_api.PlacementFromText(text) -@tvm._ffi.register_object("relax.DTensorStructInfo") +@tvm.ffi.register_object("relax.DTensorStructInfo") class DTensorStructInfo(StructInfo): """StructInfo of a Distributed Tensor value. diff --git a/python/tvm/relax/distributed/transform/_ffi_api.py b/python/tvm/relax/distributed/transform/_ffi_api.py index d064ae2bb931..b694a67116d2 100644 --- a/python/tvm/relax/distributed/transform/_ffi_api.py +++ b/python/tvm/relax/distributed/transform/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.distributed.transform""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.distributed.transform", __name__) +tvm.ffi._init_api("relax.distributed.transform", __name__) diff --git a/python/tvm/relax/dpl/_ffi.py b/python/tvm/relax/dpl/_ffi.py index 6699e42bee63..72bf073bedfc 100644 --- a/python/tvm/relax/dpl/_ffi.py +++ b/python/tvm/relax/dpl/_ffi.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """DataFlow Pattern Language FFI bindings.""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.dpl", __name__) +tvm.ffi._init_api("relax.dpl", __name__) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 3d076e8fad35..42486ee28948 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -23,7 +23,7 @@ from typing import Dict, List, Optional, Tuple, Union import tvm -import tvm._ffi as tvm_ffi +import tvm.ffi as tvm_ffi from tvm.ir.container import Array from tvm.ir.expr import PrimExpr from tvm.ir.op import Op diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 96c69e9266a2..a9782057c8fb 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -20,7 +20,7 @@ from tvm.ir import IRModule from tvm.runtime import Object -from tvm._ffi import register_object +from tvm.ffi import register_object from .pattern import DFPattern from .context import PatternContext diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 699860786072..0939dabe16ed 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -56,7 +56,7 @@ def __exit__(self, ptype, value, trace): self.exit_callback() -@tvm._ffi.register_object("relax.ExecBuilder") +@tvm.ffi.register_object("relax.ExecBuilder") class ExecBuilder(Object): """A builder to emit instructions and build executable for the virtual machine.""" diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 950de9eac022..0fa8c4df88f8 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -22,11 +22,10 @@ import numpy as _np # type: ignore import tvm -import tvm._ffi +import tvm.ffi import tvm.ir import tvm.relax from tvm import DataType -from tvm._ffi import base as _base from tvm.runtime import Object from tvm.runtime import ndarray as _nd @@ -43,7 +42,7 @@ GlobalVar = Union[tvm.ir.GlobalVar] -@tvm._ffi.register_object("relax.Id") +@tvm.ffi.register_object("relax.Id") class Id(Object): """Unique identifier(name) used in Var. Guaranteed to be stable across all passes. @@ -528,7 +527,7 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: return tvm.relax.Call(op, [self.tensor, axis]) -@tvm._ffi.register_object("relax.expr.Call") +@tvm.ffi.register_object("relax.expr.Call") class Call(ExprWithOp): """Function call node in Relax. @@ -577,7 +576,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.expr.If") +@tvm.ffi.register_object("relax.expr.If") class If(ExprWithOp): """A conditional expression in Relax. @@ -609,7 +608,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.expr.Tuple") +@tvm.ffi.register_object("relax.expr.Tuple") class Tuple(ExprWithOp): """Tuple expression that groups several fields together. @@ -644,7 +643,7 @@ def __len__(self) -> int: return len(self.fields) -@tvm._ffi.register_object("relax.expr.TupleGetItem") +@tvm.ffi.register_object("relax.expr.TupleGetItem") class TupleGetItem(ExprWithOp): """Get index-th item from a tuple. @@ -670,7 +669,7 @@ def __init__(self, tuple_value: Expr, index: int, span: Optional[Span] = None): ) -@tvm._ffi.register_object("relax.expr.ShapeExpr") +@tvm.ffi.register_object("relax.expr.ShapeExpr") class ShapeExpr(ExprWithOp): """A shape expression which allows users to construct a shape containing PrimExpr. @@ -708,7 +707,7 @@ def make_shape(shape: Union[List[Any], typing.Tuple[Any, ...]]) -> ShapeExpr: raise ValueError("Wrong type") -@tvm._ffi.register_object("relax.expr.Constant") +@tvm.ffi.register_object("relax.expr.Constant") class Constant(ExprWithOp): """Constant Tensor @@ -742,7 +741,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.expr.Var") +@tvm.ffi.register_object("relax.expr.Var") class Var(ExprWithOp): """The variable class for all Relax bindings. @@ -789,7 +788,7 @@ def name_hint(self) -> str: return name -@tvm._ffi.register_object("relax.expr.DataflowVar") +@tvm.ffi.register_object("relax.expr.DataflowVar") class DataflowVar(Var): """A sub-type of the variable node used to mark dataflow variables from normal visible "function local" bindings. @@ -838,7 +837,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.expr.PrimValue") +@tvm.ffi.register_object("relax.expr.PrimValue") class PrimValue(Expr, Scriptable): """The prim expr representing the value.""" @@ -850,7 +849,7 @@ def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore -@tvm._ffi.register_object("relax.expr.StringImm") +@tvm.ffi.register_object("relax.expr.StringImm") class StringImm(Expr, Scriptable): """Represent a string literal constant.""" @@ -861,7 +860,7 @@ def __init__(self, value: str, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) # type: ignore -@tvm._ffi.register_object("relax.expr.DataTypeImm") +@tvm.ffi.register_object("relax.expr.DataTypeImm") class DataTypeImm(Expr, Scriptable): """Represent a data type constant.""" @@ -872,7 +871,7 @@ def __init__(self, value: Union[DataType, str], span: Optional[Span] = None) -> self.__init_handle_by_constructor__(_ffi_api.DataTypeImm, value, span) # type: ignore -@tvm._ffi.register_object("relax.expr.Binding") +@tvm.ffi.register_object("relax.expr.Binding") class Binding(Node, Scriptable): """The base class of a binding in Relax.""" @@ -880,7 +879,7 @@ class Binding(Node, Scriptable): span: Optional[Span] -@tvm._ffi.register_object("relax.expr.MatchCast") +@tvm.ffi.register_object("relax.expr.MatchCast") class MatchCast(Binding): """Runtime-match the value to the struct info. @@ -912,7 +911,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.expr.VarBinding") +@tvm.ffi.register_object("relax.expr.VarBinding") class VarBinding(Binding): """Variable binding, bind he variable of the lhs with the rhs. @@ -934,7 +933,7 @@ def __init__(self, var: Var, value: Expr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.VarBinding, var, value, span) # type: ignore -@tvm._ffi.register_object("relax.expr.BindingBlock") +@tvm.ffi.register_object("relax.expr.BindingBlock") class BindingBlock(Node, Scriptable): """base class of binding block, bindings inside can be impure (with side effect or control flow)""" @@ -946,7 +945,7 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None self.__init_handle_by_constructor__(_ffi_api.BindingBlock, bindings, span) # type: ignore -@tvm._ffi.register_object("relax.expr.DataflowBlock") +@tvm.ffi.register_object("relax.expr.DataflowBlock") class DataflowBlock(BindingBlock): """dataflow block, bindings inside are pure (no side effect and no control flow)""" @@ -958,7 +957,7 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None self.__init_handle_by_constructor__(_ffi_api.DataflowBlock, bindings, span) # type: ignore -@tvm._ffi.register_object("relax.expr.SeqExpr") +@tvm.ffi.register_object("relax.expr.SeqExpr") class SeqExpr(ExprWithOp): """A sequence of binding blocks followed by an expression.""" @@ -970,7 +969,7 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Optional[Span] self.__init_handle_by_constructor__(_ffi_api.SeqExpr, blocks, body, span) # type: ignore -@tvm._ffi.register_object("relax.expr.Function") +@tvm.ffi.register_object("relax.expr.Function") class Function(BaseFunc, Scriptable): """A Relax function.""" @@ -1109,7 +1108,7 @@ def inline_functions( return _ffi_api.FunctionInlineFunctions(self, function_map) # type: ignore -@tvm._ffi.register_object("relax.expr.ExternFunc") +@tvm.ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc, ExprWithOp): """extern function, which represents a PackedFunc.""" @@ -1154,7 +1153,7 @@ def const( - bool maps to "bool" - other using the same default rule as numpy. """ - if isinstance(value, (_base.numeric_types, (bool, list))): + if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) if not dtype: diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index a0bb6df84373..49d3b14505ba 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -261,7 +261,7 @@ def visit_var_def(self, var: Var): raise TypeError("Invalid type: {0}".format(type(var))) -@tvm._ffi.register_object("expr_functor.PyExprVisitor") +@tvm.ffi.register_object("expr_functor.PyExprVisitor") class _PyExprVisitor(Object): """ A TVM object to support customization of ExprVisitor on the python side. @@ -781,7 +781,7 @@ def visit_span(self, span: Span) -> None: return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) # type: ignore -@tvm._ffi.register_object("expr_functor.PyExprMutator") +@tvm.ffi.register_object("expr_functor.PyExprMutator") class _PyExprMutator(Object): """ A TVM object to support customization of ExprMutator on the python side. @@ -1503,7 +1503,7 @@ def visit_with_new_scope(self, expr: Expr) -> Expr: def lookup_binding(self, var: Var) -> Optional[Expr]: """Look up the value bound to a variable. - Note: For function parameters, this function returns NullOpt. + Note: For function parameters, this function returns std::nullopt. Parameters ---------- diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index d89369c718b2..e7248b0f4b27 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -135,14 +135,14 @@ def shape_dtype_inference(a, b): of in-memory representation of tensors. More details: https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206. - To expose the symbol, `TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)` is guaranteed available: + To expose the symbol, `TVM_FFI_DLL_EXPORT_TYPED_FUNC(symbol, function)` is guaranteed available: .. code-block:: C++ // those headers are guaranteed to be available #include #include - #include + #include namespace { // anonymous namespace hides the symbol `_my_func_impl` from other translation units @@ -151,7 +151,7 @@ def shape_dtype_inference(a, b): } } // expose symbol `my_func` instead of `_my_func_impl` - TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl); **A compiler pass `AttachExternModules`.** It is introduced to attach a list of `nn.ExternModule`s into an IRModule at any stage of the compilation pipeline, diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 86be98cba786..ab416ef14176 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -930,6 +930,28 @@ def relu(x: Tensor, name: str = "relu") -> Tensor: return wrap_nested(_op.nn.relu(x._expr), name) +def relu6(x: Tensor, name: str = "relu6") -> Tensor: + r"""ReLU6 activation function. + + .. math:: + \text{ReLU6}(x) = \min(\max(x, 0), 6) + + Parameters + ---------- + x : Tensor + The input data. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.nn.relu6(x._expr), name) + + def silu(x: Tensor, name: str = "silu") -> Tensor: r"""Sigmoid Linear Unit function @@ -2065,7 +2087,7 @@ def extern( out: OutType, ) -> OutType: """Invoke an extern function during runtime. The extern function must be registered with the " - TVM runtime using `TVM_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python). + TVM runtime using `TVM_FFI_REGISTER_GLOBAL` (C++), or `tvm.register_func` (Python). Parameters ---------- diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 24217184b57c..bae308f435ce 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -60,13 +60,13 @@ def get_type(elem_type: Union[str, int]) -> str: return elem_type try: - from onnx.mapping import ( # pylint: disable=import-outside-toplevel - TENSOR_TYPE_TO_NP_TYPE, + from onnx.helper import ( # pylint: disable=import-outside-toplevel + tensor_dtype_to_np_dtype, ) except ImportError as exception: raise ImportError("Unable to import onnx which is required {}".format(exception)) - return str(TENSOR_TYPE_TO_NP_TYPE[elem_type]) + return str(tensor_dtype_to_np_dtype(elem_type)) def get_constant( @@ -564,15 +564,14 @@ def _impl_v18(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) -class BitwiseNot(BitwiseBase): +class BitwiseNot(OnnxOpConverter): """Converts an onnx BitwiseNot node into an equivalent Relax expression.""" - numpy_op = _np.bitwise_not - relax_op = relax.op.bitwise_not - @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params) + if isinstance(inputs[0], relax.Constant): + return relax.const(_np.bitwise_not(inputs[0].data.numpy()), inputs[0].struct_info.dtype) + return relax.op.bitwise_not(inputs[0]) class BitShift(BitwiseBase): @@ -2489,6 +2488,10 @@ def _impl_v17(cls, bb, inputs, attr, params): axis = attr.get("axis", -1) epsilon = attr.get("epsilon", 1e-05) + if bias is None: + seq_len = data.struct_info.shape[1].value + bias = relax.const([0.0] * seq_len, dtype="float32") + output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon) # Onnx layernorm has 3 outputs but only the first is used. # We construct two empty constants for this. @@ -3113,13 +3116,13 @@ def _get_convert_map(): "BitwiseAnd": BitwiseAnd, "BitwiseOr": BitwiseOr, "BitwiseXor": BitwiseXor, - "BitwiseNot": BitwiseNot, "BitShift": BitShift, "And": And, "Or": Or, "Xor": Xor, "Not": Not, # Unary operators + "BitwiseNot": BitwiseNot, "Log": Log, "Exp": Exp, "Acos": Acos, diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f8634f5da70e..485b7c088a15 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -510,12 +510,53 @@ def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: ########## Neural Network ########## + def _adaptive_avg_pool1d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] if len(node.args) > 1 else node.kwargs["output_size"] + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW") + ) + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] output_size = node.args[1] - return self.block_builder.emit( + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) + + result = self.block_builder.emit( relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _adaptive_avg_pool3d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW") + ) + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -539,6 +580,48 @@ def _addmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res + def _avg_pool1d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int]] = 1, + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + count_include_pad: Optional[bool] = True, + ) -> relax.Var: + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + stride = kernel_size if stride is None or stride == [] else stride + + result = self.block_builder.emit( + relax.op.nn.avg_pool1d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + layout="NCW", + ) + ) + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _avg_pool1d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + count_include_pad = args[5] if len(args) > 5 else kwargs.get("count_include_pad", True) + + return self._avg_pool1d_impl(x, kernel_size, stride, padding, ceil_mode, count_include_pad) + def _avg_pool2d_impl( self, x: relax.Expr, @@ -547,8 +630,13 @@ def _avg_pool2d_impl( padding: Optional[int] = 0, ceil_mode: Optional[bool] = False, ) -> relax.Var: + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) stride = kernel_size if stride is None or stride == [] else stride - return self.block_builder.emit( + + result = self.block_builder.emit( relax.op.nn.avg_pool2d( x, pool_size=kernel_size, @@ -558,6 +646,10 @@ def _avg_pool2d_impl( layout="NCHW", ) ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result def _avg_pool2d(self, node: fx.Node) -> relax.Var: args, kwargs = node.normalized_arguments(node) @@ -568,6 +660,48 @@ def _avg_pool2d(self, node: fx.Node) -> relax.Var: ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _avg_pool3d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1), + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + count_include_pad: Optional[bool] = True, + ) -> relax.Var: + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + stride = kernel_size if stride is None or stride == [] else stride + + result = self.block_builder.emit( + relax.op.nn.avg_pool3d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + layout="NCDHW", + ) + ) + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _avg_pool3d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + count_include_pad = args[5] if len(args) > 5 else kwargs.get("count_include_pad", True) + + return self._avg_pool3d_impl(x, kernel_size, stride, padding, ceil_mode, count_include_pad) + def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] batch1 = self.env[node.args[1]] @@ -599,6 +733,7 @@ def _conv_transpose1d_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], + output_padding: Optional[Tuple], ) -> relax.Var: conv1d_transpose = self.block_builder.emit( relax.op.nn.conv1d_transpose( @@ -608,8 +743,9 @@ def _conv_transpose1d_impl( padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_dtype="float32", ) ) @@ -628,8 +764,9 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None stride = args[3] if len(args) > 3 else 1 padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 + output_padding = args[5] if len(args) > 5 else 0 groups = args[6] if len(args) > 6 else 1 + dilation = args[7] if len(args) > 7 else 1 return self._conv_transpose1d_impl( x, weight, @@ -638,6 +775,7 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, ) def _conv_transpose2d_impl( @@ -649,6 +787,7 @@ def _conv_transpose2d_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], + output_padding: Optional[Tuple], ) -> relax.Var: conv2d_transpose = self.block_builder.emit( relax.op.nn.conv2d_transpose( @@ -658,8 +797,9 @@ def _conv_transpose2d_impl( padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_dtype="float32", ) ) @@ -678,8 +818,9 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None stride = args[3] if len(args) > 3 else 1 padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 + output_padding = args[5] if len(args) > 5 else 0 groups = args[6] if len(args) > 6 else 1 + dilation = args[7] if len(args) > 7 else 1 return self._conv_transpose2d_impl( x, weight, @@ -688,6 +829,7 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, ) def _conv1d_impl( @@ -837,6 +979,25 @@ def _conv3d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _cross_entropy_loss( + self, + preds: relax.Expr, + targets: relax.Expr, + weights: Optional[relax.Expr], + reduction: str, + ignore_index: int, + ) -> relax.Expr: + log_probs = relax.op.nn.log_softmax(preds) + return self.block_builder.emit( + relax.op.nn.nll_loss( + log_probs, + targets, + weights, + reduction, + ignore_index, + ) + ) + def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore @@ -923,6 +1084,50 @@ def _linear(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + def _max_pool1d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int]] = 1, + stride: Optional[Union[int, Tuple[int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + + stride = kernel_size if stride is None else stride + + result = self.block_builder.emit( + relax.op.nn.max_pool1d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCW", + ) + ) + + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _max_pool1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool2d_impl( self, x: relax.Expr, @@ -932,8 +1137,14 @@ def _max_pool2d_impl( dilation: Optional[int] = 1, ceil_mode: Optional[bool] = False, ) -> relax.Var: + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) + stride = kernel_size if stride is None else stride - return self.block_builder.emit( + + result = self.block_builder.emit( relax.op.nn.max_pool2d( x, pool_size=kernel_size, @@ -945,6 +1156,11 @@ def _max_pool2d_impl( ) ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result + def _max_pool2d(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -956,6 +1172,49 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool3d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int, int]] = (1, 1, 1), + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + + stride = kernel_size if stride is None else stride + + result = self.block_builder.emit( + relax.op.nn.max_pool3d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCDHW", + ) + ) + + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _max_pool3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pad(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] pad = node.args[1] @@ -1010,8 +1269,7 @@ def _unbind(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) assert isinstance(dim, int), "Expected 2nd argument of unbind as int" selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + ret, split = [], self.block_builder.emit(relax.op.split(x, selections, dim)) for i in range(selections): ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) return self.block_builder.emit(relax.Tuple(ret)) @@ -1260,6 +1518,19 @@ def _meshgrid(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.meshgrid(new_inputs, indexing=indexing)) + def _slice_scatter(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + src = args[1] + dim = args[2] if len(args) > 2 else node.kwargs.get("dim", 0) + start = args[3] if len(args) > 3 else node.kwargs.get("start", 0) + end = args[4] if len(args) > 4 else node.kwargs.get("end", self.shape_of(input_tensor)[dim]) + step = args[5] if len(args) > 5 else node.kwargs.get("step", 1) + + return self.block_builder.emit( + relax.op.slice_scatter(input_tensor, src, start, end, step, axis=dim) + ) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 01ed158b4229..6b4396621934 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -66,7 +66,7 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: ########## Neural Network ########## - def _batch_norm(self, node: fx.Node, training) -> relax.Var: + def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -113,6 +113,14 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: training = False return self._batch_norm(node, training) + def _cross_entropy_default(self, node: fx.Node) -> relax.Expr: + preds = self.env[node.args[0]] + targets = self.env[node.args[1]] + weight = self.env.get(node.args[2], None) if len(node.args) > 2 else None + reduction = node.kwargs.get("reduction", "mean") + ignore_index = node.kwargs.get("ignore_index", -100) + return self._cross_entropy_loss(preds, targets, weight, reduction, ignore_index) + def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_groups = node.args[1] @@ -200,6 +208,29 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None) + ) + if size is not None: + scale_factor = None + else: + scale_arg = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1) + if isinstance(scale_arg, (list, tuple)): + scale_factor = scale_arg[0] + else: + scale_factor = scale_arg + + return self._upsample_impl( + x, + size=size, + scale_factor=scale_factor, + method="cubic", + align_corners=align_corners, + ) + ########## Manipulation ########## def _narrow(self, node: fx.Node) -> relax.Var: @@ -339,6 +370,8 @@ def create_convert_map( "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), "relu_.default": self._unary_op(relax.op.nn.relu), + "relu6.default": self._unary_op(relax.op.nn.relu6), + "relu6_.default": self._unary_op(relax.op.nn.relu6), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), "rsub.Tensor": self._rsub, @@ -396,6 +429,9 @@ def create_convert_map( "mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul), "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne), "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne), + "outer.default": lambda node: self.block_builder.emit( + relax.op.outer(self.env[node.args[0]], self.env[node.args[1]]) + ), "pow.Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), @@ -412,9 +448,13 @@ def create_convert_map( "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, + "adaptive_avg_pool1d.default": self._adaptive_avg_pool1d, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "adaptive_avg_pool3d.default": self._adaptive_avg_pool3d, "addmm.default": self._addmm, + "avg_pool1d.default": self._avg_pool1d, "avg_pool2d.default": self._avg_pool2d, + "avg_pool3d.default": self._avg_pool3d, "baddbmm.default": self._baddbmm, "bmm.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul @@ -424,6 +464,7 @@ def create_convert_map( "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, "conv3d.default": self._conv3d, + "cross_entropy_loss.default": self._cross_entropy_default, "einsum.default": self._einsum, "embedding.default": lambda node: self._embedding_impl( self.env[node.args[1]], self.env[node.args[0]] @@ -432,11 +473,14 @@ def create_convert_map( "instance_norm.default": self._instance_norm, "layer_norm.default": self._layer_norm, "linear.default": self._linear, + "max_pool1d.default": self._max_pool1d, "max_pool2d.default": self._max_pool2d, + "max_pool3d.default": self._max_pool3d, "scaled_dot_product_attention.default": self._scaled_dot_product_attention, "unbind.int": self._unbind, "upsample_bilinear2d.vec": self._upsample_bilinear2d, "upsample_nearest2d.vec": self._upsample_nearest2d, + "upsample_bicubic2d.vec": self._upsample_bicubic2d, # statistical "mean.dim": self._mean, "prod.default": self._prod, @@ -473,6 +517,7 @@ def create_convert_map( "roll.default": self._roll, "select.int": self._select, "slice.Tensor": self._slice, + "slice_scatter.default": self._slice_scatter, "sort.default": self._sort, "split.Tensor": self._split, "split_with_sizes.default": self._split, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7310ce0af53e..b15e406339c7 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -182,13 +182,62 @@ def call_binary_op(op, lhs, rhs): ########## Neural Network ########## + def _adaptive_avg_pool1d_module(self, node: fx.Node) -> relax.Var: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + # Expand to 3D by adding batch dim if input is 2D + x_ndim = x.struct_info.ndim + if x_ndim == 2: + x = relax.op.expand_dims(x, axis=0) + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool1d(x, output_size, layout="NCW") # (N, C, L) + ) + # Remove added batch dim from result + if x_ndim == 2: + result = relax.op.squeeze(result, axis=[0]) + return result + def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] x = self.env[node.args[0]] output_size = module.output_size - return self.block_builder.emit( + # Expand to 4D by adding batch dim if input is 3D + x_ndim = x.struct_info.ndim + if x_ndim == 3: + x = relax.op.expand_dims(x, axis=0) + result = self.block_builder.emit( relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + # Remove added batch dim from result + if x_ndim == 3: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _adaptive_avg_pool3d_module(self, node: fx.Node) -> relax.Var: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + # Expand to 5D by adding batch dim if input is 4D + x_ndim = x.struct_info.ndim + if x_ndim == 4: + x = relax.op.expand_dims(x, axis=0) + result = self.block_builder.emit( + relax.op.nn.adaptive_avg_pool3d(x, output_size, layout="NCDHW") # (N, C, D, H, W) + ) + # Remove added batch dim from result + if x_ndim == 4: + result = relax.op.squeeze(result, axis=[0]) + return result + + def _avg_pool1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool1d_impl(x, kernel_size, stride, padding, ceil_mode) def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -199,6 +248,15 @@ def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _avg_pool3d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool3d_impl(x, kernel_size, stride, padding, ceil_mode) + def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -266,6 +324,7 @@ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, + output_padding=module.output_padding, ) def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: @@ -282,6 +341,7 @@ def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, + output_padding=module.output_padding, ) def _conv1d_module(self, node: fx.Node) -> relax.Var: @@ -338,12 +398,7 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr: weights = self.env.get(node.kwargs["weight"], None) reduction = node.kwargs["reduction"] ignore_index = node.kwargs["ignore_index"] - - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) - ) + return self._cross_entropy_loss(preds, targets, weights, reduction, ignore_index) def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: preds = self.env[node.args[0]] @@ -360,10 +415,12 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: reduction = module.reduction ignore_index = module.ignore_index - return self.block_builder.emit( - relax.op.nn.nll_loss( - relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index - ) + return self._cross_entropy_loss( + preds, + targets, + weights, + reduction, + ignore_index, ) def _embedding_module(self, node: fx.Node) -> relax.Var: @@ -479,6 +536,17 @@ def _linear_module(self, node: fx.Node) -> relax.Var: bias = self.params.get(module.bias, None) return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + def _max_pool1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool1d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -490,6 +558,17 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool3d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool3d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -669,9 +748,7 @@ def create_convert_map( nn.LogSoftmax: self._log_softmax_module, nn.PReLU: self._prelu_module, nn.ReLU: self._unary_op(relax.op.nn.relu), - nn.ReLU6: lambda node: self.block_builder.emit( - relax.op.clip(self.env[node.args[0]], 0, 6) - ), + nn.ReLU6: self._unary_op(relax.op.nn.relu6), nn.Sigmoid: self._unary_op(relax.op.sigmoid), nn.SELU: self._unary_op(relax.op.nn.selu), nn.SiLU: self._unary_op(relax.op.nn.silu), @@ -679,8 +756,12 @@ def create_convert_map( nn.Softplus: self._softplus_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network + nn.AdaptiveAvgPool1d: self._adaptive_avg_pool1d_module, nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, + nn.AdaptiveAvgPool3d: self._adaptive_avg_pool3d_module, + nn.AvgPool1d: self._avg_pool1d_module, nn.AvgPool2d: self._avg_pool2d_module, + nn.AvgPool3d: self._avg_pool3d_module, nn.BatchNorm2d: self._batch_norm_2d_module, nn.InstanceNorm1d: self._instance_norm, nn.InstanceNorm2d: self._instance_norm, @@ -694,7 +775,9 @@ def create_convert_map( nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, nn.Linear: self._linear_module, + nn.MaxPool1d: self._max_pool1d_module, nn.MaxPool2d: self._max_pool2d_module, + nn.MaxPool3d: self._max_pool3d_module, nn.modules.sparse.Embedding: self._embedding_module, nn.PixelShuffle: self._pixel_shuffle_module, # tensor manipulation @@ -740,6 +823,7 @@ def create_convert_map( "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), + "relu6": self._unary_op(relax.op.nn.relu6), "round": self._round, "rsqrt": self._unary_op(relax.op.rsqrt), "selu": self._unary_op(relax.op.nn.selu), @@ -782,6 +866,9 @@ def create_convert_map( "mod": self._binary_op(relax.op.floor_mod, operator.mod), "mul": self._binary_op(relax.op.multiply, operator.mul), "ne": self._binary_op(relax.op.not_equal, operator.ne), + "outer": lambda node: self.block_builder.emit( + relax.op.outer(self.env[node.args[0]], self.env[node.args[1]]) + ), "pow": self._binary_op(relax.op.power, operator.pow), "or_": self._binary_op(relax.op.bitwise_or, operator.or_), "rshift": self._binary_op(relax.op.right_shift, operator.rshift), @@ -790,9 +877,13 @@ def create_convert_map( "truediv": self._binary_op(relax.op.divide, operator.truediv), "xor": self._binary_op(relax.op.bitwise_xor, operator.xor), # neural network + "adaptive_avg_pool1d": self._adaptive_avg_pool1d, "adaptive_avg_pool2d": self._adaptive_avg_pool2d, + "adaptive_avg_pool3d": self._adaptive_avg_pool3d, "addmm": self._addmm, + "avg_pool1d": self._avg_pool1d, "avg_pool2d": self._avg_pool2d, + "avg_pool3d": self._avg_pool3d, "baddbmm": self._baddbmm, "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul @@ -807,7 +898,9 @@ def create_convert_map( "interpolate": self._interpolate, "layer_norm": self._layer_norm, "linear": self._linear, + "max_pool1d": self._max_pool1d, "max_pool2d": self._max_pool2d, + "max_pool3d": self._max_pool3d, "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], "unbind": self._unbind, @@ -849,6 +942,7 @@ def create_convert_map( "scatter": self._scatter, "select": self._select, "size": self._size, + "slice_scatter": self._slice_scatter, "sort": self._sort, "split": self._split, "squeeze": self._squeeze, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index bfc0a997dfc8..c4a5d2fd2329 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -84,7 +84,7 @@ ) from .datatype import astype, wrap_param from .index import dynamic_strided_slice, strided_slice, take -from .linear_algebra import einsum, linear, matmul +from .linear_algebra import einsum, linear, matmul, outer from .manipulate import ( broadcast_to, collapse_sum_like, @@ -105,6 +105,7 @@ reshape, scatter_elements, scatter_nd, + slice_scatter, split, squeeze, stack, diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py index 8dc6a1b4fbb0..1d16a024d1d4 100644 --- a/python/tvm/relax/op/_ffi_api.py +++ b/python/tvm/relax/op/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op", __name__) +tvm.ffi._init_api("relax.op", __name__) diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 6878f9733163..41eaa5de5008 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -21,7 +21,7 @@ from typing import List from tvm import relax -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.arith import Analyzer from tvm.relax.struct_info import ShapeStructInfo @@ -1090,7 +1090,7 @@ def log_softmax_grad( Returns `[y_grad - sum(y_grad, axis, keepdims=True) * exp(y)]` """ y_exp = exp(orig_var) - return [(output_grad - sum(output_grad, orig_call.attrs.axis, True) * y_exp)] + return [output_grad - sum(output_grad, orig_call.attrs.axis, True) * y_exp] @register_gradient("relax.nn.cross_entropy_with_logits") diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py index 42fe8cb65234..a7f48af57697 100644 --- a/python/tvm/relax/op/builtin/_ffi_api.py +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.builtin""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.builtin", __name__) +tvm.ffi._init_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/ccl/_ffi_api.py b/python/tvm/relax/op/ccl/_ffi_api.py index cdf468781061..bf605aae6ab0 100644 --- a/python/tvm/relax/op/ccl/_ffi_api.py +++ b/python/tvm/relax/op/ccl/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Operators serving for Collective Communications Library (CCL) operators""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.ccl", __name__) +tvm.ffi._init_api("relax.op.ccl", __name__) diff --git a/python/tvm/relax/op/distributed/_ffi_api.py b/python/tvm/relax/op/distributed/_ffi_api.py index 9b1b4d68d6da..394cb8c262b2 100644 --- a/python/tvm/relax/op/distributed/_ffi_api.py +++ b/python/tvm/relax/op/distributed/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.distributed""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.dist", __name__) +tvm.ffi._init_api("relax.op.dist", __name__) diff --git a/python/tvm/relax/op/grad/_ffi_api.py b/python/tvm/relax/op/grad/_ffi_api.py index 9b819dd4df29..415d590f01f0 100644 --- a/python/tvm/relax/op/grad/_ffi_api.py +++ b/python/tvm/relax/op/grad/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.grad""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.grad", __name__) +tvm.ffi._init_api("relax.op.grad", __name__) diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py index e666203ae7ff..8c813231f9a0 100644 --- a/python/tvm/relax/op/image/_ffi_api.py +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.image", __name__) +tvm.ffi._init_api("relax.op.image", __name__) diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py index e314e9b49af5..6bec22161dbc 100644 --- a/python/tvm/relax/op/image/image.py +++ b/python/tvm/relax/op/image/image.py @@ -35,7 +35,7 @@ def resize2d( method: str = "linear", coordinate_transformation_mode: str = "half_pixel", rounding_method: str = "round", - cubic_alpha: float = -0.5, + cubic_alpha: float = -0.75, cubic_exclude: int = 0, extrapolation_value: float = 0.0, out_dtype: Optional[Union[str, DataType]] = None, diff --git a/python/tvm/relax/op/linear_algebra.py b/python/tvm/relax/op/linear_algebra.py index efb5085c7882..9b091195763e 100644 --- a/python/tvm/relax/op/linear_algebra.py +++ b/python/tvm/relax/op/linear_algebra.py @@ -110,3 +110,30 @@ def einsum(operands, subscripts): operands = RxTuple(operands) return _ffi_api.einsum(operands, subscripts) # type: ignore + + +def outer(x1: Expr, x2: Expr) -> Expr: + """ + Computes the outer product of two input expressions. + + Parameters + ---------- + x1 : relax.Expr + The first input expression. + + x2 : relax.Expr + The second input expression. + + Notes + ----- + This operation computes the outer product between two expressions, + resulting in a tensor where each element is the product of elements + from `x1` and `x2`. It is commonly used in tensor and matrix operations + to expand lower-dimensional inputs into higher-dimensional representations. + + Returns + ------- + result : relax.Expr + The resulting expression representing the outer product. + """ + return _ffi_api.outer(x1, x2) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index b52aced59ae9..c71b19494a41 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -786,6 +786,44 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore +def slice_scatter(input_tensor: Expr, src: Expr, start, end, step, axis=0): + """Embeds the values of the src tensor into input at the given dimension. + + Parameters + ---------- + input_tensor: relax.Expr + The input tensor to be updated. + + src: relax.Expr + The tensor to embed into input. + + axis: int + The dimension to insert the slice into. + + start: + The start index of where to insert the slice. + + end: + The end index of where to insert the slice. + + step: + The how many elements to skip in. + + Returns + ------- + result : relax.Expr + The computed result tensor with the same shape as `data`. + + """ + if not isinstance(start, PrimValue): + start = PrimValue(start) + if not isinstance(end, PrimValue): + end = PrimValue(end) + if not isinstance(step, PrimValue): + step = PrimValue(step) + return _ffi_api.slice_scatter(input_tensor, src, axis, start, end, step) + + def one_hot( indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1 ) -> Expr: diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py index 475de481b22e..fb829b7db953 100644 --- a/python/tvm/relax/op/memory/_ffi_api.py +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.memory""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.memory", __name__) +tvm.ffi._init_api("relax.op.memory", __name__) diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 0b90d0cca831..d12a3ee3636f 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -47,6 +47,7 @@ pixel_shuffle, prelu, relu, + relu6, rms_norm, selu, silu, diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py index 1785345ac1b1..b5f735127ec2 100644 --- a/python/tvm/relax/op/nn/_ffi_api.py +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.nn", __name__) +tvm.ffi._init_api("relax.op.nn", __name__) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index bf4b02c963ef..b193f93c0f85 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -17,7 +17,7 @@ """Relax Neural Network (NN) operators""" from typing import List, Optional, Tuple, Union -from tvm import DataType +from tvm import DataType, relax from tvm.tir import FloatImm from ...expr import Expr @@ -840,7 +840,7 @@ def avg_pool1d( padding: Union[int, Tuple[int, ...]] = (0, 0), dilation: Union[int, Tuple[int, int]] = (1,), ceil_mode: bool = False, - count_include_pad: bool = False, + count_include_pad: bool = True, layout: str = "NCW", out_layout: Optional[str] = None, ) -> Expr: @@ -1008,7 +1008,7 @@ def avg_pool3d( padding: Union[int, Tuple[int, ...]] = (0, 0, 0), dilation: Union[int, Tuple[int, int]] = (1, 1, 1), ceil_mode: bool = False, - count_include_pad: bool = False, + count_include_pad: bool = True, layout: str = "NCDHW", out_layout: Optional[str] = None, ) -> Expr: @@ -1267,6 +1267,25 @@ def relu(data: Expr) -> Expr: return _ffi_api.relu(data) # type: ignore +def relu6(data: Expr) -> Expr: + r"""ReLU6 activation function. + + .. math:: + \text{ReLU6}(x) = \min(\max(x, 0), 6) + + Parameters + ---------- + data : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return relax.op.clip(data, 0, 6) + + def leakyrelu(data: Expr, alpha: float = 0.01) -> Expr: """Rectified linear unit. diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 24eece70a941..dd30215ef654 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -16,184 +16,184 @@ # under the License. """The attributes node used for Relax operators""" from tvm.ir import Attrs -import tvm._ffi +import tvm.ffi -@tvm._ffi.register_object("relax.attrs.CallTIRWithGradAttrs") +@tvm.ffi.register_object("relax.attrs.CallTIRWithGradAttrs") class CallTIRWithGradAttrs(Attrs): """Attributes used in call_tir_with_grad operator""" -@tvm._ffi.register_object("relax.attrs.InitAttrs") +@tvm.ffi.register_object("relax.attrs.InitAttrs") class InitAttrs(Attrs): """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" -@tvm._ffi.register_object("relax.attrs.TriluAttrs") +@tvm.ffi.register_object("relax.attrs.TriluAttrs") class TriluAttrs(Attrs): """Attributes used in tril and triu operator""" -@tvm._ffi.register_object("relax.attrs.AstypeAttrs") +@tvm.ffi.register_object("relax.attrs.AstypeAttrs") class AstypeAttrs(Attrs): """Attributes used in astype operator""" -@tvm._ffi.register_object("relax.attrs.TakeAttrs") +@tvm.ffi.register_object("relax.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes used in take operator""" -@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") +@tvm.ffi.register_object("relax.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" -@tvm._ffi.register_object("relax.attrs.MatmulAttrs") +@tvm.ffi.register_object("relax.attrs.MatmulAttrs") class MatmulAttrs(Attrs): """Attributes for matmul operator""" -@tvm._ffi.register_object("relax.attrs.Conv2DAttrs") +@tvm.ffi.register_object("relax.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" -@tvm._ffi.register_object("relax.attrs.Conv3DAttrs") +@tvm.ffi.register_object("relax.attrs.Conv3DAttrs") class Conv3DAttrs(Attrs): """Attributes for nn.conv3d""" -@tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs") +@tvm.ffi.register_object("relax.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes for nn.conv2d_transpose""" -@tvm._ffi.register_object("relax.attrs.Pool2DAttrs") +@tvm.ffi.register_object("relax.attrs.Pool2DAttrs") class Pool2DAttrs(Attrs): """Attributes for nn.max_pool2d""" -@tvm._ffi.register_object("relax.attrs.AdaptivePool2DAttrs") +@tvm.ffi.register_object("relax.attrs.AdaptivePool2DAttrs") class AdaptivePool2DAttrs(Attrs): """Attributes for 2d adaptive pool operator""" -@tvm._ffi.register_object("relax.attrs.SoftmaxAttrs") +@tvm.ffi.register_object("relax.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" -@tvm._ffi.register_object("relax.attrs.BatchNormAttrs") +@tvm.ffi.register_object("relax.attrs.BatchNormAttrs") class BatchNormAttrs(Attrs): """Attributes used in batch_norm operator""" -@tvm._ffi.register_object("relax.attrs.LayerNormAttrs") +@tvm.ffi.register_object("relax.attrs.LayerNormAttrs") class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" -@tvm._ffi.register_object("relax.attrs.InstanceNormAttrs") +@tvm.ffi.register_object("relax.attrs.InstanceNormAttrs") class InstanceNormAttrs(Attrs): """Attributes used in instance_norm operator""" -@tvm._ffi.register_object("relax.attrs.DropoutAttrs") +@tvm.ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for dropout operator""" -@tvm._ffi.register_object("relax.attrs.StatisticalAttrs") +@tvm.ffi.register_object("relax.attrs.StatisticalAttrs") class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" -@tvm._ffi.register_object("relax.attrs.ConcatAttrs") +@tvm.ffi.register_object("relax.attrs.ConcatAttrs") class ConcatAttrs(Attrs): """Attributes for concat operator""" -@tvm._ffi.register_object("relax.attrs.ExpandDimsAttrs") +@tvm.ffi.register_object("relax.attrs.ExpandDimsAttrs") class ExpandDimsAttrs(Attrs): """Attributes for expand_dims operator""" -@tvm._ffi.register_object("relax.attrs.PermuteDimsAttrs") +@tvm.ffi.register_object("relax.attrs.PermuteDimsAttrs") class PermuteDimsAttrs(Attrs): """Attributes for permute_dims operator""" -@tvm._ffi.register_object("relax.attrs.SortAttrs") +@tvm.ffi.register_object("relax.attrs.SortAttrs") class SortAttrs(Attrs): """Attributes for sort operator""" -@tvm._ffi.register_object("relax.attrs.ArgsortAttrs") +@tvm.ffi.register_object("relax.attrs.ArgsortAttrs") class ArgsortAttrs(Attrs): """Attributes for argsort operator""" -@tvm._ffi.register_object("relax.attrs.SplitAttrs") +@tvm.ffi.register_object("relax.attrs.SplitAttrs") class SplitAttrs(Attrs): """Attributes used in split operator""" -@tvm._ffi.register_object("relax.attrs.SqueezeAttrs") +@tvm.ffi.register_object("relax.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): """Attributes for squeeze operator""" -@tvm._ffi.register_object("relax.attrs.StackAttrs") +@tvm.ffi.register_object("relax.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes for concat operator""" -@tvm._ffi.register_object("relax.attrs.IndexPutAttrs") +@tvm.ffi.register_object("relax.attrs.IndexPutAttrs") class IndexPutAttrs(Attrs): """Attributes for index_put operator""" -@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") +@tvm.ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" -@tvm._ffi.register_object("relax.attrs.Resize2DAttrs") +@tvm.ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" -@tvm._ffi.register_object("relax.attrs.ArgmaxArgminAttrs") +@tvm.ffi.register_object("relax.attrs.ArgmaxArgminAttrs") class ArgmaxArgminAttrs(Attrs): """Attributes for argmax/argmin operator""" -@tvm._ffi.register_object("relax.attrs.RepeatAttrs") +@tvm.ffi.register_object("relax.attrs.RepeatAttrs") class RepeatAttrs(Attrs): """Attributes for repeat operator""" -@tvm._ffi.register_object("relax.attrs.TileAttrs") +@tvm.ffi.register_object("relax.attrs.TileAttrs") class TileAttrs(Attrs): """Attributes for tile operator""" -@tvm._ffi.register_object("relax.attrs.ScanopAttrs") +@tvm.ffi.register_object("relax.attrs.ScanopAttrs") class ScanopAttrs(Attrs): """Attributes for scan operators""" -@tvm._ffi.register_object("relax.attrs.TopKAttrs") +@tvm.ffi.register_object("relax.attrs.TopKAttrs") class TopKAttrs(Attrs): """Attributes for topk operators""" -@tvm._ffi.register_object("relax.attrs.EinsumAttrs") +@tvm.ffi.register_object("relax.attrs.EinsumAttrs") class EinsumAttrs(Attrs): """Attributes for einsum operator""" -@tvm._ffi.register_object("relax.attrs.FlipAttrs") +@tvm.ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" diff --git a/python/tvm/relax/op/vm/_ffi_api.py b/python/tvm/relax/op/vm/_ffi_api.py index 786b73c76c64..f3b6cea13b67 100644 --- a/python/tvm/relax/op/vm/_ffi_api.py +++ b/python/tvm/relax/op/vm/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.vm""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.op.vm", __name__) +tvm.ffi._init_api("relax.op.vm", __name__) diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index de1b1ac3bfc3..c143f098328c 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -18,7 +18,7 @@ """The struct info nodes of the Relax language.""" from typing import List, Optional, Union -import tvm._ffi +import tvm.ffi import tvm from tvm.ir import Span, EnvFunc, Array, VDevice @@ -29,7 +29,7 @@ from . import _ffi_api, ty, expr -@tvm._ffi.register_object("relax.ObjectStructInfo") +@tvm.ffi.register_object("relax.ObjectStructInfo") class ObjectStructInfo(StructInfo): """StructInfo of an Object.""" @@ -37,7 +37,7 @@ def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore -@tvm._ffi.register_object("relax.PrimStructInfo") +@tvm.ffi.register_object("relax.PrimStructInfo") class PrimStructInfo(StructInfo): """StructInfo of a primitive POD value. @@ -107,7 +107,7 @@ def __init__( ) # type: ignore -@tvm._ffi.register_object("relax.ShapeStructInfo") +@tvm.ffi.register_object("relax.ShapeStructInfo") class ShapeStructInfo(StructInfo): """StructInfo of a shape value. @@ -136,7 +136,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.TensorStructInfo") +@tvm.ffi.register_object("relax.TensorStructInfo") class TensorStructInfo(StructInfo): """StructInfo of a Tensor value. @@ -180,7 +180,7 @@ def __init__( ) -@tvm._ffi.register_object("relax.TupleStructInfo") +@tvm.ffi.register_object("relax.TupleStructInfo") class TupleStructInfo(StructInfo): """StructInfo of a Tuple value. @@ -197,7 +197,7 @@ def __init__(self, fields: List[StructInfo], span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TupleStructInfo, fields, span) # type: ignore -@tvm._ffi.register_object("relax.FuncStructInfo") +@tvm.ffi.register_object("relax.FuncStructInfo") class FuncStructInfo(StructInfo): """StructInfo of a function value. diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 02c79bd4fa6e..198b07e51ea7 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -70,7 +70,7 @@ def dataflow_alias_analysis( return res_alias_sets, res_tuple_map # type: ignore -@tvm._ffi.register_object("relax.transform.InplaceOpportunity") +@tvm.ffi.register_object("relax.transform.InplaceOpportunity") class InplaceOpportunity(Object): """ Represents an opportunity to make a binding in-place. Exposed only for testing; diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py index 70cb83fc0e6f..9b7dbcdee748 100644 --- a/python/tvm/relax/training/_ffi_api.py +++ b/python/tvm/relax/training/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.training""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.training", __name__) +tvm.ffi._init_api("relax.training", __name__) diff --git a/python/tvm/relax/training/utils.py b/python/tvm/relax/training/utils.py index 4d1a32177227..dd433435e278 100644 --- a/python/tvm/relax/training/utils.py +++ b/python/tvm/relax/training/utils.py @@ -21,7 +21,7 @@ import tvm from tvm import relax -from tvm._ffi.registry import register_func +from tvm.ffi.registry import register_func from tvm.relax.block_builder import BlockBuilder from ..expr import Function, Var, Call diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py index 667aa62c2c95..3c4387a3cbb8 100644 --- a/python/tvm/relax/transform/_ffi_api.py +++ b/python/tvm/relax/transform/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.transform""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.transform", __name__) +tvm.ffi._init_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index 318c9521f31a..154afa9dffca 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -115,3 +115,22 @@ def _einsum(bb: BlockBuilder, call: Call) -> Expr: t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) return bb.call_te(topi.einsum, call.attrs.subscripts, *fields) + + +@register_legalize("relax.outer") +def _outer(bb: BlockBuilder, call: Call) -> Expr: + def te_outer(a: te.Tensor, b: te.Tensor) -> te.Tensor: + a_shape = list(a.shape) + b_shape = list(b.shape) + assert len(a_shape) == 1 and len(b_shape) == 1, "outer requires 1D tensors" + + n = a_shape[0] + m = b_shape[0] + + def compute_fn(i, j): + return a[i] * b[j] + + return te.compute((n, m), compute_fn, name="outer") + + lhs, rhs = call.args + return bb.call_te(te_outer, lhs, rhs, primfunc_name_hint="outer") diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 835be4bd4ef8..58abe434a23a 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -263,6 +263,20 @@ def scatter_nd(data, indices, updates, reduction): ) +@register_legalize("relax.slice_scatter") +def _slice_scatter(bb: BlockBuilder, call: Call) -> Expr: + + return bb.call_te( + topi.slice_scatter, + call.args[0], + call.args[1], + call.args[2], + call.args[3], + call.args[4], + call.attrs.axis, + ) + + @register_legalize("relax.one_hot") def _one_hot(bb: BlockBuilder, call: Call) -> Expr: indices, on_value, off_value = call.args diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 4ee53b032918..8e74f0897720 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -36,14 +36,14 @@ from ..expr import Var -@tvm._ffi.register_object("relax.FunctionPass") +@tvm.ffi.register_object("relax.FunctionPass") class FunctionPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.Function in a module. A function pass class should be created through `function_pass`. """ -@tvm._ffi.register_object("relax.DataflowBlockPass") +@tvm.ffi.register_object("relax.DataflowBlockPass") class DataflowBlockPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.DataflowBlock in a module.""" @@ -820,7 +820,7 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() # type: ignore -@tvm._ffi.register_object("relax.transform.PatternCheckContext") +@tvm.ffi.register_object("relax.transform.PatternCheckContext") class PatternCheckContext(Object): """ The input of check function `FusionPattern.check`. @@ -854,7 +854,7 @@ class PatternCheckContext(Object): value_to_bound_var: Mapping[Expr, Var] -@tvm._ffi.register_object("relax.transform.FusionPattern") +@tvm.ffi.register_object("relax.transform.FusionPattern") class FusionPattern(Object): """ The pattern used by `FuseOpsByPattern`. It's mainly DFPattern but with other diff --git a/python/tvm/relax/transform/tuning_api/_ffi_api.py b/python/tvm/relax/transform/tuning_api/_ffi_api.py index f31522d02595..54caece700ef 100644 --- a/python/tvm/relax/transform/tuning_api/_ffi_api.py +++ b/python/tvm/relax/transform/tuning_api/_ffi_api.py @@ -14,6 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for relax.tuning_api""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("relax.tuning_api", __name__) +tvm.ffi._init_api("relax.tuning_api", __name__) diff --git a/python/tvm/relax/transform/tuning_api/database.py b/python/tvm/relax/transform/tuning_api/database.py index 9477e142bad4..cbc103423b0f 100644 --- a/python/tvm/relax/transform/tuning_api/database.py +++ b/python/tvm/relax/transform/tuning_api/database.py @@ -24,7 +24,7 @@ from tvm.meta_schedule.database import Workload from tvm.tir.schedule.trace import JSON_TYPE from tvm.target import Target -from tvm._ffi import register_object +from tvm.ffi import register_object from .primitives import Trace from . import _ffi_api diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py index bfb73e11a06b..cbd71a06e608 100644 --- a/python/tvm/relax/transform/tuning_api/default_functions.py +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -33,7 +33,7 @@ LocalRunner, RunnerInput, ) -from tvm._ffi.registry import register_func +from tvm.ffi.registry import register_func from .primitives import Knob, Trace logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py index 67b81ba7e99c..fdc3769f3e5a 100644 --- a/python/tvm/relax/transform/tuning_api/primitives.py +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -23,7 +23,7 @@ from tvm.ir.module import IRModule from tvm.relax import Expr from tvm.tir.schedule.trace import JSON_TYPE, _json_from_tvm -from tvm._ffi import register_object +from tvm.ffi import register_object from . import _ffi_api logger = logging.getLogger("TuningAPI") # pylint: disable=invalid-name diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index b0afb069435a..426695c9f1fe 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=invalid-name, unused-import """The type nodes of the Relax language.""" -import tvm._ffi +import tvm.ffi from tvm.ir import Type, TupleType, FuncType, Span from . import _ffi_api -@tvm._ffi.register_object("relax.ShapeType") +@tvm.ffi.register_object("relax.ShapeType") class ShapeType(Type): """The type of shape in Relax. @@ -37,7 +37,7 @@ def __init__(self, ndim: int = -1, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore -@tvm._ffi.register_object("relax.ObjectType") +@tvm.ffi.register_object("relax.ObjectType") class ObjectType(Type): """A type that corresponds to tvm::runtime::Object, is base of all possible object values in TVM.""" @@ -46,7 +46,7 @@ def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore -@tvm._ffi.register_object("relax.DynTensorType") +@tvm.ffi.register_object("relax.DynTensorType") class TensorType(Type): """A dynamic tensor type in Relax. @@ -65,7 +65,7 @@ def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TensorType, ndim, dtype, span) # type: ignore -@tvm._ffi.register_object("relax.PackedFuncType") +@tvm.ffi.register_object("relax.PackedFuncType") class PackedFuncType(Type): """The type of ExternFunc in Relax.""" diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py index 1a7cc739b5c1..3b77e7a552e3 100644 --- a/python/tvm/rpc/_ffi_api.py +++ b/python/tvm/rpc/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.rpc""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("rpc", __name__) +tvm.ffi._init_api("rpc", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index 8be74fe88c96..235934212c0e 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -25,7 +25,7 @@ import random import logging -from .._ffi.base import py_str +from ..base import py_str # Magic header for RPC data plane RPC_MAGIC = 0xFF271 diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index f9e677e49e98..ea78b0d7d418 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -22,8 +22,8 @@ import struct import time -import tvm._ffi -from tvm._ffi.base import TVMError +import tvm.ffi +from tvm.base import TVMError from tvm.contrib import utils from tvm.runtime import ndarray as nd from tvm.runtime import Device @@ -263,7 +263,7 @@ def __init__(self): RPCSession.__init__(self, _ffi_api.LocalSession()) -@tvm._ffi.register_func("rpc.PopenSession") +@tvm.ffi.register_func("rpc.PopenSession") def _popen_session(binary): temp = utils.tempdir() diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index 842ba7e49814..2e46965a2050 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -16,7 +16,7 @@ # under the License. """Utils to path.""" import os -from tvm._ffi import libinfo +from tvm import libinfo from tvm.contrib import cc diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index 7997274d3f5a..5987d3f101ba 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -47,7 +47,7 @@ from . import base from .base import TrackerCode from .server import _server_env -from .._ffi.base import py_str +from ..base import py_str class ForwardHandler(object): diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index d334fe0cf7ba..eb345260e300 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -36,10 +36,10 @@ import time import errno import sys -import tvm._ffi +import tvm.ffi -from tvm._ffi.base import py_str -from tvm._ffi.libinfo import find_lib_path +from tvm.base import py_str +from tvm.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import utils from tvm.contrib.popen_pool import PopenWorker @@ -70,11 +70,11 @@ def _server_env(load_library, work_path=None): temp = utils.tempdir() # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) + @tvm.ffi.register_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) - @tvm._ffi.register_func("tvm.rpc.server.load_module", override=True) + @tvm.ffi.register_func("tvm.rpc.server.load_module", override=True) def load_module(file_name): """Load module from remote side.""" path = temp.relpath(file_name) @@ -82,7 +82,7 @@ def load_module(file_name): logger.info("load_module %s", path) return m - @tvm._ffi.register_func("tvm.rpc.server.download_linked_module", override=True) + @tvm.ffi.register_func("tvm.rpc.server.download_linked_module", override=True) def download_linked_module(file_name): """Load module from remote side.""" # pylint: disable=import-outside-toplevel diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py index cd32c4a8221c..4ceb338404f7 100644 --- a/python/tvm/rpc/tracker.py +++ b/python/tvm/rpc/tracker.py @@ -60,7 +60,7 @@ f"RPCTracker module requires tornado package {error_msg}. Try 'pip install tornado'." ) -from .._ffi.base import py_str +from ..base import py_str from . import base from .base import RPC_TRACKER_MAGIC, TrackerCode diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index a07193ea9852..71f96983ee18 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.runtime""" -import tvm._ffi +import tvm.ffi -# Exports functions registered via TVM_REGISTER_GLOBAL with the "runtime" prefix. -# e.g. TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") -tvm._ffi._init_api("runtime", __name__) +# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "runtime" prefix. +# e.g. TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") +tvm.ffi._init_api("runtime", __name__) diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 395496d16be7..493dfceab59e 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -17,14 +17,14 @@ # pylint: disable=invalid-name, unused-argument """FFI for tvm.node""" -import tvm._ffi +import tvm.ffi import tvm.ffi.core # The implementations below are default ones when the corresponding # functions are not available in the runtime only mode. # They will be overriden via _init_api to the ones registered -# via TVM_REGISTER_GLOBAL in the compiler mode. +# via TVM_FFI_REGISTER_GLOBAL in the compiler mode. def AsRepr(obj): return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")" @@ -45,6 +45,6 @@ def LoadJSON(json_str): raise RuntimeError("Do not support object serialization in runtime only mode") -# Exports functions registered via TVM_REGISTER_GLOBAL with the "node" prefix. -# e.g. TVM_REGISTER_GLOBAL("node.AsRepr") -tvm._ffi._init_api("node", __name__) +# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the "node" prefix. +# e.g. TVM_FFI_REGISTER_GLOBAL("node.AsRepr") +tvm.ffi._init_api("node", __name__) diff --git a/python/tvm/runtime/disco/_ffi_api.py b/python/tvm/runtime/disco/_ffi_api.py index 340be86708db..79e1a52ad44e 100644 --- a/python/tvm/runtime/disco/_ffi_api.py +++ b/python/tvm/runtime/disco/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs from C++""" -from ..._ffi import _init_api +from ...ffi import _init_api _init_api("runtime.disco", __name__) diff --git a/python/tvm/runtime/disco/process_pool.py b/python/tvm/runtime/disco/process_pool.py index 95969e038e0f..8f05f28e9158 100644 --- a/python/tvm/runtime/disco/process_pool.py +++ b/python/tvm/runtime/disco/process_pool.py @@ -20,7 +20,7 @@ import subprocess import sys -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.runtime import ShapeTuple diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 3ebb6dbdb611..c551eac428b7 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -25,7 +25,7 @@ import numpy as np -from ..._ffi import get_global_func, register_func, register_object +from ...ffi import get_global_func, register_func, register_object from ..device import Device from ..container import ShapeTuple from ..ndarray import NDArray @@ -271,7 +271,7 @@ def copy_to_worker_0(self, host_array: NDArray, remote_array: Optional[DRef] = N output_array: DRef The DRef containing the copied data on worker0, and - NullOpt on all other workers. If `remote_array` was + std::nullopt on all other workers. If `remote_array` was provided, this return value is the same as `remote_array`. Otherwise, it is the newly allocated space. diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index bb1fbb5fe3c2..3dd4de5da0a8 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -24,8 +24,8 @@ import numpy as np import tvm.ffi -from tvm._ffi.base import _RUNTIME_ONLY -from tvm._ffi.libinfo import find_include_path +from tvm.base import _RUNTIME_ONLY +from tvm.libinfo import find_include_path from . import _ffi_api diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 7c298050929f..9a026707cb48 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -70,7 +70,7 @@ def from_dlpack(ext_tensor): ) -@tvm._ffi.register_object("object.NDArray") +@tvm.ffi.register_object("object.NDArray") class NDArray(tvm.ffi.core.NDArray): """Lightweight NDArray class of TVM runtime. diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py index ff223b75998c..45e4925a3e28 100644 --- a/python/tvm/runtime/object_path.py +++ b/python/tvm/runtime/object_path.py @@ -22,7 +22,7 @@ from typing import Optional -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from . import _ffi_node_api @@ -40,7 +40,7 @@ ) -@tvm._ffi.register_object("ObjectPath") +@tvm.ffi.register_object("ObjectPath") class ObjectPath(Object): """ Path to an object from some root object. @@ -94,42 +94,42 @@ def missing_map_entry(self) -> "ObjectPath": __hash__ = Object.__hash__ -@tvm._ffi.register_object("RootPath") +@tvm.ffi.register_object("RootPath") class RootPath(ObjectPath): pass -@tvm._ffi.register_object("AttributeAccessPath") +@tvm.ffi.register_object("AttributeAccessPath") class AttributeAccessPath(ObjectPath): pass -@tvm._ffi.register_object("UnknownAttributeAccessPath") +@tvm.ffi.register_object("UnknownAttributeAccessPath") class UnknownAttributeAccessPath(ObjectPath): pass -@tvm._ffi.register_object("ArrayIndexPath") +@tvm.ffi.register_object("ArrayIndexPath") class ArrayIndexPath(ObjectPath): pass -@tvm._ffi.register_object("MissingArrayElementPath") +@tvm.ffi.register_object("MissingArrayElementPath") class MissingArrayElementPath(ObjectPath): pass -@tvm._ffi.register_object("MapValuePath") +@tvm.ffi.register_object("MapValuePath") class MapValuePath(ObjectPath): pass -@tvm._ffi.register_object("MissingMapEntryPath") +@tvm.ffi.register_object("MissingMapEntryPath") class MissingMapEntryPath(ObjectPath): pass -@tvm._ffi.register_object("ObjectPathPair") +@tvm.ffi.register_object("ObjectPathPair") class ObjectPathPair(Object): """ Pair of ObjectPaths, one for each object being tested for structural equality. diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 903cd443876f..45189a008495 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -17,7 +17,7 @@ """Registration of profiling objects in python.""" from typing import Dict, Sequence, Optional -from ... import _ffi +from ... import ffi as _ffi from . import _ffi_api from .. import Object, Device diff --git a/python/tvm/runtime/profiling/_ffi_api.py b/python/tvm/runtime/profiling/_ffi_api.py index d26b847a699f..85e5d4ca020c 100644 --- a/python/tvm/runtime/profiling/_ffi_api.py +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for profiling""" -from ... import _ffi +from ...ffi import _init_api -_ffi._init_api("runtime.profiling", __name__) +_init_api("runtime.profiling", __name__) diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py index 411835cd216f..d69c3308fad4 100644 --- a/python/tvm/runtime/relax_vm.py +++ b/python/tvm/runtime/relax_vm.py @@ -18,12 +18,12 @@ """The Relax virtual machine.""" from enum import IntEnum from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from numbers import Number, Integral import numpy as np # type: ignore import tvm -from tvm._ffi import base as _base -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.runtime import Device, Object, PackedFunc from tvm.runtime.profiling import Report @@ -198,7 +198,7 @@ def _convert(self, arg: Any, cargs: List) -> None: def _gettype(arg): if isinstance(arg, np.float16): return "float16" - elif isinstance(arg, (_base.integer_types, bool)): + elif isinstance(arg, (Integral, bool)): return "int32" else: return "float32" @@ -215,7 +215,7 @@ def _gettype(arg): for field in arg: self._convert(field, field_args) cargs.append(tuple(field_args)) - elif isinstance(arg, (_base.numeric_types, bool)): + elif isinstance(arg, (Number, bool)): dtype = _gettype(arg) value = tvm.nd.array(np.array(arg, dtype=dtype), device=tvm.cpu(0)) cargs.append(value) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index ad3f612c4e29..ade34e1e9b85 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -18,7 +18,7 @@ import os from typing import Dict, List, Optional, Sequence -from tvm._ffi import get_global_func, register_object +from tvm.ffi import get_global_func, register_object from tvm.runtime import Object from . import _ffi_node_api diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index 3716460a2709..149a66ef7b55 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -19,10 +19,10 @@ import re -import tvm._ffi +import tvm.ffi -@tvm._ffi.register_func("tvm.runtime.regex_match") +@tvm.ffi.register_func("tvm.runtime.regex_match") def _regex_match(regex_pattern: str, match_against: str) -> bool: """Check if a pattern matches a regular expression diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/_ffi_api.py index ebc638f3fd35..8ae8f7b7f9a5 100644 --- a/python/tvm/script/_ffi_api.py +++ b/python/tvm/script/_ffi_api.py @@ -14,7 +14,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("script", __name__) +tvm.ffi._init_api("script", __name__) diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py index 68811c9e018c..8ee223051986 100644 --- a/python/tvm/script/ir_builder/_ffi_api.py +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access +tvm.ffi._init_api("script.ir_builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 1d5d050444f7..95b5c5002558 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -17,7 +17,7 @@ """A generic IRBuilder across the TVM stack""" from typing import Any, Callable, List -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from tvm.runtime import Object as _Object from . import _ffi_api diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py index 874cc278af83..5b9d801a6ed3 100644 --- a/python/tvm/script/ir_builder/ir/_ffi_api.py +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access +tvm.ffi._init_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/ir/frame.py b/python/tvm/script/ir_builder/ir/frame.py index e16d86dc227e..d2737fde59a6 100644 --- a/python/tvm/script/ir_builder/ir/frame.py +++ b/python/tvm/script/ir_builder/ir/frame.py @@ -16,7 +16,7 @@ # under the License. """Package tvm.script.ir_builder.ir.frame""" -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from ..base import IRBuilderFrame diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py index 6e2098cf88af..1c767bacc4c5 100644 --- a/python/tvm/script/ir_builder/relax/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access +tvm.ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py index d7121744890a..4d2ba60c2002 100644 --- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax.distributed""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api( +tvm.ffi._init_api( "script.ir_builder.relax.distributed", __name__ ) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py b/python/tvm/script/ir_builder/relax/distributed/ir.py index aa0e2a26d34c..159ad5aea169 100644 --- a/python/tvm/script/ir_builder/relax/distributed/ir.py +++ b/python/tvm/script/ir_builder/relax/distributed/ir.py @@ -18,6 +18,7 @@ """IRBuilder for distributed Relax dialect""" from typing import Union, List, Tuple, Optional +from numbers import Number import numpy as _np # type: ignore import tvm @@ -27,7 +28,7 @@ from tvm.relax.expr import Tuple as RxTuple from tvm.relax.distributed import DTensorStructInfo from tvm.relax.utils import args_converter -from tvm._ffi import base as _base +from tvm import base as _base from tvm.runtime import ndarray as _nd from tvm.relax.op.distributed import ( redistribute as _redistribute, @@ -114,7 +115,7 @@ def const( if not isinstance(struct_info, DTensorStructInfo): raise TypeError("struct_info needs to be an instance of DTensorStructInfo. ") dtype = str(struct_info.tensor_sinfo.dtype) - if isinstance(value, (_base.numeric_types, (bool, list))): + if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) if isinstance(value, (_np.ndarray, _np.generic)): diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py index 97e181fbe4be..181f62ec4f39 100644 --- a/python/tvm/script/ir_builder/relax/frame.py +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """IR Builder Frame for Relax dialect""" -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from ..base import IRBuilderFrame diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d1e86cc7f456..92f84ce05cc2 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -138,6 +138,7 @@ ones, ones_like, one_hot, + outer, permute_dims, power, print, @@ -156,6 +157,7 @@ sign, sin, sinh, + slice_scatter, sort, split, sqrt, @@ -826,6 +828,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "one_hot", "opencl", "output", + "outer", "permute_dims", "power", "prim_value", @@ -852,6 +855,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "sign", "sin", "sinh", + "slice_scatter", "sort", "split", "square", diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py index 876f5f3a35a0..69797f986afd 100644 --- a/python/tvm/script/ir_builder/tir/_ffi_api.py +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access +tvm.ffi._init_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index b2229d503bfb..e3ce2e6e2eb1 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -17,7 +17,7 @@ """IRBuilder for TIR""" from typing import List, Union -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from tvm.tir import Buffer, Var from ..base import IRBuilderFrame diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index f40b9a7cf6d3..78da15ca1f27 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -24,7 +24,7 @@ import numpy as np -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.error import DiagnosticError from tvm.ir import GlobalVar diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py index 944ecad01e08..9cbf6cfdca22 100644 --- a/python/tvm/script/printer/_ffi_api.py +++ b/python/tvm/script/printer/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.printer""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("script.printer", __name__) # pylint: disable=protected-access +tvm.ffi._init_api("script.printer", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 9a6e7f1b8c8f..02a67e916bc0 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -19,7 +19,7 @@ from enum import IntEnum, unique from typing import Dict, List, Optional, Sequence, Tuple, Union -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object, ObjectPath from tvm.tir import FloatImm, IntImm diff --git a/python/tvm/support.py b/python/tvm/support.py index a50a5e7b5732..7e0ad5875f83 100644 --- a/python/tvm/support.py +++ b/python/tvm/support.py @@ -22,11 +22,11 @@ import sys import tvm -import tvm._ffi +import tvm.ffi from .runtime.module import Module from . import get_global_func -tvm._ffi._init_api("support", __name__) +tvm.ffi._init_api("support", __name__) def libinfo(): diff --git a/python/tvm/target/_ffi_api.py b/python/tvm/target/_ffi_api.py index 3f3c4f2b8e46..489b59b4c6ae 100644 --- a/python/tvm/target/_ffi_api.py +++ b/python/tvm/target/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.target""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("target", __name__) +tvm.ffi._init_api("target", __name__) diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index aaf30afaf535..5003dbdde253 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -26,7 +26,7 @@ BinaryOpExpr as _BinaryOpExpr, ) from tvm.tir.op import call_pure_extern -from tvm._ffi import register_func as _register_func +from tvm.ffi import register_func as _register_func from tvm.tir import call_intrin @@ -214,7 +214,7 @@ class name (e.g. Add, LE, Cast, Call). ) else: lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." + src_type_name - tvm._ffi.register_func(lower_func_name, lower_func) + tvm.ffi.register_func(lower_func_name, lower_func) def register_min_func(func, type_name): diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index 57e39dfdb0ef..ec1875eb90a1 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -17,7 +17,7 @@ """Detect target.""" from typing import Union -from .._ffi import get_global_func +from ..ffi import get_global_func from ..runtime import Device from ..runtime.ndarray import device from . import Target diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index ba7b93333f1f..eb8bf1f9b807 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -20,8 +20,8 @@ import warnings from typing import Union -import tvm._ffi -from tvm._ffi import register_func as _register_func +import tvm.ffi +from tvm.ffi import register_func as _register_func from tvm.runtime import Device from tvm.runtime import Object, convert from tvm.runtime.container import String @@ -30,7 +30,7 @@ from . import _ffi_api -@tvm._ffi.register_object +@tvm.ffi.register_object class TargetKind(Object): """Kind of a compilation target""" @@ -53,7 +53,7 @@ def __getattr__(self, name: str): return _ffi_api.TargetGetFeature(self.target, name) -@tvm._ffi.register_object +@tvm.ffi.register_object class Target(Object): """Target device information, use through TVM API. diff --git a/python/tvm/target/virtual_device.py b/python/tvm/target/virtual_device.py index b2bc0bbcf5ee..3d923a4623d2 100644 --- a/python/tvm/target/virtual_device.py +++ b/python/tvm/target/virtual_device.py @@ -22,7 +22,7 @@ from . import _ffi_api -@tvm._ffi.register_object +@tvm.ffi.register_object class VirtualDevice(Object): """A compile time representation for where data is to be stored at runtime, and how to compile code to compute it.""" diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index c040eface808..177021f1433f 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Common x86 related utilities""" -from .._ffi import register_func +from ..ffi import register_func from .codegen import target_has_features diff --git a/python/tvm/te/_ffi_api.py b/python/tvm/te/_ffi_api.py index ac814c844724..98e466e9e88c 100644 --- a/python/tvm/te/_ffi_api.py +++ b/python/tvm/te/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.te""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("te", __name__) +tvm.ffi._init_api("te", __name__) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 8d72fc794011..4a5d2425e669 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -14,18 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" Operation class for computation declaration.""" +"""Operation class for computation declaration.""" import inspect # pylint: disable=invalid-name from numbers import Integral as _Integral from typing import List, Optional, Union -import tvm._ffi +import tvm.ffi import tvm.arith._ffi_api import tvm.tir import tvm.tir._ffi_api -from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert @@ -516,7 +515,7 @@ def thread_axis(dom=None, tag="", name="", span=None): axis : IterVar The thread itervar. """ - if isinstance(dom, string_types): + if isinstance(dom, str): tag, dom = dom, None if not tag: raise ValueError("tag must be given as Positional or keyword argument") diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index dde92870518d..aad18a8b016c 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -16,7 +16,7 @@ # under the License. """Tensor class for computation declaration.""" # pylint: disable=invalid-name -import tvm._ffi +import tvm.ffi from tvm.runtime import Object, ObjectGeneric from tvm.tir import expr as _expr, DataProducer @@ -48,7 +48,7 @@ def dtype(self): return self.tensor.dtype -@tvm._ffi.register_object("te.Tensor") +@tvm.ffi.register_object("te.Tensor") class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" @@ -141,12 +141,12 @@ def input_tensors(self): return _ffi_api.OpInputTensors(self) -@tvm._ffi.register_object +@tvm.ffi.register_object class PlaceholderOp(Operation): """Placeholder operation.""" -@tvm._ffi.register_object +@tvm.ffi.register_object class BaseComputeOp(Operation): """Compute operation.""" @@ -161,12 +161,12 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") -@tvm._ffi.register_object +@tvm.ffi.register_object class ComputeOp(BaseComputeOp): """Scalar operation.""" -@tvm._ffi.register_object +@tvm.ffi.register_object class ScanOp(Operation): """Scan operation.""" @@ -176,6 +176,6 @@ def scan_axis(self): return self.__getattr__("scan_axis") -@tvm._ffi.register_object +@tvm.ffi.register_object class ExternOp(Operation): """External operation.""" diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py index 56a77223b767..e3c30d1299a1 100644 --- a/python/tvm/testing/_ffi_api.py +++ b/python/tvm/testing/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.testing""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("testing", __name__) +tvm.ffi._init_api("testing", __name__) diff --git a/python/tvm/testing/popen_pool.py b/python/tvm/testing/popen_pool.py index 42a34ccc61da..0fc3ce219030 100644 --- a/python/tvm/testing/popen_pool.py +++ b/python/tvm/testing/popen_pool.py @@ -36,19 +36,19 @@ def after_initializer(): return TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3 -@tvm._ffi.register_func("testing.identity_py") +@tvm.ffi.register_func("testing.identity_py") def identity_py(arg): return arg def register_ffi(): - @tvm._ffi.register_func("testing.nested_identity_py") + @tvm.ffi.register_func("testing.nested_identity_py") def _identity_py(arg): # pylint: disable=unused-variable return arg def call_py_ffi(arg): - _identity_py = tvm._ffi.get_global_func("testing.nested_identity_py") + _identity_py = tvm.ffi.get_global_func("testing.nested_identity_py") return _identity_py(arg) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index ec3fc28b1d8a..6b047de4460a 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -87,7 +87,7 @@ def test_something(): import tvm.arith import tvm.tir import tvm.te -import tvm._ffi +import tvm.ffi from tvm.target import codegen from tvm.contrib import nvcc, cudnn, rocm diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py index 1b60b8c81c6d..8c438557c8c1 100644 --- a/python/tvm/tir/_ffi_api.py +++ b/python/tvm/tir/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("tir", __name__) +tvm.ffi._init_api("tir", __name__) diff --git a/python/tvm/tir/analysis/_ffi_api.py b/python/tvm/tir/analysis/_ffi_api.py index 6c1687e8a520..40a7b4caf340 100644 --- a/python/tvm/tir/analysis/_ffi_api.py +++ b/python/tvm/tir/analysis/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.analysis""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("tir.analysis", __name__) +tvm.ffi._init_api("tir.analysis", __name__) diff --git a/python/tvm/tir/block_dependence_info.py b/python/tvm/tir/block_dependence_info.py index 5f1664628890..67a644967e4b 100644 --- a/python/tvm/tir/block_dependence_info.py +++ b/python/tvm/tir/block_dependence_info.py @@ -18,7 +18,7 @@ to store the block level dependences""" from typing import Union, Optional -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir.module import IRModule from tvm.runtime import Object from tvm.tir import Block, PrimFunc diff --git a/python/tvm/tir/block_scope.py b/python/tvm/tir/block_scope.py index 30e047b4f78a..b24cca0707a0 100644 --- a/python/tvm/tir/block_scope.py +++ b/python/tvm/tir/block_scope.py @@ -18,7 +18,7 @@ from enum import IntEnum from typing import List, Optional, Union -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.runtime import Object from tvm.tir import Block, For diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 72c2a40fedd2..1f40520e55be 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -17,15 +17,14 @@ """Abstraction for array data structures.""" from numbers import Integral -import tvm._ffi -from tvm._ffi.base import string_types +import tvm.ffi from tvm.ir import PointerType, PrimExpr, PrimType, Range from tvm.runtime import Object, Scriptable, convert from . import _ffi_api -@tvm._ffi.register_object("tir.Buffer") +@tvm.ffi.register_object("tir.Buffer") class Buffer(Object, Scriptable): """Symbolic data buffer in TVM. @@ -85,7 +84,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0, # Get access ptr for read with extent buffer.access_ptr("r", extent = 100) """ - if isinstance(access_mask, string_types): + if isinstance(access_mask, str): mask = 0 for value in access_mask: if value == "r": @@ -350,6 +349,6 @@ def decl_buffer( ) -@tvm._ffi.register_object("tir.DataProducer") +@tvm.ffi.register_object("tir.DataProducer") class DataProducer(Object): pass diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index 71cc404ee23b..39874640ff40 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -17,13 +17,13 @@ """Data layout.""" from typing import Union -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from . import _ffi_api -@tvm._ffi.register_object("tir.Layout") +@tvm.ffi.register_object("tir.Layout") class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and @@ -81,7 +81,7 @@ def factor_of(self, axis): return _ffi_api.LayoutFactorOf(self, axis) # type: ignore -@tvm._ffi.register_object("tir.BijectiveLayout") +@tvm.ffi.register_object("tir.BijectiveLayout") class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index b293343cae74..e57c01f23afc 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -29,7 +29,7 @@ """ from typing import List, Optional, Union -import tvm._ffi +import tvm.ffi import tvm.ir._ffi_api from tvm import ir from tvm.ir import Op, PrimExpr @@ -349,7 +349,7 @@ class LogicalExpr(PrimExprWithOp): pass -@tvm._ffi.register_object("tir.Var") +@tvm.ffi.register_object("tir.Var") class Var(PrimExprWithOp): """Symbolic variable. @@ -372,7 +372,7 @@ def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) # type: ignore -@tvm._ffi.register_object("tir.SizeVar") +@tvm.ffi.register_object("tir.SizeVar") class SizeVar(Var): """Symbolic variable to represent a tensor index size which is greater or equal to zero. @@ -394,7 +394,7 @@ def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore -@tvm._ffi.register_object("tir.IterVar") +@tvm.ffi.register_object("tir.IterVar") class IterVar(ExprOp, Object, Scriptable): """Represent iteration variable. @@ -467,7 +467,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.CommReducer") +@tvm.ffi.register_object("tir.CommReducer") class CommReducer(Object, Scriptable): """Commutative reduce operator @@ -507,7 +507,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.Reduce") +@tvm.ffi.register_object("tir.Reduce") class Reduce(PrimExprWithOp): """Reduce node. @@ -558,7 +558,7 @@ def __init__( ) -@tvm._ffi.register_object +@tvm.ffi.register_object class FloatImm(ConstExpr): """Float constant. @@ -585,7 +585,7 @@ def __float__(self) -> float: return self.value -@tvm._ffi.register_object +@tvm.ffi.register_object class IntImm(ConstExpr): """Int constant. @@ -627,7 +627,7 @@ def __bool__(self) -> bool: return self.__nonzero__() -@tvm._ffi.register_object("tir.StringImm") # type: ignore +@tvm.ffi.register_object("tir.StringImm") # type: ignore class StringImm(ConstExpr): """String constant. @@ -659,7 +659,7 @@ def __hash__(self) -> int: return PrimExpr.__hash__(self) -@tvm._ffi.register_object("tir.Cast") +@tvm.ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): """Cast expression. @@ -681,7 +681,7 @@ def __init__(self, dtype, value, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore -@tvm._ffi.register_object("tir.Add") +@tvm.ffi.register_object("tir.Add") class Add(BinaryOpExpr): """Add node. @@ -701,7 +701,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Sub") +@tvm.ffi.register_object("tir.Sub") class Sub(BinaryOpExpr): """Sub node. @@ -721,7 +721,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Mul") +@tvm.ffi.register_object("tir.Mul") class Mul(BinaryOpExpr): """Mul node. @@ -741,7 +741,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Div") +@tvm.ffi.register_object("tir.Div") class Div(BinaryOpExpr): """Div node. @@ -761,7 +761,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Mod") +@tvm.ffi.register_object("tir.Mod") class Mod(BinaryOpExpr): """Mod node. @@ -781,7 +781,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.FloorDiv") +@tvm.ffi.register_object("tir.FloorDiv") class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -801,7 +801,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.FloorMod") +@tvm.ffi.register_object("tir.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. @@ -821,7 +821,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Min") +@tvm.ffi.register_object("tir.Min") class Min(BinaryOpExpr): """Min node. @@ -841,7 +841,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Max") +@tvm.ffi.register_object("tir.Max") class Max(BinaryOpExpr): """Max node. @@ -861,7 +861,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.EQ") +@tvm.ffi.register_object("tir.EQ") class EQ(CmpExpr): """EQ node. @@ -881,7 +881,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.NE") +@tvm.ffi.register_object("tir.NE") class NE(CmpExpr): """NE node. @@ -901,7 +901,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.LT") +@tvm.ffi.register_object("tir.LT") class LT(CmpExpr): """LT node. @@ -921,7 +921,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.LE") +@tvm.ffi.register_object("tir.LE") class LE(CmpExpr): """LE node. @@ -941,7 +941,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.GT") +@tvm.ffi.register_object("tir.GT") class GT(CmpExpr): """GT node. @@ -961,7 +961,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.GE") +@tvm.ffi.register_object("tir.GE") class GE(CmpExpr): """GE node. @@ -981,7 +981,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.And") +@tvm.ffi.register_object("tir.And") class And(LogicalExpr): """And node. @@ -1001,7 +1001,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Or") +@tvm.ffi.register_object("tir.Or") class Or(LogicalExpr): """Or node. @@ -1024,7 +1024,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> Non self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore -@tvm._ffi.register_object("tir.Not") +@tvm.ffi.register_object("tir.Not") class Not(LogicalExpr): """Not node. @@ -1043,7 +1043,7 @@ def __init__(self, a: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore -@tvm._ffi.register_object("tir.Select") +@tvm.ffi.register_object("tir.Select") class Select(PrimExprWithOp): """Select node. @@ -1087,7 +1087,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.BufferLoad") +@tvm.ffi.register_object("tir.BufferLoad") class BufferLoad(PrimExprWithOp): """Buffer load node. @@ -1122,7 +1122,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.ProducerLoad") +@tvm.ffi.register_object("tir.ProducerLoad") class ProducerLoad(PrimExprWithOp): """Producer load node. @@ -1149,7 +1149,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.Ramp") +@tvm.ffi.register_object("tir.Ramp") class Ramp(PrimExprWithOp): """Ramp node. @@ -1180,7 +1180,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.Broadcast") +@tvm.ffi.register_object("tir.Broadcast") class Broadcast(PrimExprWithOp): """Broadcast node. @@ -1203,7 +1203,7 @@ def __init__(self, value: PrimExpr, lanes: PrimExpr, span: Optional[Span] = None self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore -@tvm._ffi.register_object("tir.Shuffle") +@tvm.ffi.register_object("tir.Shuffle") class Shuffle(PrimExprWithOp): """Shuffle node. @@ -1241,7 +1241,7 @@ class CallEffectKind: Opaque = UpdateState -@tvm._ffi.register_object("tir.Call") +@tvm.ffi.register_object("tir.Call") class Call(PrimExprWithOp): """Call node. @@ -1281,7 +1281,7 @@ def __init__( self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore -@tvm._ffi.register_object("tir.Let") +@tvm.ffi.register_object("tir.Let") class Let(PrimExprWithOp): """Let node. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index eb3c50b409c8..55bae37809f0 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -22,7 +22,7 @@ from typing import Callable, List, Mapping, Optional, Tuple, Union import tvm -import tvm._ffi +import tvm.ffi import tvm.runtime from tvm.ir import BaseFunc, Range from tvm.runtime import Object, Scriptable @@ -33,7 +33,7 @@ from .expr import PrimExpr, Var -@tvm._ffi.register_object("tir.PrimFunc") +@tvm.ffi.register_object("tir.PrimFunc") class PrimFunc(BaseFunc, Scriptable): """A function declaration expression. @@ -174,7 +174,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: return _ffi_api.Specialize(self, param_map) # type: ignore -@tvm._ffi.register_object("tir.TensorIntrin") +@tvm.ffi.register_object("tir.TensorIntrin") class TensorIntrin(Object): """A tensor intrinsic. @@ -230,7 +230,7 @@ def get(name: str, allow_missing: bool = False) -> Optional["TensorIntrin"]: return _ffi_api.TensorIntrinGet(name, allow_missing) # pylint: type: ignore -@tvm._ffi.register_object("tir.IndexMap") +@tvm.ffi.register_object("tir.IndexMap") class IndexMap(Object): """A mapping from multi-dimensional indices to another set of multi-dimensional indices diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 777d46ec7b0d..7a9708848ab4 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -16,7 +16,6 @@ # under the License. """Developer API of IR node builder make function.""" import tvm -from tvm._ffi.base import string_types from tvm.runtime import ObjectGeneric, const from tvm.ir import container as _container @@ -194,9 +193,9 @@ def scope_attr(self, node, attr_key, value): ib.scope_attr(x, "storage_scope", "global") x[i] = x[i - 1] + 1 """ - if isinstance(node, string_types): + if isinstance(node, str): node = _expr.StringImm(node) - if isinstance(value, string_types): + if isinstance(value, str): value = _expr.StringImm(value) # thread_extent could be zero for dynamic workloads if attr_key == "thread_extent": diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4c0b51280e3e..57aa060cd0c5 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -18,7 +18,7 @@ """Operators used in TIR expression.""" from typing import Any, Optional, Union -import tvm._ffi +import tvm.ffi from tvm import tir from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span @@ -1927,7 +1927,7 @@ def all(*args, span=None): return val -@tvm._ffi.register_func("tvm.default_trace_action") +@tvm.ffi.register_func("tvm.default_trace_action") def _tvm_default_trace_action(*args): print(list(args)) diff --git a/python/tvm/tir/schedule/_ffi_api.py b/python/tvm/tir/schedule/_ffi_api.py index ae8bdfde54bf..b854145beb6a 100644 --- a/python/tvm/tir/schedule/_ffi_api.py +++ b/python/tvm/tir/schedule/_ffi_api.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.schedule""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("tir.schedule", __name__) # pylint: disable=protected-access +tvm.ffi._init_api("tir.schedule", __name__) # pylint: disable=protected-access diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 15748a99c81a..491a689c9309 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -17,7 +17,7 @@ """Analysis used in TensorIR scheduling""" from typing import List, Optional -import tvm._ffi +import tvm.ffi from tvm.runtime import Object from ..buffer import Buffer @@ -62,7 +62,7 @@ def suggest_index_map( ) -@tvm._ffi.register_object("tir.schedule.TensorizeInfo") +@tvm.ffi.register_object("tir.schedule.TensorizeInfo") class TensorizeInfo(Object): """Necessary information used for tensorization.""" @@ -90,7 +90,7 @@ def get_tensorize_loop_mapping( return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func, allow_padding) # type: ignore -@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") +@tvm.ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") class AutoTensorizeMappingInfo(Object): """Necessary information used to perform transformations for tensorization.""" diff --git a/python/tvm/tir/schedule/instruction.py b/python/tvm/tir/schedule/instruction.py index 09b2d70dc321..5a8563e652b6 100644 --- a/python/tvm/tir/schedule/instruction.py +++ b/python/tvm/tir/schedule/instruction.py @@ -17,7 +17,7 @@ """Schedule instructions each corresponds to a schedule primitive""" from typing import TYPE_CHECKING, Any, List, Union -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from tvm.runtime import Object from . import _ffi_api diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d0dd427e6d5e..5325ecdc16c4 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -18,7 +18,7 @@ import inspect from typing import Callable, Dict, List, Literal, Optional, Tuple, Union -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index df2eb534e633..f082a9e92ea7 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -20,7 +20,7 @@ from enum import IntEnum from typing import Dict, Optional, Union -from tvm._ffi import register_object +from tvm.ffi import register_object from tvm.ir import IRModule from tvm.runtime import Object from tvm.tir import Block, BlockRealize, For, PrimFunc diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index f1f6bdaa5743..15bb201ae641 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -18,7 +18,7 @@ import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional -from tvm._ffi import register_object as _register_object +from tvm.ffi import register_object as _register_object from tvm.runtime import Object from ...ir import Array, Map, save_json @@ -139,7 +139,7 @@ def pop(self) -> Optional[Instruction]: Returns ------- popped_inst : Instruction - Returns the instruction removed; NullOpt if the trace is empty + Returns the instruction removed; std::nullopt if the trace is empty """ return _ffi_api.TracePop(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/transform.py b/python/tvm/tir/schedule/transform.py index e40b55d4d6b2..fbaca81197e5 100644 --- a/python/tvm/tir/schedule/transform.py +++ b/python/tvm/tir/schedule/transform.py @@ -41,6 +41,6 @@ def tile_with_tensor_intrin( ------- tiled_loop_rv : Optional[LoopRV] LoopRV corresponding to the outermost loop of a block tiled according to the given intrin - NullOpt if no valid loop mapping is found + std::nullopt if no valid loop mapping is found """ return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name, allow_padding) # type: ignore diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index aa3b17a7a12f..a04f80b55e7a 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -29,7 +29,7 @@ from enum import IntEnum from typing import List, Mapping, Optional, Union -import tvm._ffi +import tvm.ffi from tvm.ir import PrimExpr, Range, Span from tvm.runtime import Object, Scriptable, const, NDArray @@ -42,7 +42,7 @@ class Stmt(Object, Scriptable): """Base class of all the statements.""" -@tvm._ffi.register_object("tir.LetStmt") +@tvm.ffi.register_object("tir.LetStmt") class LetStmt(Stmt): """LetStmt node. @@ -72,7 +72,7 @@ def __init__(self, var: Var, value: PrimExpr, body: Stmt, span: Optional[Span] = ) -@tvm._ffi.register_object("tir.AssertStmt") +@tvm.ffi.register_object("tir.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. @@ -120,7 +120,7 @@ class ForKind(IntEnum): THREAD_BINDING = 4 # pylint: disable=invalid-name -@tvm._ffi.register_object("tir.For") +@tvm.ffi.register_object("tir.For") class For(Stmt): """For node. @@ -185,7 +185,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.While") +@tvm.ffi.register_object("tir.While") class While(Stmt): """While node. @@ -209,7 +209,7 @@ def __init__(self, condition: PrimExpr, body: Stmt, span: Optional[Span] = None) self.__init_handle_by_constructor__(_ffi_api.While, condition, body, span) # type: ignore -@tvm._ffi.register_object("tir.BufferStore") +@tvm.ffi.register_object("tir.BufferStore") class BufferStore(Stmt): """Buffer store node. @@ -252,7 +252,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.BufferRealize") +@tvm.ffi.register_object("tir.BufferRealize") class BufferRealize(Stmt): """Buffer realize node. @@ -293,7 +293,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.ProducerStore") +@tvm.ffi.register_object("tir.ProducerStore") class ProducerStore(Stmt): """ProducerStore node. @@ -329,7 +329,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.Allocate") +@tvm.ffi.register_object("tir.Allocate") class Allocate(Stmt): """Allocate node. @@ -389,7 +389,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.AllocateConst") +@tvm.ffi.register_object("tir.AllocateConst") class AllocateConst(Stmt): """Allocate constant node. @@ -451,7 +451,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.DeclBuffer") +@tvm.ffi.register_object("tir.DeclBuffer") class DeclBuffer(Stmt): """DeclBuffer node. @@ -475,7 +475,7 @@ def __init__(self, buffer: Buffer, body: Stmt, span: Optional[Span] = None) -> N self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, span) -@tvm._ffi.register_object("tir.AttrStmt") +@tvm.ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. @@ -511,7 +511,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.ProducerRealize") +@tvm.ffi.register_object("tir.ProducerRealize") class ProducerRealize(Stmt): """ProducerRealize node. @@ -563,7 +563,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.SeqStmt") +@tvm.ffi.register_object("tir.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. @@ -589,7 +589,7 @@ def __len__(self): return len(self.seq) -@tvm._ffi.register_object("tir.IfThenElse") +@tvm.ffi.register_object("tir.IfThenElse") class IfThenElse(Stmt): """IfThenElse node. @@ -624,7 +624,7 @@ def __init__( ) -@tvm._ffi.register_object("tir.Evaluate") +@tvm.ffi.register_object("tir.Evaluate") class Evaluate(Stmt): """Evaluate node. @@ -644,7 +644,7 @@ def __init__(self, value: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore -@tvm._ffi.register_object("tir.Prefetch") +@tvm.ffi.register_object("tir.Prefetch") class Prefetch(Stmt): """Prefetch node. @@ -668,7 +668,7 @@ def __init__(self, buffer: Buffer, bounds: List[Range], span: Optional[Span] = N self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) # type: ignore -@tvm._ffi.register_object("tir.BufferRegion") +@tvm.ffi.register_object("tir.BufferRegion") class BufferRegion(Object, Scriptable): """BufferRegion node. @@ -688,7 +688,7 @@ def __init__(self, buffer: Buffer, region: List[Range]) -> None: self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) # type: ignore -@tvm._ffi.register_object("tir.MatchBufferRegion") +@tvm.ffi.register_object("tir.MatchBufferRegion") class MatchBufferRegion(Object, Scriptable): """MatchBufferRegion node. @@ -710,7 +710,7 @@ def __init__(self, buffer: Buffer, source: BufferRegion) -> None: ) -@tvm._ffi.register_object("tir.Block") +@tvm.ffi.register_object("tir.Block") class Block(Stmt): """Block node. @@ -792,7 +792,7 @@ def __init__( ) # type: ignore -@tvm._ffi.register_object("tir.BlockRealize") +@tvm.ffi.register_object("tir.BlockRealize") class BlockRealize(Stmt): """BlockRealize node. diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 57b1c3b873d7..6f964c94370d 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -18,7 +18,7 @@ """Intrinsics for tensorization on NVIDIA GPU.""" from typing import Dict, Literal, Optional, Tuple -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.runtime import convert from tvm.script import tir as T from tvm.tir import Cast, IntImm, TensorIntrin diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tir/transform/_ffi_api.py index 86f7bdf5dac3..8a6607c11af0 100644 --- a/python/tvm/tir/transform/_ffi_api.py +++ b/python/tvm/tir/transform/_ffi_api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.transform""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("tir.transform", __name__) +tvm.ffi._init_api("tir.transform", __name__) diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index c6c825699e9a..b679d4ab16ce 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -19,13 +19,13 @@ import functools from typing import Callable, List, Optional, Union -import tvm._ffi +import tvm.ffi from tvm.ir.transform import Pass, PassInfo from . import _ffi_api -@tvm._ffi.register_object("tir.PrimFuncPass") +@tvm.ffi.register_object("tir.PrimFuncPass") class PrimFuncPass(Pass): """A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function pass class should be created through py:func:`tvm.tir.transform.function_pass`. diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index fa4e98a89a42..9503aea0cd2f 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -24,7 +24,7 @@ Some of the schedule function may have been specially optimized for a specific workload. """ -from tvm._ffi.libinfo import __version__ +from tvm.libinfo import __version__ # Ensure C++ schedules get registered first, so python schedules can # override them. @@ -40,6 +40,7 @@ from .sort import * from .scatter import * from .scatter_elements import * +from .slice_scatter import * from .sparse_reshape import * from .scan import * from .einsum import * diff --git a/python/tvm/topi/cpp/cuda.py b/python/tvm/topi/cpp/cuda.py index ce2efa929824..22f97293d38d 100644 --- a/python/tvm/topi/cpp/cuda.py +++ b/python/tvm/topi/cpp/cuda.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for CUDA TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.cuda", "tvm.topi.cpp.cuda") +tvm.ffi._init_api("topi.cuda", "tvm.topi.cpp.cuda") diff --git a/python/tvm/topi/cpp/generic.py b/python/tvm/topi/cpp/generic.py index d314eca8b22d..3230d5428bb2 100644 --- a/python/tvm/topi/cpp/generic.py +++ b/python/tvm/topi/cpp/generic.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for generic TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.generic", "tvm.topi.cpp.generic") +tvm.ffi._init_api("topi.generic", "tvm.topi.cpp.generic") diff --git a/python/tvm/topi/cpp/impl.py b/python/tvm/topi/cpp/impl.py index 2c877c300dc9..e5473a7e6602 100644 --- a/python/tvm/topi/cpp/impl.py +++ b/python/tvm/topi/cpp/impl.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """Load Lib for C++ TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi", "tvm.topi.cpp") +tvm.ffi._init_api("topi", "tvm.topi.cpp") diff --git a/python/tvm/topi/cpp/nn.py b/python/tvm/topi/cpp/nn.py index 0e3cee703de0..2ea1fc371404 100644 --- a/python/tvm/topi/cpp/nn.py +++ b/python/tvm/topi/cpp/nn.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for NN TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.nn", "tvm.topi.cpp.nn") +tvm.ffi._init_api("topi.nn", "tvm.topi.cpp.nn") diff --git a/python/tvm/topi/cpp/rocm.py b/python/tvm/topi/cpp/rocm.py index eab51107beb7..771fc3c3f0f3 100644 --- a/python/tvm/topi/cpp/rocm.py +++ b/python/tvm/topi/cpp/rocm.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for Rocm TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.rocm", "tvm.topi.cpp.rocm") +tvm.ffi._init_api("topi.rocm", "tvm.topi.cpp.rocm") diff --git a/python/tvm/topi/cpp/utils.py b/python/tvm/topi/cpp/utils.py index 60a2747f9abb..b78a6baa0f01 100644 --- a/python/tvm/topi/cpp/utils.py +++ b/python/tvm/topi/cpp/utils.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for TOPI utility functions""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.utils", "tvm.topi.cpp.utils") +tvm.ffi._init_api("topi.utils", "tvm.topi.cpp.utils") diff --git a/python/tvm/topi/cpp/vision/__init__.py b/python/tvm/topi/cpp/vision/__init__.py index 000602fb399d..5fdf1ac4e3a8 100644 --- a/python/tvm/topi/cpp/vision/__init__.py +++ b/python/tvm/topi/cpp/vision/__init__.py @@ -16,8 +16,8 @@ # under the License. """FFI for vision TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi from . import yolo -tvm._ffi._init_api("topi.vision", "tvm.topi.cpp.vision") +tvm.ffi._init_api("topi.vision", "tvm.topi.cpp.vision") diff --git a/python/tvm/topi/cpp/vision/yolo.py b/python/tvm/topi/cpp/vision/yolo.py index 17e2327295d2..5d8bdd99d24c 100644 --- a/python/tvm/topi/cpp/vision/yolo.py +++ b/python/tvm/topi/cpp/vision/yolo.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for Yolo TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.vision.yolo", "tvm.topi.cpp.vision.yolo") +tvm.ffi._init_api("topi.vision.yolo", "tvm.topi.cpp.vision.yolo") diff --git a/python/tvm/topi/cpp/x86.py b/python/tvm/topi/cpp/x86.py index 0034af02c572..18de30c668a3 100644 --- a/python/tvm/topi/cpp/x86.py +++ b/python/tvm/topi/cpp/x86.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. """FFI for x86 TOPI ops and schedules""" -import tvm._ffi +import tvm.ffi -tvm._ffi._init_api("topi.x86", "tvm.topi.cpp.x86") +tvm.ffi._init_api("topi.x86", "tvm.topi.cpp.x86") diff --git a/python/tvm/topi/generic_op_impl.py b/python/tvm/topi/generic_op_impl.py index 661e24d2a45c..2e0beadbc177 100644 --- a/python/tvm/topi/generic_op_impl.py +++ b/python/tvm/topi/generic_op_impl.py @@ -64,7 +64,7 @@ def _tensor_bop_impl(lhs, rhs): it performs tensor-scalar {op} operation on an element-wise basis. Otherwise, it performs default generic.{op} operation, as defined - in tvm.generic module. + in tvm.tir.generic module. Parameters ---------- diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 5cbc292adbd7..ad2c99fa3ac1 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -376,7 +376,7 @@ def resize1d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, + bicubic_alpha=-0.75, bicubic_exclude=0, extrapolation_value=0.0, out_dtype=None, @@ -748,7 +748,7 @@ def resize2d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, + bicubic_alpha=-0.75, bicubic_exclude=0, extrapolation_value=0.0, out_dtype=None, @@ -1217,7 +1217,7 @@ def resize3d( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="", - bicubic_alpha=-0.5, + bicubic_alpha=-0.75, bicubic_exclude=0, extrapolation_value=0.0, out_dtype=None, diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index 865d62683c8b..fb306f9e599b 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -484,7 +484,6 @@ def log2(x): return te.compute(x.shape, lambda *i: te.log2(x(*i))) -@tvm.te.tag_scope(tag=tag.ELEMWISE) def log10(x): """Take logarithm to the base 10 of input x. @@ -498,7 +497,9 @@ def log10(x): y : tvm.te.Tensor The result. """ - return te.compute(x.shape, lambda *i: te.log10(x(*i))) + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) + return te.compute(x.shape, lambda *i: te.log10(x(*i)), tag=tag.ELEMWISE) @tvm.te.tag_scope(tag=tag.ELEMWISE) @@ -515,6 +516,8 @@ def sqrt(x): y : tvm.te.Tensor The result. """ + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.sqrt(x(*i))) @@ -532,6 +535,8 @@ def rsqrt(x): y : tvm.te.Tensor The result. """ + if x.dtype.startswith("int"): + x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.rsqrt(x(*i))) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 2965944d61c8..531c0a6c6663 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -743,6 +743,11 @@ def conv( dimensions, kernel_dimensions, dilations, pad_begin, pad_end, strides ) ] + for out_dim in out_dimensions: + if isinstance(out_dim, int) and out_dim <= 0: + raise ValueError( + f"Invalid conv parameters: lead to negative output shape {out_dimensions}. " + ) # compute graph pad_before = list(np.array([0, 0] + pad_begin)[data_permutation_from]) pad_after = list(np.array([0, 0] + pad_end)[data_permutation_from]) diff --git a/python/tvm/topi/slice_scatter.py b/python/tvm/topi/slice_scatter.py new file mode 100644 index 000000000000..d8772d0f5b7e --- /dev/null +++ b/python/tvm/topi/slice_scatter.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SliceScatter operator""" +from tvm import topi +from . import utils + + +def slice_scatter(input_tensor, src, start, end, step, axis): + """ + Scatters a slice of src into input along the given axis (SSA form). + + Args: + input_tensor (te.Tensor): The input tensor to scatter into. + src (te.Tensor): The source tensor to scatter from. + start (int): The starting index of the slice. + end (int): The ending index of the slice. + step (int): The step size of the slice. + axis (int): The axis to scatter along. + + Returns: + list[te.Tensor]: A list containing the output tensor with the slice scattered. + """ + + dim_size_expr = input_tensor.shape[axis] # Expression for dimension size + dim_size = utils.get_const_int(dim_size_expr) # Dimension size (as constant int) + + if start == 0 and end == dim_size and step == 1: + return topi.identity(src) + + mask = topi.full((dim_size,), "bool", True) + idx = topi.arange(start=0, stop=dim_size, step=1, dtype="int64") + + if start != 0: + mask = topi.logical_and(mask, topi.greater_equal(idx, start)) + + if end != dim_size: + mask = topi.logical_and(mask, topi.less(idx, end)) + + if step != 1: + step_mask = topi.equal(topi.floor_mod(idx - start, step), 0) + mask = topi.logical_and(mask, step_mask) + + mask_shape_base = [1] * len(input_tensor.shape) + mask_shape_base[axis] = dim_size + mask_shape = tuple(mask_shape_base) + + mask_reshaped = topi.reshape(mask, mask_shape) + + idx_new_pre = idx - start + (step - 1) + idx_new_div = topi.floor_divide(idx_new_pre, step) + idx_new = topi.clip(idx_new_div, 0, dim_size - 1) + + temp = topi.take(src, idx_new, axis=axis) + + mask_shape_expanded_base = list(input_tensor.shape) + mask_shape_expanded = tuple(mask_shape_expanded_base) + + mask_expanded = topi.broadcast_to(mask_reshaped, mask_shape_expanded) + + output = topi.where(mask_expanded, temp, input_tensor) + + return [output] diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 1ef65230591b..951944e618ab 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -96,7 +96,8 @@ def _compute(*idxs): axis_index = 0 for i in range(0, len(idxs)): if i not in real_axis: - indices.append(idxs[i]) + dim = tvm.tir.if_then_else(a.shape[len(indices)] != 1, idxs[i], 0) + indices.append(dim) axis_index += 1 return a(*indices) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 3a0441ef84af..d74d5d2a845a 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -262,7 +262,17 @@ def simplify(expr): out : Expr or int The simplified output """ - return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr + if isinstance(expr, te.Tensor): + return te.compute( + expr.shape, + lambda *indices: tvm.arith.Analyzer().simplify(expr[indices]), + name="simplify_output", + tag="simplify", + ) + elif isinstance(expr, tvm.tir.PrimExpr): + return tvm.arith.Analyzer().simplify(expr) + else: + return expr def ravel_index(indices, shape): diff --git a/rust/.gitignore b/rust/.gitignore deleted file mode 100644 index 0cc660650780..000000000000 --- a/rust/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -target/ -*.rs.bk -Cargo.lock -c_runtime_api.rs diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml deleted file mode 100644 index 95936dc4dec8..000000000000 --- a/rust/.rustfmt.toml +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -max_width = 100 -hard_tabs = false -tab_spaces = 4 -newline_style = "Auto" -use_small_heuristics = "Default" -reorder_imports = true -reorder_modules = true -remove_nested_parens = true -fn_params_layout = "Tall" -edition = "2018" -merge_derives = true -use_try_shorthand = false -use_field_init_shorthand = false -force_explicit_abi = true diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml deleted file mode 100644 index 4300cb3f1dcb..000000000000 --- a/rust/tvm-macros/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-macros" -version = "0.1.1-alpha" -license = "Apache-2.0" -description = "Procedural macros of the TVM crate." -repository = "https://github.com/apache/tvm" -readme = "README.md" -keywords = ["tvm"] -authors = ["TVM Contributors"] -edition = "2018" - -[lib] -proc-macro = true - -[dependencies] -goblin = "^0.2" -proc-macro2 = "^1.0" -quote = "^1.0" -syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] } -proc-macro-error = "^1.0" diff --git a/rust/tvm-macros/README.md b/rust/tvm-macros/README.md deleted file mode 100644 index 8a7c4b301524..000000000000 --- a/rust/tvm-macros/README.md +++ /dev/null @@ -1,20 +0,0 @@ - - - - - - - - - - - - - - - - - -# tvm-macros - -The procedural macro implementations for TVM crates, see `tvm` crate for more documentation. diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs deleted file mode 100644 index 146f9d4d6bc6..000000000000 --- a/rust/tvm-macros/src/external.rs +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use proc_macro2::Span; -use proc_macro_error::abort; -use quote::quote; -use syn::parse::{Parse, ParseStream, Result}; - -use syn::{ - token::Semi, Attribute, FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, - Signature, Type, Visibility, -}; - -struct ExternalItem { - attrs: Vec, - visibility: Visibility, - sig: Signature, -} - -impl Parse for ExternalItem { - fn parse(input: ParseStream) -> Result { - let item = ExternalItem { - attrs: input.call(Attribute::parse_outer)?, - visibility: input.parse()?, - sig: input.parse()?, - }; - let _semi: Semi = input.parse()?; - Ok(item) - } -} - -struct External { - visibility: Visibility, - tvm_name: String, - ident: Ident, - generics: Generics, - inputs: Vec, - ret_type: ReturnType, -} - -impl Parse for External { - fn parse(input: ParseStream) -> Result { - let method: ExternalItem = input.parse()?; - let visibility = method.visibility; - assert_eq!(method.attrs.len(), 1); - let sig = method.sig; - let tvm_name = method.attrs[0].parse_meta()?; - let tvm_name = match tvm_name { - Meta::List(meta_list) => { - let name = meta_list.path.get_ident().expect("name"); - assert_eq!(name.to_string(), "name".to_string()); - match meta_list.nested.first() { - Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(), - _ => panic!(), - } - } - _ => panic!(), - }; - - let ident = sig.ident; - let generics = sig.generics; - let inputs = sig - .inputs - .iter() - .cloned() - .map(|param| param.clone()) - .collect(); - let ret_type = sig.output; - - Ok(External { - visibility, - tvm_name, - ident, - generics, - inputs, - ret_type, - }) - } -} - -struct ExternalInput { - externs: Vec, -} - -impl Parse for ExternalInput { - fn parse(input: ParseStream) -> Result { - let mut externs: Vec = Vec::new(); - - loop { - if input.is_empty() { - break; - } - externs.push(input.parse()?); - } - - Ok(ExternalInput { externs }) - } -} - -pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ext_input = syn::parse_macro_input!(input as ExternalInput); - - let tvm_rt_crate = crate::util::get_tvm_rt_crate(); - - let result_type = quote! { #tvm_rt_crate::function::Result }; - - let mut items = Vec::new(); - - for external in &ext_input.externs { - let visibility = &external.visibility; - let name = &external.ident; - let global_name = format!("global_{}", external.ident); - let global_name = Ident::new(&global_name, Span::call_site()); - let ext_name = &external.tvm_name; - - let ty_params: Vec = external - .generics - .params - .iter() - .map(|ty_param| match ty_param { - syn::GenericParam::Type(param) => param.clone(), - _ => abort! { ty_param, - "Only supports type parameters." - }, - }) - .collect(); - - let args = &external.inputs; - - let (args, tys): (Vec, Vec) = args - .iter() - .map(|arg| match arg { - FnArg::Typed(pat_type) => match &*pat_type.pat { - Pat::Ident(pat_ident) => { - let ident: Ident = pat_ident.ident.clone(); - let ty: Type = *pat_type.ty.clone(); - (ident, ty) - } - _ => abort! { pat_type, - "Only supports type parameters." - }, - }, - pat => abort! { - pat, "invalid pattern type for function"; - - note = "{:?} is not allowed here", pat; - }, - }) - .unzip(); - - let ret_type = match &external.ret_type { - ReturnType::Type(_, rtype) => *rtype.clone(), - ReturnType::Default => syn::parse_str::("()").unwrap(), - }; - - let global = quote! { - #[allow(non_upper_case_globals)] - static #global_name: ::once_cell::sync::Lazy<#tvm_rt_crate::Function> = - ::once_cell::sync::Lazy::new(|| { - #tvm_rt_crate::Function::get(#ext_name) - .expect(concat!("unable to load external function", stringify!(#ext_name), "from TVM registry.")) - }); - }; - - items.push(global); - - let wrapper = quote! { - #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { - let func_ref: #tvm_rt_crate::Function = #global_name.clone(); - let func_ref: Box #result_type<#ret_type>> = func_ref.into(); - let res: #ret_type = func_ref(#(#args),*)?; - Ok(res) - } - }; - - items.push(wrapper); - } - - proc_macro::TokenStream::from(quote! { - #(#items - )* - }) -} diff --git a/rust/tvm-macros/src/import_module.rs b/rust/tvm-macros/src/import_module.rs deleted file mode 100644 index bebf73b2528f..000000000000 --- a/rust/tvm-macros/src/import_module.rs +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use quote::quote; -use std::{fs::File, io::Read}; -use syn::parse::{Parse, ParseStream, Result}; -use syn::LitStr; - -use std::path::PathBuf; - -struct ImportModule { - importing_file: LitStr, -} - -impl Parse for ImportModule { - fn parse(input: ParseStream) -> Result { - let importing_file: LitStr = input.parse()?; - Ok(ImportModule { importing_file }) - } -} - -pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let import_module_args = syn::parse_macro_input!(input as ImportModule); - - let manifest = - std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); - - let mut path = PathBuf::new(); - path.push(manifest); - path = path.join(import_module_args.importing_file.value()); - - let mut fd = File::open(&path) - .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); - let mut buffer = Vec::new(); - fd.read_to_end(&mut buffer).unwrap(); - - let fn_names = match goblin::Object::parse(&buffer).unwrap() { - goblin::Object::Elf(elf) => elf - .syms - .iter() - .filter_map(|s| { - if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { - return None; - } - match elf.strtab.get(s.st_name) { - Some(Ok(name)) if name != "" => { - Some(syn::Ident::new(name, proc_macro2::Span::call_site())) - } - _ => None, - } - }) - .collect::>(), - goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { - obj.symbols() - .filter_map(|s| match s { - Ok((name, ref nlist)) - if nlist.is_global() - && nlist.n_sect != 0 - && !name.ends_with("tvm_module_ctx") => - { - Some(syn::Ident::new( - if name.starts_with('_') { - // Mach objects prepend a _ to globals. - &name[1..] - } else { - &name - }, - proc_macro2::Span::call_site(), - )) - } - _ => None, - }) - .collect::>() - } - _ => panic!("Unsupported object format."), - }; - - let extern_fns = quote! { - mod ext { - extern "C" { - #( - pub(super) fn #fn_names( - args: *const tvm_graph_rt::ffi::TVMValue, - type_codes: *const std::os::raw::c_int, - num_args: std::os::raw::c_int - ) -> std::os::raw::c_int; - )* - } - } - }; - - let fns = quote! { - use tvm_graph_rt::{ffi::TVMValue, ArgValue, RetValue, FuncCallError}; - #extern_fns - - #( - pub fn #fn_names(args: &[ArgValue]) -> Result { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = unsafe { - ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) - }; - if exit_code == 0 { - Ok(RetValue::default()) - } else { - Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) - } - } - )* - }; - - proc_macro::TokenStream::from(fns) -} diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs deleted file mode 100644 index e563a57f149e..000000000000 --- a/rust/tvm-macros/src/lib.rs +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use proc_macro::TokenStream; -use proc_macro_error::proc_macro_error; - -mod external; -mod import_module; -mod object; -mod util; - -#[proc_macro] -pub fn import_module(input: TokenStream) -> TokenStream { - import_module::macro_impl(input) -} - -#[proc_macro_error] -#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))] -pub fn macro_impl(input: TokenStream) -> TokenStream { - // let input = proc_macro2::TokenStream::from(input); - TokenStream::from(object::macro_impl(input)) -} - -#[proc_macro_error] -#[proc_macro] -pub fn external(input: TokenStream) -> TokenStream { - external::macro_impl(input) -} diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs deleted file mode 100644 index 4134da5fe6d9..000000000000 --- a/rust/tvm-macros/src/object.rs +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::quote; -use syn::DeriveInput; -use syn::Ident; - -use crate::util::*; - -pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { - let tvm_rt_crate = get_tvm_rt_crate(); - let result = quote! { #tvm_rt_crate::function::Result }; - let error = quote! { #tvm_rt_crate::errors::Error }; - let derive_input = syn::parse_macro_input!(input as DeriveInput); - let payload_id = derive_input.ident.clone(); - - let type_key = get_attr(&derive_input, "type_key") - .map(attr_to_str) - .expect("Failed to get type_key"); - - let derive = get_attr(&derive_input, "no_derive") - .map(|_| false) - .unwrap_or(true); - - let ref_id = get_attr(&derive_input, "ref_name") - .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) - .unwrap_or_else(|| { - let id = payload_id.to_string(); - let suffixes = ["Node", "Obj"]; - if let Some(suf) = suffixes - .iter() - .find(|&suf| id.len() > suf.len() && id.ends_with(suf)) - { - Ident::new(&id[..id.len() - suf.len()], payload_id.span()) - } else { - panic!( - "Either 'ref_name' must be given, or the struct name must end one of {:?}", - suffixes - ) - } - }); - - let base_tokens = match &derive_input.data { - syn::Data::Struct(s) => s.fields.iter().next().and_then(|f| { - let (base_id, base_ty) = (f.ident.clone()?, f.ty.clone()); - if base_id == "base" { - // The transitive case of subtyping - Some(quote! { - impl AsRef for #payload_id - where #base_ty: AsRef - { - fn as_ref(&self) -> &O { - self.#base_id.as_ref() - } - } - }) - } else { - None - } - }), - _ => panic!("derive only works for structs"), - }; - - let ref_derives = if derive { - quote! { #[derive(Debug, Clone)]} - } else { - quote! { #[derive(Clone)] } - }; - - let mut expanded = quote! { - unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { - const TYPE_KEY: &'static str = #type_key; - } - - // a silly AsRef impl is necessary for subtyping to work - impl AsRef<#payload_id> for #payload_id { - fn as_ref(&self) -> &Self { - self - } - } - - #ref_derives - pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); - - impl #tvm_rt_crate::object::IsObjectRef for #ref_id { - type Object = #payload_id; - - fn as_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr> { - self.0.as_ref() - } - - fn into_ptr(self) -> Option<#tvm_rt_crate::object::ObjectPtr> { - self.0 - } - - fn from_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr>) -> Self { - #ref_id(object_ptr) - } - } - - impl std::ops::Deref for #ref_id { - type Target = #payload_id; - - fn deref(&self) -> &Self::Target { - self.0.as_ref().unwrap() - } - } - - impl std::convert::From<#payload_id> for #ref_id { - fn from(payload: #payload_id) -> Self { - let ptr = #tvm_rt_crate::object::ObjectPtr::new(payload); - #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr)) - } - } - - impl std::convert::From<#tvm_rt_crate::object::ObjectPtr<#payload_id>> for #ref_id { - fn from(ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id>) -> Self { - #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr)) - } - } - - impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { - type Error = #error; - - fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> { - use std::convert::TryInto; - let ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id> = ret_val.try_into()?; - Ok(ptr.into()) - } - } - - impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { - use std::ffi::c_void; - let object_ptr = &object_ref.0; - match object_ptr { - None => { - #tvm_rt_crate::ArgValue:: - ObjectHandle(std::ptr::null::() as *mut c_void) - } - Some(value) => value.into() - } - } - } - - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #error; - - fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> { - use std::convert::TryInto; - let optr = arg_value.try_into()?; - Ok(#ref_id(Some(optr))) - } - } - - - impl From<#ref_id> for #tvm_rt_crate::RetValue { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue { - use std::ffi::c_void; - let object_ptr = &object_ref.0; - match object_ptr { - None => { - #tvm_rt_crate::RetValue::ObjectHandle(std::ptr::null::() as *mut c_void) - } - Some(value) => value.clone().into() - } - } - } - }; - - expanded.extend(base_tokens); - - if derive { - let derives = quote! { - impl std::hash::Hash for #ref_id { - fn hash(&self, state: &mut H) { - self.0.hash(state) - } - } - - impl std::cmp::PartialEq for #ref_id { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } - } - - impl std::cmp::Eq for #ref_id {} - }; - - expanded.extend(derives); - } - - TokenStream::from(expanded) -} diff --git a/rust/tvm-macros/src/util.rs b/rust/tvm-macros/src/util.rs deleted file mode 100644 index b02e3f69b671..000000000000 --- a/rust/tvm-macros/src/util.rs +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use proc_macro2::TokenStream; -use quote::quote; -use std::env; - -pub fn get_tvm_rt_crate() -> TokenStream { - if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" { - quote!(crate) - } else { - quote!(tvm_rt) - } -} - -pub(crate) fn get_attr<'a>( - derive_input: &'a syn::DeriveInput, - name: &str, -) -> Option<&'a syn::Attribute> { - derive_input.attrs.iter().find(|a| a.path.is_ident(name)) -} - -pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr { - match attr.parse_meta() { - Ok(syn::Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(s), - .. - })) => s, - Ok(m) => panic!("Expected a string literal, got {:?}", m), - Err(e) => panic!("{}", e), - } -} diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore deleted file mode 100644 index 2430329c78b6..000000000000 --- a/rust/tvm-rt/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -target -**/*.rs.bk -Cargo.lock -/tests/basics/add_* -/examples/resnet/deploy_* -/examples/resnet/*.png -/examples/resnet/synset.* diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml deleted file mode 100644 index cb8c560c3efa..000000000000 --- a/rust/tvm-rt/Cargo.toml +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-rt" -version = "0.1.0-alpha" -license = "Apache-2.0" -description = "Rust bindings for the TVM runtime API." -repository = "https://github.com/apache/tvm" -homepage = "https://github.com/apache/tvm" -readme = "README.md" -keywords = ["rust", "tvm"] -categories = ["api-bindings", "science"] -authors = ["TVM Contributors"] -edition = "2018" - -[features] -default = ["dynamic-linking"] -dynamic-linking = ["tvm-sys/dynamic-linking"] -static-linking = ["tvm-sys/static-linking"] -standalone = ["tvm-sys/runtime-only"] -runtime-only = ["tvm-sys/runtime-only"] -blas = ["ndarray/blas"] -# Enabling any of the following features is like setting the value to "ON" in config.cmake. -use-cuda = ["tvm-sys/use-cuda"] -use-opencl = ["tvm-sys/use-opencl"] -use-vulkan = ["tvm-sys/use-vulkan"] -use-metal = ["tvm-sys/use-metal"] -use-rocm = ["tvm-sys/use-rocm"] -use-hexagon-device = ["tvm-sys/use-hexagon-device"] -use-rpc = ["tvm-sys/use-rpc"] -use-threads = ["tvm-sys/use-threads"] -use-llvm = ["tvm-sys/use-llvm"] -use-stackvm-runtime = ["tvm-sys/use-stackvm-runtime"] -use-openmp = ["tvm-sys/use-openmp"] -use-rtti = ["tvm-sys/use-rtti"] -use-mscv-mt = ["tvm-sys/use-mscv-mt"] -use-install-dev = ["tvm-sys/use-install-dev"] -hide-private-symbols = ["tvm-sys/hide-private-symbols"] -use-fallback-stl-map = ["tvm-sys/use-fallback-stl-map"] -use-index-default-i64 = ["tvm-sys/use-index-default-i64"] -use-tf-tvmdsoop = ["tvm-sys/use-tf-tvmdsoop"] -use-byodt-posit = ["tvm-sys/use-byodt-posit"] -use-mkl = ["tvm-sys/use-mkl"] -use-mkldnn = ["tvm-sys/use-mkldnn"] -use-dnnl-codegen = ["tvm-sys/use-dnnl-codegen"] -use-cudnn = ["tvm-sys/use-cudnn"] -use-cublas = ["tvm-sys/use-cublas"] -use-thrust = ["tvm-sys/use-thrust"] -use-miopen = ["tvm-sys/use-miopen"] -use-rocblas = ["tvm-sys/use-rocblas"] -use-sort = ["tvm-sys/use-sort"] -use-nnpack = ["tvm-sys/use-nnpack"] -use-random = ["tvm-sys/use-random"] -use-cpp-rpc = ["tvm-sys/use-cpp-rpc"] -use-tflite = ["tvm-sys/use-tflite"] -use-coreml = ["tvm-sys/use-coreml"] -use-target-onnx = ["tvm-sys/use-target-onnx"] -use-arm-compute-lib = ["tvm-sys/use-arm-compute-lib"] -use-arm-compute-lib-graph-runtime = ["tvm-sys/use-arm-compute-lib-graph-runtime"] -use-tensorrt-codegen = ["tvm-sys/use-tensorrt-codegen"] -use-tensorrt-runtime = ["tvm-sys/use-tensorrt-runtime"] -build-static-runtime = ["tvm-sys/build-static-runtime"] - -[dependencies] -thiserror = "^1.0" -ndarray = "0.12" -num-traits = "0.2" -tvm-macros = { version = "0.1.1-alpha", path = "../tvm-macros" } -paste = "0.1" -mashup = "0.1" -once_cell = "^1.3.1" -memoffset = "0.5.6" - -[dependencies.tvm-sys] -version = "0.1.1-alpha" -default-features = false -path = "../tvm-sys/" - -[dev-dependencies] -anyhow = "^1.0" diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md deleted file mode 100644 index 58b1f8a30a39..000000000000 --- a/rust/tvm-rt/README.md +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - - - - - - - - - - - - -# TVM Runtime Support - -This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime, -see [here](https://github.com/apache/tvm/blob/main/rust/tvm/README.md) for more details. - -## What Does This Crate Offer? - -TVM is an end-to-end deep learning compiler which takes high level machine learning -models or tensor computations and lowers them into executable code for a variety -of heterogenous devices (e.g., CPU, GPU). - -This crate provides access to the APIs for manipulating runtime data structures, -as well as TVM's cross-language Object system which functions similarly to systems -such as COM, enabling cross-language interoperability. - -## Installations - -Please follow TVM [installation](https://tvm.apache.org/docs/install/index.html) instructions, -`export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. - -### Example of registering a cross-language closure. - -One can use `register!` macro to expose a Rust closure with arguments which implement `TryFrom` -and return types which implement `Into`. Once registered with TVM these functions can be -accessed via Python or C++, or any other language which implements the TVM packed function convention -see the offcial documentation for more information. - -```rust -use tvm_rt::{ArgValue, RetValue}; -use tvm_rt::function::{Function, Result, register}; - -fn sum(x: i64, y: i64, z: i64) -> i64 { - x + y + z -} - -fn main() { - register(sum, "mysum".to_owned()).unwrap(); - let func = Function::get("mysum").unwrap(); - let boxed_fn: Box Result> = func.into(); - let ret = boxed_fn(10, 20, 30).unwrap(); - assert_eq!(ret, 60); -} -``` diff --git a/rust/tvm-rt/src/device.rs b/rust/tvm-rt/src/device.rs deleted file mode 100644 index b1cb58cd54cf..000000000000 --- a/rust/tvm-rt/src/device.rs +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::os::raw::c_void; -use std::ptr; - -use crate::errors::Error; - -use tvm_sys::ffi; - -pub use tvm_sys::device::*; - -trait DeviceExt { - /// Checks whether the device exists or not. - fn exist(&self) -> bool; - fn sync(&self) -> Result<(), Error>; - fn max_threads_per_block(&self) -> isize; - fn warp_size(&self) -> isize; - fn max_shared_memory_per_block(&self) -> isize; - fn compute_version(&self) -> isize; - fn device_name(&self) -> isize; - fn max_clock_rate(&self) -> isize; - fn multi_processor_count(&self) -> isize; - fn max_thread_dimensions(&self) -> isize; -} - -macro_rules! impl_device_attrs { - ($(($attr_name:ident, $attr_kind:expr));+) => { - $( - fn $attr_name(&self) -> isize { - get_device_attr(self.device_type as i32, self.device_id as i32, 0) - .expect("should not fail") as isize - } - - )+ - }; -} - -crate::external! { - #[name("runtime.GetDeviceAttr")] - fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32; -} - -impl DeviceExt for Device { - fn exist(&self) -> bool { - let exists = get_device_attr(self.device_type as i32, self.device_id as i32, 0) - .expect("should not fail"); - - exists != 0 - } - - /// Synchronize the device stream. - fn sync(&self) -> Result<(), Error> { - check_call!(ffi::TVMSynchronize( - self.device_type as i32, - self.device_id as i32, - ptr::null_mut() as *mut c_void - )); - Ok(()) - } - - impl_device_attrs!((max_threads_per_block, 1); - (warp_size, 2); - (max_shared_memory_per_block, 3); - (compute_version, 4); - (device_name, 5); - (max_clock_rate, 6); - (multi_processor_count, 7); - (max_thread_dimensions, 8)); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sync() { - let dev = Device::cpu(0); - assert!(dev.sync().is_ok()) - } -} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs deleted file mode 100644 index 31ce385ef662..000000000000 --- a/rust/tvm-rt/src/errors.rs +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::DataType; -use thiserror::Error; - -#[derive(Debug, Error)] -#[error("Function was not set in `function::Builder`")] -pub struct FunctionNotFoundError; - -#[derive(Debug, Error)] -#[error("Expected type `{expected}` but found `{actual}`")] -pub struct TypeMismatchError { - pub expected: String, - pub actual: String, -} - -#[derive(Debug, Error)] -pub enum NDArrayError { - #[error("Cannot convert from an empty array.")] - EmptyArray, - #[error("Invalid datatype when attempting to convert ndarray.")] - InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError), - #[error("a shape error occurred in the Rust ndarray library")] - ShapeError(#[from] ndarray::ShapeError), - #[error("Expected type `{expected}` but found `{actual}`")] - DataTypeMismatch { - expected: DataType, - actual: DataType, - }, -} - -#[derive(Debug, Error)] -pub enum Error { - #[error("{0}")] - Downcast(#[from] tvm_sys::errors::ValueDowncastError), - #[error("raw pointer passed across boundary was null")] - Null, - #[error("failed to load module due to invalid path {0}")] - ModuleLoadPath(String), - #[error("failed to convert String into CString due to embedded nul character")] - ToCString(#[from] std::ffi::NulError), - #[error("failed to convert CString into String")] - FromCString(#[from] std::ffi::IntoStringError), - #[error("Handle `{0}` is null.")] - NullHandle(String), - #[error("{0}")] - NDArray(#[from] NDArrayError), - #[error("{0}")] - CallFailed(String), - #[error("this case will never occur")] - Infallible(#[from] std::convert::Infallible), - #[error("a panic occurred while executing a Rust packed function")] - Panic, - #[error( - "one or more error diagnostics were emitted, please check diagnostic render for output." - )] - DiagnosticError(String), - #[error("{0}")] - Raw(String), -} - -impl Error { - pub fn from_raw_tvm(raw: &str) -> Error { - let err_header = raw.find(":").unwrap_or(0); - let (err_ty, err_content) = raw.split_at(err_header); - match err_ty { - "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()), - _ => Error::Raw(raw.into()), - } - } -} - -impl Error { - pub fn downcast(actual_type: String, expected_type: &'static str) -> Error { - Self::Downcast(tvm_sys::errors::ValueDowncastError { - actual_type, - expected_type, - }) - } -} diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs deleted file mode 100644 index 62474e6650d4..000000000000 --- a/rust/tvm-rt/src/function.rs +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module provides an idiomatic Rust API for creating and working with TVM functions. -//! -//! For calling an already registered TVM function use [`function::Builder`] -//! To register a TVM packed function from Rust side either -//! use [`function::register`] or the macro [`register_global_func`]. -//! -//! See the tests and examples repository for more examples. - -use std::convert::{TryFrom, TryInto}; -use std::sync::Arc; -use std::{ - ffi::CString, - os::raw::{c_char, c_int}, - ptr, str, -}; - -use crate::errors::Error; - -pub use super::to_function::{RawArgs, ToFunction, Typed}; -use crate::object::AsArgValue; -pub use tvm_sys::{ffi, ArgValue, RetValue}; - -pub type Result = std::result::Result; - -#[derive(Debug, Hash)] -struct FunctionPtr { - handle: ffi::TVMFunctionHandle, -} - -// NB(@jroesch): I think this is ok, need to double check, -// if not we should mutex the pointer or move to Rc. -unsafe impl Send for FunctionPtr {} -unsafe impl Sync for FunctionPtr {} - -impl FunctionPtr { - fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { - FunctionPtr { handle } - } -} - -impl Drop for FunctionPtr { - fn drop(&mut self) { - check_call!(ffi::TVMFuncFree(self.handle)); - } -} - -/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust. -#[derive(Debug, Hash)] -pub struct Function { - inner: Arc, -} - -impl Function { - pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { - Function { - inner: Arc::new(FunctionPtr::from_raw(handle)), - } - } - - pub unsafe fn null() -> Self { - Function::from_raw(std::ptr::null_mut()) - } - - /// For a given function, it returns a function by name. - pub fn get>(name: S) -> Option { - let name = CString::new(name.as_ref()).unwrap(); - let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; - - check_call!(ffi::TVMFuncGetGlobal( - name.as_ptr() as *const c_char, - &mut handle as *mut _ - )); - - if handle.is_null() { - None - } else { - Some(Function::from_raw(handle)) - } - } - - pub fn get_boxed(name: S) -> Option> - where - S: AsRef, - F: ?Sized, - Self: Into>, - { - Self::get(name).map(|f| f.into()) - } - - /// Returns the underlying TVM function handle. - pub fn handle(&self) -> ffi::TVMFunctionHandle { - self.inner.handle - } - - /// Calls the function that created from `Builder`. - pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { - let num_args = arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = - arg_buf.into_iter().map(|arg| arg.to_tvm_value()).unzip(); - - let mut ret_val = ffi::TVMValue { v_int64: 0 }; - let mut ret_type_code = 0i32; - - let ret_code = unsafe { - ffi::TVMFuncCall( - self.handle(), - values.as_mut_ptr() as *mut ffi::TVMValue, - type_codes.as_mut_ptr() as *mut c_int, - num_args as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _, - ) - }; - - if ret_code != 0 { - let raw_error = crate::get_last_error(); - let error = match Error::from_raw_tvm(raw_error) { - Error::Raw(string) => Error::CallFailed(string), - e => e, - }; - return Err(error); - } - - let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); - - Ok(rv) - } -} - -macro_rules! impl_to_fn { - () => { impl_to_fn!(@impl); }; - ($t:ident, $($ts:ident,)*) => { impl_to_fn!(@impl $t, $($ts,)*); impl_to_fn!($($ts,)*); }; - (@impl $($t:ident,)*) => { - impl From for Box Result> - where - Error: From, - Out: TryFrom, - $($t: for<'a> AsArgValue<'a>),* - { - fn from(func: Function) -> Self { - #[allow(non_snake_case)] - Box::new(move |$($t : $t),*| { - let args = vec![ $((&$t).as_arg_value()),* ]; - Ok(func.invoke(args)?.try_into()?) - }) - } - } - }; -} - -impl_to_fn!(T1, T2, T3, T4, T5, T6,); - -impl Clone for Function { - fn clone(&self) -> Function { - Function { - inner: self.inner.clone(), - } - } -} - -impl From for RetValue { - fn from(func: Function) -> RetValue { - RetValue::FuncHandle(func.handle()) - } -} - -impl TryFrom for Function { - type Error = Error; - - fn try_from(ret_value: RetValue) -> Result { - match ret_value { - RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), - _ => Err(Error::downcast( - format!("{:?}", ret_value), - "FunctionHandle", - )), - } - } -} - -impl<'a> From<&'a Function> for ArgValue<'a> { - fn from(func: &'a Function) -> ArgValue<'a> { - if func.handle().is_null() { - ArgValue::Null - } else { - ArgValue::FuncHandle(func.handle()) - } - } -} - -impl<'a> TryFrom> for Function { - type Error = Error; - - fn try_from(arg_value: ArgValue<'a>) -> Result { - match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), - _ => Err(Error::downcast( - format!("{:?}", arg_value), - "FunctionHandle", - )), - } - } -} - -impl<'a> TryFrom<&ArgValue<'a>> for Function { - type Error = Error; - - fn try_from(arg_value: &ArgValue<'a>) -> Result { - match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)), - _ => Err(Error::downcast( - format!("{:?}", arg_value), - "FunctionHandle", - )), - } - } -} - -/// Registers a Rust function with an arbitrary type signature in -/// the TVM registry. -/// -/// -/// A function is convertible if and only if its arguments and return types are convertible -/// to and from TVM values respectively. -/// -/// Use [`register_override`] if control of overriding existing global TVM function -/// is required, this function will panic if a function is already registered. -/// -/// ## Example -/// -/// ``` -/// # use tvm_rt::{ArgValue, RetValue}; -/// # use tvm_rt::function::{Function, Result, register}; -/// -/// fn sum(x: i64, y: i64, z: i64) -> i64 { -/// x + y + z -/// } -/// -/// register(sum, "mysum".to_owned()).unwrap(); -/// let func = Function::get("mysum").unwrap(); -/// let boxed_fn: Box Result> = func.into(); -/// let ret = boxed_fn(10, 20, 30).unwrap(); -/// assert_eq!(ret, 60); -/// ``` -pub fn register>(f: F, name: S) -> Result<()> -where - F: ToFunction, - F: Typed, -{ - register_override(f, name, false) -} - -/// Register a function with explicit control over whether to override an existing registration or not. -/// -/// See `register` for more details on how to use the registration API. -pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> -where - F: ToFunction, - F: Typed, -{ - let func = f.to_function(); - let name = name.into(); - // Not sure about this code - let handle = func.handle(); - let name = CString::new(name)?; - check_call!(ffi::TVMFuncRegisterGlobal( - name.into_raw(), - handle, - override_ as c_int - )); - - Ok(()) -} - -pub fn register_untyped>( - f: for<'a> fn(Vec>) -> Result, - name: S, - override_: bool, -) -> Result<()> { - //TODO(@jroesch): can we unify the untpyed and typed registration functions. - let func = ToFunction::::to_function(f); - let name = name.into(); - // Not sure about this code - let handle = func.handle(); - let name = CString::new(name)?; - check_call!(ffi::TVMFuncRegisterGlobal( - name.into_raw(), - handle, - override_ as c_int - )); - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::function::Function; - - static CANARY: &str = "runtime.ModuleLoadFromFile"; - - #[test] - fn get_fn() { - assert!(Function::get(CANARY).is_some()); - assert!(Function::get("does not exists!").is_none()); - } - - #[test] - fn register_and_call_closure0() { - use crate::function; - use function::Result; - - fn constfn() -> i64 { - return 10; - } - - function::register_override(constfn, "constfn".to_owned(), true).unwrap(); - - let func = Function::get_boxed:: Result, _>("constfn").unwrap(); - let ret = func().unwrap(); - assert_eq!(ret, 10); - } - - #[test] - fn register_and_call_closure1() { - use crate::function::{self}; - - fn ident(x: i64) -> i64 { - return x; - } - - function::register_override(ident, "ident".to_owned(), true).unwrap(); - let func = Function::get_boxed:: Result, _>("ident").unwrap(); - assert_eq!(func(60).unwrap(), 60); - } -} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs deleted file mode 100644 index 921117abaee4..000000000000 --- a/rust/tvm-rt/src/lib.rs +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! [TVM](https://github.com/apache/tvm) is a compiler stack for deep learning systems. -//! -//! This crate provides an idiomatic Rust API for TVM runtime. -//! -//! The TVM runtime API contains the data structures used by higher-level TVM executors. -//! Specifically it exposes the basic types such as NDArray, as well as the more general object system. -//! The TVM object system enables cross-language interoperability including that of closures for all -//! supported languages including C++, and Python. - -// Macro to check the return call to TVM runtime shared library. - -#[macro_export] -macro_rules! tvm_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - Err($crate::get_last_error().into()) - } else { - Ok(()) - } - }}; -} - -#[macro_export] -macro_rules! check_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - panic!("{}", $crate::get_last_error()); - } - }}; -} - -// Define all sumodules. -pub mod device; -pub mod errors; -pub mod function; -pub mod module; -pub mod ndarray; -pub mod object; -pub mod string; -mod to_function; - -pub use object::*; -pub use string::*; - -use std::{ - ffi::{CStr, CString}, - str, -}; - -pub use crate::{ - device::{Device, DeviceType}, - errors::*, - function::Function, - module::Module, - ndarray::NDArray, -}; - -pub use function::{ArgValue, RetValue}; -pub use tvm_sys::byte_array::ByteArray; -pub use tvm_sys::datatype::DataType; -use tvm_sys::ffi; - -pub use tvm_macros::external; - -/// Gets the last error message. -pub fn get_last_error() -> &'static str { - unsafe { - match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { - Ok(s) => s, - Err(_) => "Invalid UTF-8 message", - } - } -} - -pub(crate) fn set_last_error(err: &E) { - let c_string = CString::new(err.to_string()).unwrap(); - unsafe { - ffi::TVMAPISetLastError(c_string.as_ptr()); - } -} - -/// Outputs the current TVM version. -pub fn version() -> &'static str { - match str::from_utf8(ffi::TVM_VERSION) { - Ok(s) => s, - Err(_) => "Invalid UTF-8 string", - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ByteArray, DataType, Device}; - use std::{convert::TryInto, str::FromStr}; - - #[test] - fn print_version() { - println!("TVM version: {}", version()); - } - - #[test] - fn set_error() { - let err = errors::NDArrayError::EmptyArray; - set_last_error(&err); - assert_eq!( - get_last_error().trim(), - errors::NDArrayError::EmptyArray.to_string() - ); - } - - // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. - // #[test] - // fn bytearray() { - // let w = vec![1u8, 2, 3, 4, 5]; - // let v = ByteArray::from(w.as_slice()); - // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - // assert_eq!( - // tvm.data(), - // w.iter().copied().collect::>().as_slice() - // ); - // } - - #[test] - fn ty() { - let t = DataType::from_str("int32").unwrap(); - let tvm: DataType = RetValue::from(t).try_into().unwrap(); - assert_eq!(tvm, t); - } - - #[test] - fn device() { - let c = Device::from_str("cuda").unwrap(); - let tvm: Device = RetValue::from(c).try_into().unwrap(); - assert_eq!(tvm, c); - } -} diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs deleted file mode 100644 index 754ebf44262e..000000000000 --- a/rust/tvm-rt/src/module.rs +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! Provides the [`Module`] type and methods for working with runtime TVM modules. - -use std::{ - ffi::CString, - os::raw::{c_char, c_int}, - path::Path, - ptr, -}; - -use crate::object::Object; -use tvm_macros::Object; -use tvm_sys::ffi; - -use crate::errors::Error; -use crate::String as TString; -use crate::{errors, function::Function}; - -/// Wrapper around TVM module handle which contains an entry function. -/// The entry function can be applied to an imported module through [`entry_func`]. -/// -/// [`entry_func`]:struct.Module.html#method.entry_func -#[repr(C)] -#[derive(Object, Debug)] -#[ref_name = "Module"] -#[type_key = "runtime.Module"] -pub struct ModuleNode { - base: Object, -} - -crate::external! { - #[name("runtime.RuntimeEnabled")] - fn runtime_enabled(target: CString) -> bool; - - #[name("runtime.ModuleLoadFromFile")] - fn load_from_file(file_name: CString, format: CString) -> Module; - - #[name("runtime.ModuleSaveToFile")] - fn save_to_file(module: Module, name: TString, fmt: TString); - - // TODO(@jroesch): we need to refactor this - #[name("tvm.relax.module_export_library")] - fn export_library(module: Module, file_name: TString); -} - -impl Module { - pub fn default_fn(&mut self) -> Result { - self.get_function("default", true) - } - - /// Gets a function by name from a registered module. - pub fn get_function(&self, name: &str, query_import: bool) -> Result { - let name = CString::new(name)?; - let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; - - check_call!(ffi::TVMModGetFunction( - self.handle(), - name.as_ptr() as *const c_char, - query_import as c_int, - &mut fhandle as *mut _ - )); - - if fhandle.is_null() { - return Err(errors::Error::NullHandle(name.into_string()?.to_string())); - } - - Ok(Function::from_raw(fhandle)) - } - - /// Imports a dependent module such as `.ptx` for cuda gpu. - pub fn import_module(&self, dependent_module: Module) { - check_call!(ffi::TVMModImport(self.handle(), dependent_module.handle())) - } - - /// Loads a module shared library from path. - pub fn load>(path: &P) -> Result { - let ext = CString::new( - path.as_ref() - .extension() - .unwrap_or_else(|| std::ffi::OsStr::new("")) - .to_str() - .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, - )?; - - let cpath = CString::new( - path.as_ref() - .to_str() - .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, - )?; - - let module = load_from_file(cpath, ext)?; - Ok(module) - } - - pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> { - save_to_file(self.clone(), name.into(), fmt.into()) - } - - pub fn export_library(&self, name: String) -> Result<(), Error> { - export_library(self.clone(), name.into()) - } - - /// Checks if a target device is enabled for a module. - pub fn enabled(&self, target: &str) -> bool { - let target = CString::new(target).unwrap(); - runtime_enabled(target).unwrap() - } - - /// Returns the underlying module handle. - pub unsafe fn handle(&self) -> ffi::TVMModuleHandle { - self.0.clone().unwrap().into_raw() as *mut _ - } -} diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs deleted file mode 100644 index dd3882a098e2..000000000000 --- a/rust/tvm-rt/src/ndarray.rs +++ /dev/null @@ -1,515 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module implements the [`NDArray`] type for working with *TVM tensors* or -//! coverting from a Rust's ndarray to TVM `NDArray`. -//! -//! One can create an empty NDArray given the shape, device and dtype using [`empty`]. -//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. -//! To copy an NDArray to different device use [`copy_to_device`]. -//! -//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: -//! -//! # Example -//! -//! ``` -//! # use tvm_rt::{NDArray, DataType, Device}; -//! # use ndarray::{Array, ArrayD}; -//! # use std::str::FromStr; -//! use std::convert::TryFrom; -//! -//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) -//! .unwrap() -//! .into_dyn(); // Rust's ndarray -//! let nd = NDArray::from_rust_ndarray(&a, Device::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); -//! assert_eq!(nd.shape(), &[2, 2]); -//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); -//! assert!(rnd.all_close(&a, 1e-8f32)); -//! ``` -//! -//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ -//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer -//! [`copy_to_device`]:struct.NDArray.html#method.copy_to_device - -use std::ffi::c_void; -use std::{borrow::Cow, convert::TryInto}; -use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; - -use mem::size_of; -use tvm_macros::Object; -use tvm_sys::ffi::DLTensor; -use tvm_sys::{ffi, ByteArray, DataType, Device}; - -use ndarray::{Array, ArrayD}; -use num_traits::Num; - -use crate::errors::NDArrayError; - -use crate::object::{Object, ObjectPtr, ObjectRef}; - -/// See the [`module-level documentation`](../ndarray/index.html) for more details. -#[repr(C)] -#[derive(Object, Debug)] -#[ref_name = "NDArray"] -#[type_key = "runtime.NDArray"] -pub struct NDArrayContainer { - base: Object, - // Container Base - dl_tensor: DLTensor, - manager_ctx: *mut c_void, - shape: ObjectRef, -} - -impl NDArrayContainer { - pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option> { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) }; - let object_ptr = ObjectPtr::from_raw(base_ptr.cast()); - object_ptr.map(|ptr| { - ptr.downcast::() - .expect("we know this is an NDArray container") - }) - } - - pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut NDArrayContainer - where - NDArrayContainer: 'a, - { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - unsafe { - &mut *std::mem::ManuallyDrop::new(object_ptr) - .ptr - .as_ptr() - .cast::() - .offset(base_offset) - .cast::() - } - } - - pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer - where - NDArrayContainer: 'a, - { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - unsafe { - object_ptr - .ptr - .as_ptr() - .cast::() - .offset(base_offset) - .cast::() - } - } -} - -fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { - if std::mem::size_of::() == 64 { - debug_assert!(slice.iter().all(|&x| x >= 0)); - let shape: &[usize] = unsafe { std::mem::transmute(slice) }; - Cow::Borrowed(shape) - } else { - let shape: Vec = slice - .iter() - .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot fit into usize: {}", x))) - .collect(); - Cow::Owned(shape) - } -} - -impl NDArray { - pub(crate) fn _from_raw(handle: ffi::TVMArrayHandle) -> Self { - let ptr = NDArrayContainer::from_raw(handle); - NDArray(ptr) - } - - // I think these should be marked as unsafe functions? projecting a reference is bad news. - pub fn as_dltensor(&self) -> &DLTensor { - &self.dl_tensor - } - - pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - unsafe { std::mem::transmute(self.as_dltensor()) } - } - - pub fn is_view(&self) -> bool { - false - } - - /// Returns the shape of the NDArray. - pub fn shape(&self) -> &[i64] { - let arr = self.as_dltensor(); - if arr.shape.is_null() || arr.data.is_null() { - &[] - } else { - unsafe { slice::from_raw_parts(arr.shape, self.ndim()) } - } - } - - /// Returns the shape of the NDArray as a &[usize] - /// - /// On 64-bit platforms, this is zero-cost and uses the shape from the DLTensor. - /// On other platforms, this copies into a buffer. - pub fn shape_usize(&self) -> Cow<[usize]> { - cow_usize(self.shape()) - } - - /// Returns the strides of the underlying NDArray. - pub fn strides(&self) -> Option<&[i64]> { - let arr = self.as_dltensor(); - if arr.strides.is_null() { - None - } else { - Some(unsafe { slice::from_raw_parts(arr.strides, self.ndim()) }) - } - } - - /// Returns the strides of the NDArray as a &[usize] - /// - /// On 64-bit platforms, this is zero-cost and uses the strides from the DLTensor. - /// On other platforms, this copies into a buffer. - pub fn strides_usize(&self) -> Option> { - self.strides().map(cow_usize) - } - - /// Returns true if the tensor is empty - pub fn is_empty(&self) -> bool { - self.as_dltensor().data.is_null() - } - - /// Returns the total number of entries of the NDArray. - pub fn len(&self) -> usize { - let len: i64 = self.shape().iter().product(); - usize::try_from(len).unwrap_or_else(|_| panic!("bad len: {}", len)) - } - - /// Returns the total bytes taken up by the data. - /// This is equal to `nd.len() * nd.dtype().itemsize` - pub fn size(&self) -> usize { - self.len() * self.dtype().itemsize - } - - /// Returns the device which the NDArray was defined. - pub fn device(&self) -> Device { - self.as_dltensor().device.into() - } - - /// Returns the type of the entries of the NDArray. - pub fn dtype(&self) -> DataType { - self.as_dltensor().dtype.into() - } - - /// Returns the number of dimensions of the NDArray. - pub fn ndim(&self) -> usize { - self.as_dltensor() - .ndim - .try_into() - .expect("number of dimensions must always be positive") - } - - /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> bool { - match self.strides() { - None => true, - Some(strides) => { - // NDArrayError::MissingShape in case shape is not determined - self.shape() - .iter() - .zip(strides) - .rfold( - (true, 1), - |(is_contig, expected_stride), (shape, stride)| { - ( - is_contig && *stride == expected_stride, - expected_stride * shape, - ) - }, - ) - .0 - } - } - } - - pub fn byte_offset(&self) -> isize { - self.as_dltensor().byte_offset as isize - } - - /// Flattens the NDArray to a `Vec` of the same type in cpu. - /// - /// ## Example - /// - /// ``` - /// # use tvm_rt::{Device, DataType, NDArray}; - /// # use std::str::FromStr; - /// let mut shape = [4]; - /// let mut data = vec![1i32, 2, 3, 4]; - /// let dev = Device::cpu(0); - /// let mut ndarray = NDArray::empty(&mut shape, dev, DataType::from_str("int32").unwrap()); - /// ndarray.copy_from_buffer(&mut data); - /// assert_eq!(ndarray.shape(), shape); - /// assert_eq!(ndarray.to_vec::().unwrap(), data); - /// ``` - pub fn to_vec(&self) -> Result, NDArrayError> { - let n = self.size() / size_of::(); - let mut vec: Vec = Vec::with_capacity(n); - - let ptr = vec.as_mut_ptr(); - let slice = unsafe { slice::from_raw_parts_mut(ptr, n) }; - self.copy_to_buffer(slice); - - unsafe { vec.set_len(n) }; - Ok(vec) - } - - /// Converts the NDArray to [`ByteArray`]. - pub fn to_bytearray(&self) -> Result { - let v = self.to_vec::()?; - Ok(ByteArray::from(v)) - } - - /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. - /// - /// ## Example - /// - /// ``` - /// # use tvm_rt::{Device, DataType, NDArray}; - /// # use std::str::FromStr; - /// let shape = &mut [2]; - /// let mut data = vec![1f32, 2.0]; - /// let dev = Device::cpu(0); - /// let mut ndarray = NDArray::empty(shape, dev, DataType::from_str("int32").unwrap()); - /// ndarray.copy_from_buffer(&mut data); - /// ``` - /// - /// *Note*: if something goes wrong during the copy, it will panic - /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. - pub fn copy_from_buffer(&mut self, data: &[T]) { - check_call!(ffi::TVMArrayCopyFromBytes( - self.as_raw_dltensor(), - data.as_ptr() as *mut _, - (data.len() * mem::size_of::()) as _, - )); - } - - pub fn copy_to_buffer(&self, data: &mut [T]) { - assert_eq!(self.size(), data.len() * size_of::()); - check_call!(ffi::TVMArrayCopyToBytes( - self.as_raw_dltensor(), - data.as_ptr() as *mut _, - self.size() as _, - )); - } - - pub fn fill_from_iter(&mut self, iter: I) - where - T: Num32, - I: ExactSizeIterator, - { - assert!(self.is_contiguous()); - assert_eq!(self.size(), size_of::() * iter.len()); - let mut ptr: *mut T = self.as_dltensor().data.cast(); - iter.for_each(|x| unsafe { - ptr.write(x); - ptr = ptr.add(1); - }) - } - - /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result { - if self.dtype() != target.dtype() { - return Err(NDArrayError::DataTypeMismatch { - expected: self.dtype(), - actual: target.dtype(), - }); - } - - check_call!(ffi::TVMArrayCopyFromTo( - self.as_raw_dltensor(), - target.as_raw_dltensor(), - ptr::null_mut() as ffi::TVMStreamHandle - )); - - Ok(target) - } - - /// Copies the NDArray to a target device. - pub fn copy_to_device(&self, target: &Device) -> Result { - let tmp = NDArray::empty(self.shape(), *target, self.dtype()); - let copy = self.copy_to_ndarray(tmp)?; - Ok(copy) - } - - /// Converts a Rust's ndarray to TVM NDArray. - pub fn from_rust_ndarray( - input_nd: &ArrayD, - dev: Device, - dtype: DataType, - ) -> Result { - let shape: Vec = input_nd.shape().iter().map(|&x| x as i64).collect(); - let mut nd = NDArray::empty(&shape, dev, dtype); - nd.fill_from_iter(input_nd.iter().copied()); - Ok(nd) - } - - /// Allocates and creates an empty NDArray given the shape, device and dtype. - pub fn empty(shape: &[i64], dev: Device, dtype: DataType) -> NDArray { - let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; - let dtype: tvm_sys::ffi::DLDataType = dtype.into(); - check_call!(ffi::TVMArrayAlloc( - shape.as_ptr(), - shape.len() as c_int, - i32::from(dtype.code) as c_int, - i32::from(dtype.bits) as c_int, - i32::from(dtype.lanes) as c_int, - dev.device_type as c_int, - dev.device_id as c_int, - &mut handle as *mut _, - )); - let ptr = NDArrayContainer::from_raw(handle) - .map(|o| o.downcast().expect("this should never fail")); - NDArray(ptr) - } - - pub fn zeroed(self) -> NDArray { - unsafe { - let dltensor = self.as_raw_dltensor(); - let bytes_ptr: *mut u8 = std::mem::transmute((*dltensor).data); - println!("size {}", self.size()); - std::ptr::write_bytes(bytes_ptr, 0, self.size()); - self - } - } -} - -macro_rules! impl_from_ndarray_rustndarray { - ($type:ty, $type_name:tt) => { - impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { - type Error = NDArrayError; - - fn try_from(nd: &NDArray) -> Result, Self::Error> { - assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape_usize(), - nd.to_vec::<$type>()?, - )?) - } - } - - impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { - type Error = NDArrayError; - - fn try_from(nd: &mut NDArray) -> Result, Self::Error> { - assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape_usize(), - nd.to_vec::<$type>()?, - )?) - } - } - }; -} - -impl_from_ndarray_rustndarray!(i32, "int"); -impl_from_ndarray_rustndarray!(u32, "uint"); -impl_from_ndarray_rustndarray!(f32, "float"); - -mod sealed { - /// Private trait to prevent other traits from being implemeneted in downstream crates. - pub trait Sealed {} -} - -/// A trait for the supported 32-bits numerical types in frontend. -pub trait Num32: Num + sealed::Sealed { - const BITS: u8 = 32; -} - -macro_rules! impl_num32 { - ($($type:ty),+) => { - $( - impl sealed::Sealed for $type {} - impl Num32 for $type {} - )+ - }; -} - -impl_num32!(i32, u32, f32); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn basics() { - let shape = &[1, 2, 3]; - let dev = Device::cpu(0); - println!("before empty"); - let ndarray = NDArray::empty(shape, dev, DataType::from_str("int32").unwrap()); - println!("after empty"); - assert_eq!(ndarray.shape(), shape); - assert_eq!(ndarray.len(), shape.iter().product::() as usize); - assert_eq!(ndarray.ndim(), 3); - assert!(ndarray.strides().is_none()); - assert_eq!(ndarray.byte_offset(), 0); - } - - #[test] - fn copy() { - let shape = &[4]; - let data = vec![1i32, 2, 3, 4]; - let dev = Device::cpu(0); - let mut ndarray = NDArray::empty(shape, dev, DataType::int(32, 1)).zeroed(); - assert_eq!(ndarray.to_vec::().unwrap(), vec![0, 0, 0, 0]); - ndarray.copy_from_buffer(&data); - assert_eq!(ndarray.shape(), shape); - assert_eq!(ndarray.to_vec::().unwrap(), data); - assert_eq!(ndarray.ndim(), 1); - assert!(ndarray.is_contiguous()); - assert_eq!(ndarray.byte_offset(), 0); - let shape = vec![4]; - let e = NDArray::empty(&shape, Device::cpu(0), DataType::from_str("int32").unwrap()); - let nd = ndarray.copy_to_ndarray(e); - assert!(nd.is_ok()); - assert_eq!(nd.unwrap().to_vec::().unwrap(), data); - } - - /// This occasionally panics on macOS: https://github.com/rust-lang/rust/issues/71397 - #[test] - #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - fn copy_wrong_dtype() { - let shape = vec![4]; - let mut data = vec![1f32, 2., 3., 4.]; - let dev = Device::cpu(0); - let mut nd_float = NDArray::empty(&shape, dev, DataType::from_str("float32").unwrap()); - nd_float.copy_from_buffer(&mut data); - let empty_int = NDArray::empty(&shape, dev, DataType::from_str("int32").unwrap()); - nd_float.copy_to_ndarray(empty_int).unwrap(); - } - - #[test] - fn rust_ndarray() { - let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) - .unwrap() - .into_dyn(); - let nd = - NDArray::from_rust_ndarray(&a, Device::cpu(0), DataType::from_str("float32").unwrap()) - .unwrap(); - assert_eq!(nd.shape(), &[2, 2]); - let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); - assert!(rnd.all_close(&a, 1e-8f32)); - } -} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs deleted file mode 100644 index f5832fcb3ab8..000000000000 --- a/rust/tvm-rt/src/object/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::convert::TryFrom; -use std::ffi::CString; - -use crate::errors::Error; -use crate::external; - -use tvm_sys::{ArgValue, RetValue}; - -mod object_ptr; - -pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; - -pub trait AsArgValue<'a> { - fn as_arg_value(&'a self) -> ArgValue<'a>; -} - -impl<'a, T: 'static> AsArgValue<'a> for T -where - &'a T: Into>, -{ - fn as_arg_value(&'a self) -> ArgValue<'a> { - self.into() - } -} - -// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we -// can't because of coherence rules. Instead, we generate them in the macro, and -// add what we can (including Into instead of From) as subtraits. -// We also add named conversions for clarity -pub trait IsObjectRef: - Sized - + Clone - + Into - + for<'a> AsArgValue<'a> - + TryFrom - + for<'a> TryFrom, Error = Error> - + std::fmt::Debug -{ - type Object: IsObject; - fn as_ptr(&self) -> Option<&ObjectPtr>; - fn into_ptr(self) -> Option>; - fn from_ptr(object_ptr: Option>) -> Self; - - fn null() -> Self { - Self::from_ptr(None) - } - - fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { - self.as_arg_value() - } - - fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { - Self::try_from(arg_value) - } - - fn into_ret_value<'a>(self) -> RetValue { - self.into() - } - - fn from_ret_value<'a>(ret_value: RetValue) -> Result { - Self::try_from(ret_value) - } - - fn upcast(self) -> U - where - U: IsObjectRef, - Self::Object: AsRef, - { - let ptr = self.into_ptr().map(ObjectPtr::upcast); - U::from_ptr(ptr) - } - - fn downcast(self) -> Result - where - U: IsObjectRef, - U::Object: AsRef, - { - let ptr = self.into_ptr().map(ObjectPtr::downcast); - let ptr = ptr.transpose()?; - Ok(U::from_ptr(ptr)) - } -} - -external! { - #[name("ir.DebugPrint")] - pub fn debug_print(object: ObjectRef) -> CString; - #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; - #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; -} diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs deleted file mode 100644 index 09d6068f1a88..000000000000 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ /dev/null @@ -1,555 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::convert::TryFrom; -use std::ffi::CString; -use std::fmt; -use std::os::raw::c_char; -use std::ptr::NonNull; -use std::sync::atomic::AtomicI32; - -use tvm_macros::Object; -use tvm_sys::ffi::{ - self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, -}; -use tvm_sys::{ArgValue, RetValue}; - -use crate::errors::Error; - -type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); - -/// A TVM intrusive smart pointer header, in TVM all FFI compatible types -/// start with an Object as their first field. The base object tracks -/// a type_index which is an index into the runtime type information -/// table, an atomic reference count, and a customized deleter which -/// will be invoked when the reference count is zero. -/// -#[derive(Debug, Object)] -#[ref_name = "ObjectRef"] -#[type_key = "runtime.Object"] -#[repr(C)] -pub struct Object { - /// The index into TVM's runtime type information table. - pub(self) type_index: u32, - // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. - // NB: in general we should not touch this in Rust. - /// The reference count of the smart pointer. - pub(self) ref_count: AtomicI32, - /// The deleter function which is used to deallocate the underlying data - /// when the reference count is zero. This field must always be set for - /// all objects. - /// - /// The common use case is ensuring that the allocator which allocated the - /// data is also the one that deletes it. - pub(self) fdeleter: Deleter, -} - -/// The default deleter for objects allocated in Rust, we use a bit of -/// trait magic here to get a monomorphized deleter for each object -/// "subtype". -/// -/// This function just converts the pointer to the correct type -/// and reconstructs a Box which then is dropped to deallocate -/// the underlying allocation. -unsafe extern "C" fn delete(object: *mut Object) { - let typed_object: *mut T = object as *mut T; - let boxed: Box = Box::from_raw(typed_object); - drop(boxed); -} - -fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { - let mut is_derived = 0; - crate::check_call!(ffi::TVMObjectDerivedFrom( - child_type_index, - parent_type_index, - &mut is_derived - )); - - if is_derived == 0 { - false - } else { - true - } -} - -impl Object { - fn new(type_index: u32, deleter: Deleter) -> Object { - Object { - type_index, - // NB(@jroesch): I believe it is sound to use Rust atomics - // in conjunction with C++ atomics given the memory model - // is nearly identical. - // - // Of course these are famous last words which I may later - // regret. - ref_count: AtomicI32::new(0), - fdeleter: deleter, - } - } - - fn get_type_key(&self) -> String { - let mut cstring: *mut c_char = std::ptr::null_mut(); - unsafe { - if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { - panic!("{}", crate::get_last_error()); - } - return CString::from_raw(cstring) - .into_string() - .expect("type keys should be valid utf-8"); - } - } - - fn get_type_index() -> u32 { - let type_key = T::TYPE_KEY; - let cstring = CString::new(type_key).expect("type key must not contain null characters"); - - // TODO(@jroesch): look into TVMObjectTypeKey2Index. - if type_key == "runtime.Object" { - return 0; - } else { - let mut index = 0; - unsafe { - if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 { - panic!("{}", crate::get_last_error()) - } - } - return index; - } - } - - pub fn count(&self) -> i32 { - // need to do atomic read in C++ - // ABI compatible atomics is funky/hard. - self.ref_count.load(std::sync::atomic::Ordering::Relaxed) - } - - /// Allocates a base object value for an object subtype of type T. - /// By using associated constants and generics we can provide a - /// type indexed abstraction over allocating objects with the - /// correct index and deleter. - pub fn base() -> Object { - let index = Object::get_type_index::(); - Object::new(index, delete::) - } - - /// Increases the object's reference count by one. - pub(self) fn inc_ref(&self) { - let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void; - unsafe { - assert_eq!(TVMObjectRetain(raw_ptr), 0); - } - } - - /// Decreases the object's reference count by one. - pub(self) fn dec_ref(&self) { - let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void; - unsafe { - assert_eq!(TVMObjectFree(raw_ptr), 0); - } - } -} - -/// An unsafe trait which should be implemented for an object -/// subtype. -/// -/// The trait contains the type key needed to compute the type -/// index, a method for accessing the base object given the -/// subtype, and a typed delete method which is specialized -/// to the subtype. -pub unsafe trait IsObject: AsRef + std::fmt::Debug { - const TYPE_KEY: &'static str; -} - -/// A smart pointer for types which implement IsObject. -/// This type directly corresponds to TVM's C++ type ObjectPtr. -/// -/// See object.h for more details. -#[repr(C)] -pub struct ObjectPtr { - pub ptr: NonNull, -} - -impl ObjectPtr { - pub fn from_raw(object_ptr: *mut Object) -> Option> { - let non_null = NonNull::new(object_ptr); - non_null.map(|ptr| { - debug_assert!(unsafe { ptr.as_ref().count() } >= 0); - ObjectPtr { ptr } - }) - } -} - -impl Clone for ObjectPtr { - fn clone(&self) -> Self { - unsafe { self.ptr.as_ref().as_ref().inc_ref() } - ObjectPtr { ptr: self.ptr } - } -} - -impl Drop for ObjectPtr { - fn drop(&mut self) { - unsafe { self.ptr.as_ref().as_ref().dec_ref() } - } -} - -impl ObjectPtr { - pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut T - where - T: 'a, - { - unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr() } - } - - pub fn new(object: T) -> ObjectPtr { - object.as_ref().inc_ref(); - let object_ptr = Box::new(object); - let object_ptr = Box::leak(object_ptr); - let ptr = NonNull::from(object_ptr); - ObjectPtr { ptr } - } - - pub fn count(&self) -> i32 { - // need to do atomic read in C++ - // ABI compatible atomics is funky/hard. - self.as_ref() - .ref_count - .load(std::sync::atomic::Ordering::Relaxed) - } - - /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory - unsafe fn cast(self) -> ObjectPtr { - let ptr = self.ptr.cast(); - std::mem::forget(self); - ObjectPtr { ptr } - } - - pub fn upcast(self) -> ObjectPtr - where - U: IsObject, - T: AsRef, - { - unsafe { self.cast() } - } - - pub fn downcast(self) -> Result, Error> - where - U: IsObject + AsRef, - { - let child_index = Object::get_type_index::(); - let object_index = self.as_ref().type_index; - - let is_derived = if child_index == object_index { - true - } else { - // TODO(@jroesch): write tests - derived_from(object_index, child_index) - }; - - if is_derived { - Ok(unsafe { self.cast() }) - } else { - let type_key = self.as_ref().get_type_key(); - Err(Error::downcast(type_key.into(), U::TYPE_KEY)) - } - } - - pub unsafe fn into_raw(self) -> *mut T { - self.ptr.as_ptr() - } - - pub unsafe fn as_ptr(&self) -> *mut T { - self.ptr.as_ptr() - } -} - -impl std::ops::Deref for ObjectPtr { - type Target = T; - - fn deref(&self) -> &Self::Target { - unsafe { self.ptr.as_ref() } - } -} - -impl fmt::Debug for ObjectPtr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use std::ops::Deref; - write!(f, "{:?}", self.deref()) - } -} - -impl<'a, T: IsObject> From> for RetValue { - fn from(object_ptr: ObjectPtr) -> RetValue { - let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; - assert!(!raw_object_ptr.is_null()); - RetValue::ObjectHandle(raw_object_ptr) - } -} - -impl<'a, T: IsObject> TryFrom for ObjectPtr { - type Error = Error; - - fn try_from(ret_value: RetValue) -> Result, Self::Error> { - use crate::ffi::DLTensor; - use crate::ndarray::NDArrayContainer; - - match ret_value { - RetValue::ObjectHandle(handle) | RetValue::ModuleHandle(handle) => { - let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.downcast() - } - RetValue::NDArrayHandle(handle) => { - let optr: ObjectPtr = - NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.upcast::().downcast() - } - _ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)), - } - } -} - -impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { - fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { - debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.clone().upcast::(); - match T::TYPE_KEY { - "runtime.NDArray" => { - use crate::ndarray::NDArrayContainer; - let dcast_ptr = object_ptr.downcast().unwrap(); - let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::NDArrayHandle(raw_ptr) - } - "runtime.Module" => { - let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::ModuleHandle(raw_ptr) - } - _ => { - let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::ObjectHandle(raw_ptr) - } - } - } -} - -impl<'a, T: IsObject> TryFrom> for ObjectPtr { - type Error = Error; - - fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { - use crate::ffi::DLTensor; - use crate::ndarray::NDArrayContainer; - - match arg_value { - ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { - let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - optr.inc_ref(); - // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must - // bump the reference count by one. - assert!(optr.count() >= 1); - optr.downcast() - } - ArgValue::NDArrayHandle(handle) => { - let optr = - NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must - // bump the reference count by one. - assert!(optr.count() >= 1); - // TODO(@jroesch): figure out if there is a more optimal way to do this - let object = optr.upcast::(); - object.inc_ref(); - object.downcast() - } - _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), - } - } -} - -impl std::hash::Hash for ObjectPtr { - fn hash(&self, state: &mut H) { - state.write_i64( - super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(), - ) - } -} - -impl PartialEq for ObjectPtr { - fn eq(&self, other: &Self) -> bool { - let lhs = ObjectRef(Some(self.clone().upcast())); - let rhs = ObjectRef(Some(other.clone().upcast())); - super::structural_equal(lhs, rhs, false, false).unwrap() - } -} - -impl Eq for ObjectPtr {} - -#[cfg(test)] -mod tests { - use super::{Object, ObjectPtr}; - use anyhow::{ensure, Result}; - use std::convert::TryInto; - use tvm_sys::{ArgValue, RetValue}; - - #[test] - fn test_new_object() -> anyhow::Result<()> { - let object = Object::base::(); - let ptr = ObjectPtr::new(object); - assert_eq!(ptr.count(), 1); - Ok(()) - } - - #[test] - fn test_leak() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let object = ObjectPtr::leak(ptr); - assert_eq!(object.count(), 1); - Ok(()) - } - - #[test] - fn test_clone() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ptr2 = ptr.clone(); - assert_eq!(ptr2.count(), 2); - drop(ptr); - assert_eq!(ptr2.count(), 1); - Ok(()) - } - - #[test] - fn roundtrip_retvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ret_value: RetValue = ptr.clone().into(); - let ptr2: ObjectPtr = ret_value.try_into()?; - assert_eq!(ptr.count(), ptr2.count()); - assert_eq!(ptr.count(), 2); - ensure!( - ptr.type_index == ptr2.type_index, - "type indices do not match" - ); - ensure!( - ptr.fdeleter == ptr2.fdeleter, - "objects have different deleters" - ); - // After dropping the second pointer we should only see only refcount. - drop(ptr2); - assert_eq!(ptr.count(), 1); - Ok(()) - } - - #[test] - fn roundtrip_argvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ptr_clone = ptr.clone(); - assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = (&ptr_clone).into(); - assert_eq!(ptr.count(), 2); - let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 3); - assert_eq!(ptr.count(), ptr2.count()); - drop(ptr_clone); - assert_eq!(ptr.count(), 2); - ensure!( - ptr.type_index == ptr2.type_index, - "type indices do not match" - ); - ensure!( - ptr.fdeleter == ptr2.fdeleter, - "objects have different deleters" - ); - // After dropping the second pointer we should only see only refcount. - drop(ptr2); - assert_eq!(ptr.count(), 1); - Ok(()) - } - - fn test_fn_raw<'a>( - mut args: crate::to_function::ArgList<'a>, - ) -> crate::function::Result { - let v: ArgValue = args.remove(0); - let v2: ArgValue = args.remove(0); - // assert_eq!(o.count(), 2); - let o: ObjectPtr = v.try_into().unwrap(); - assert_eq!(o.count(), 2); - let o2: ObjectPtr = v2.try_into().unwrap(); - assert_eq!(o2.count(), 3); - drop(o2); - assert_eq!(o.count(), 2); - Ok(o.into()) - } - - #[test] - fn test_ref_count_raw_fn() { - use super::*; - use crate::function::{register_untyped, Function}; - let ptr = ObjectPtr::new(Object::base::()); - // Call the function without the wrapping for TVM. - assert_eq!(ptr.count(), 1); - let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = same.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - - register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); - let raw_func = Function::get("test_fn_raw").unwrap(); - let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = output.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - } - - fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { - assert_eq!(o.count(), 3); - assert_eq!(o2.count(), 3); - drop(o2); - assert_eq!(o.count(), 2); - return o; - } - - #[test] - fn test_ref_count_typed() { - use super::*; - use crate::function::{register, Function}; - let ptr = ObjectPtr::new(Object::base::()); - // Call the function without the wrapping for TVM. - assert_eq!(ptr.count(), 1); - let output = test_fn_typed(ptr.clone(), ptr.clone()); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - - register(test_fn_typed, "test_fn_typed").unwrap(); - let typed_func = Function::get("test_fn_typed").unwrap(); - let output = typed_func - .invoke(vec![(&ptr).into(), (&ptr).into()]) - .unwrap(); - let output: ObjectPtr = output.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - } -} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs deleted file mode 100644 index e61afaf7399b..000000000000 --- a/rust/tvm-rt/src/string.rs +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::cmp::{Ordering, PartialEq}; -use std::hash::{Hash, Hasher}; - -use super::Object; - -use tvm_macros::Object; - -#[repr(C)] -#[derive(Object, Debug)] -#[ref_name = "String"] -#[type_key = "runtime.String"] -#[no_derive] -pub struct StringObj { - base: Object, - data: *const u8, - size: u64, -} - -impl From for String { - fn from(s: std::string::String) -> Self { - let size = s.len() as u64; - let data = Box::into_raw(s.into_boxed_str()).cast(); - let base = Object::base::(); - StringObj { base, data, size }.into() - } -} - -impl From<&'static str> for String { - fn from(s: &'static str) -> Self { - let size = s.len() as u64; - let data = s.as_bytes().as_ptr(); - let base = Object::base::(); - StringObj { base, data, size }.into() - } -} - -impl AsRef<[u8]> for String { - fn as_ref(&self) -> &[u8] { - self.as_bytes() - } -} - -impl std::fmt::Display for String { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.to_string_lossy().fmt(f) - } -} - -impl String { - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn len(&self) -> usize { - self.size as usize - } - - pub fn as_bytes(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.data, self.len()) } - } - - pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> { - std::str::from_utf8(self.as_bytes()) - } - - pub fn to_string_lossy(&self) -> std::borrow::Cow { - std::string::String::from_utf8_lossy(self.as_bytes()) - } -} - -impl> PartialEq for String { - fn eq(&self, other: &T) -> bool { - self.as_bytes() == other.as_ref() - } -} - -impl> PartialOrd for String { - fn partial_cmp(&self, other: &T) -> Option { - self.as_bytes().partial_cmp(other.as_ref()) - } -} - -impl Eq for String {} - -impl Ord for String { - fn cmp(&self, other: &Self) -> Ordering { - self.as_bytes().cmp(other.as_bytes()) - } -} - -impl Hash for String { - fn hash(&self, state: &mut H) { - self.as_bytes().hash(state); - } -} - -impl std::fmt::Debug for String { - fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_fmt(format_args!("{:?}", self.to_string_lossy())) - } -} - -#[cfg(test)] -mod tests { - use super::String; - use crate::object::debug_print; - use crate::IsObjectRef; - use anyhow::{ensure, Result}; - - #[test] - fn test_string_debug() -> Result<()> { - let s = String::from("foo"); - let object_ref = s.upcast(); - println!("about to call"); - let string = debug_print(object_ref)?; - println!("after call"); - ensure!( - string.into_string().expect("is cstring").contains("foo"), - "string content is invalid" - ); - Ok(()) - } -} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs deleted file mode 100644 index 67fbfc996af0..000000000000 --- a/rust/tvm-rt/src/to_function.rs +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module provides an idiomatic Rust API for creating and working with TVM functions. -//! -//! For calling an already registered TVM function use [`function::Builder`] -//! To register a TVM packed function from Rust side either -//! use [`function::register`] or the macro [`register_global_func`]. -//! -//! See the tests and examples repository for more examples. - -use std::convert::{TryFrom, TryInto}; -use std::{ - os::raw::{c_int, c_void}, - ptr, slice, -}; - -use super::{function::Result, Function}; -use crate::errors::Error; - -pub use tvm_sys::{ffi, ArgValue, RetValue}; - -/// A trait representing whether the function arguments -/// and return type can be assigned to a TVM packed function. -/// -/// By splitting the conversion to function into two traits -/// we are able to improve error reporting, by splitting the -/// conversion of inputs and outputs to this trait. -/// -/// And the implementation of it to `ToFunction`. - -pub type ArgList<'a> = Vec>; - -pub enum Args<'a, I> { - Typed(I), - Raw(ArgList<'a>), -} - -pub trait Typed { - fn args<'arg>(i: Vec>) -> Result>; - fn ret(o: O) -> Result; -} - -pub trait ToFunction: Sized { - type Handle; - - fn into_raw(self) -> *mut Self::Handle; - - fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result - where - Self: Typed; - - fn drop(handle: *mut Self::Handle); - - fn to_function(self) -> Function - where - Self: Typed, - { - let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; - let resource_handle = self.into_raw(); - - check_call!(ffi::TVMFuncCreateFromCFunc( - Some(Self::tvm_callback), - resource_handle as *mut _, - Some(Self::tvm_finalizer), - &mut fhandle as *mut ffi::TVMFunctionHandle, - )); - - Function::from_raw(fhandle) - } - - /// The callback function which is wrapped converted by TVM - /// into a packed function stored in fhandle. - unsafe extern "C" fn tvm_callback( - args: *mut ffi::TVMValue, - type_codes: *mut c_int, - num_args: c_int, - ret: ffi::TVMRetValueHandle, - resource_handle: *mut c_void, - ) -> c_int - where - Self: Typed, - { - #![allow(unused_assignments, unused_unsafe)] - let result = std::panic::catch_unwind(|| { - // turning off the incorrect linter complaints - let len = num_args as usize; - let args_list = slice::from_raw_parts_mut(args, len); - let type_codes_list = slice::from_raw_parts_mut(type_codes, len); - let mut local_args: Vec = Vec::new(); - let mut value = ffi::TVMValue { v_int64: 0 }; - let mut tcode = 0; - let resource_handle = resource_handle as *mut Self::Handle; - for i in 0..len { - value = args_list[i]; - tcode = type_codes_list[i]; - // TODO(@jroesch): I believe it is sound to disable this specialized move rule. - // - // This is used in C++ to deal with moving an RValue or reference to a return value - // directly so you can skip copying. - // - // I believe this is not needed as the move directly occurs into the Rust function. - - // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - // { - // check_call!(ffi::TVMCbArgToReturn( - // &mut value as *mut _, - // &mut tcode as *mut _ - // )); - // } - let arg_value = ArgValue::from_tvm_value(value, tcode as u32); - local_args.push(arg_value); - } - - let rv = match Self::call(resource_handle, local_args) { - Ok(v) => v, - Err(msg) => { - return Err(msg); - } - }; - - // TODO(@jroesch): clean up the handling of the is dec_ref - match rv.clone().try_into() as Result> { - Err(_) => {} - Ok(v) => drop(v), - }; - - let (mut ret_val, ret_tcode) = rv.to_tvm_value(); - let mut ret_type_code = ret_tcode as c_int; - - check_call!(ffi::TVMCFuncSetReturn( - ret, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _, - 1 as c_int - )); - - Ok(()) - }); - - // Here we handle either a panic or true error to isolate - // the unwinding as it will cause issues if we allow Rust - // to unwind over C++ boundary without care. - match result { - Err(_) => { - // TODO(@jroesch): figure out how to improve error here. - crate::set_last_error(&Error::Panic); - return -1; - } - Ok(inner_res) => match inner_res { - Err(err) => { - crate::set_last_error(&err); - return -1; - } - Ok(()) => return 0, - }, - } - } - - /// The finalizer which is invoked when the packed function's - /// reference count is zero. - unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { - let handle = std::mem::transmute(fhandle); - Self::drop(handle) - } -} - -pub struct RawArgs; - -impl Typed for for<'a> fn(Vec>) -> Result { - fn args<'arg>(args: Vec>) -> Result> { - Ok(Args::Raw(args)) - } - - fn ret(o: RetValue) -> Result { - Ok(o) - } -} - -impl ToFunction for for<'arg> fn(Vec>) -> Result { - type Handle = for<'arg> fn(Vec>) -> Result; - - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(self); - Box::into_raw(ptr) - } - - fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { - let func = *handle; - func(args) - } - } - - fn drop(_: *mut Self::Handle) {} -} - -/// A helper trait which correctly captures the complex conversion and lifetime semantics needed -/// to coerce an ordinary Rust value into `ArgValue`. -pub trait TryFromArgValue: TryFrom { - fn from_arg_value(f: F) -> std::result::Result; -} - -impl<'a, T> TryFromArgValue> for T -where - Self: TryFrom>, - Error: From<>>::Error>, -{ - fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { - Ok(TryFrom::try_from(f)?) - } -} - -macro_rules! impl_typed_and_to_function { - ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for Fun - where - Fun: Fn($($t),*) -> Out, - Out: TryInto, - Error: From, - $( for<'a> $t: TryFromArgValue>, )* - { - #[allow(non_snake_case, unused_variables, unused_mut)] - fn args<'arg>(args: Vec>) -> Result> { - if args.len() != $len { - return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", - std::any::type_name::(), - $len, args.len()))) - } - let mut args = args.into_iter(); - $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* - Ok(Args::Typed(($($t,)*))) - } - - fn ret(out: Out) -> Result { - out.try_into().map_err(|e| e.into()) - } - } - - - impl ToFunction<($($t,)*), Out> for Fun - where - Fun: Fn($($t,)*) -> Out + 'static - { - type Handle = Box Out + 'static>; - - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(Box::new(self)); - Box::into_raw(ptr) - } - - #[allow(non_snake_case)] - fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result - where - Fun: Typed<($($t,)*), Out> - { - let ($($t,)*) = match Fun::args(args)? { - Args::Raw(_) => panic!("impossible case"), - Args::Typed(typed) => typed, - }; - - let fn_ptr = unsafe { &*handle }; - let out = fn_ptr($($t),*); - Fun::ret(out) - } - - fn drop(ptr: *mut Self::Handle) { - let bx = unsafe { Box::from_raw(ptr) }; - std::mem::drop(bx) - } - } - } -} - -impl_typed_and_to_function!(0;); -impl_typed_and_to_function!(1; A); -impl_typed_and_to_function!(2; A, B); -impl_typed_and_to_function!(3; A, B, C); -impl_typed_and_to_function!(4; A, B, C, D); -impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, F); -impl_typed_and_to_function!(7; A, B, C, D, E, F, G); -impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); - -#[cfg(test)] -mod tests { - use super::*; - - fn call<'a, F, I, O>(f: F, args: Vec>) -> Result - where - F: ToFunction, - F: Typed, - { - F::call(f.into_raw(), args) - } - - #[test] - fn test_to_function0() { - fn zero() -> i32 { - 10 - } - let _ = zero.to_function(); - let good = call(zero, vec![]).unwrap(); - assert_eq!(i32::try_from(good).unwrap(), 10); - let bad = call(zero, vec![1.into()]).unwrap_err(); - assert!(matches!(bad, Error::CallFailed(..))); - } - - #[test] - fn test_to_function2() { - fn two_arg(i: i32, j: i32) -> i32 { - i + j - } - let good = call(two_arg, vec![3.into(), 4.into()]).unwrap(); - assert_eq!(i32::try_from(good).unwrap(), 7); - } -} diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml deleted file mode 100644 index 03e1d4e13d55..000000000000 --- a/rust/tvm-sys/Cargo.toml +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-sys" -version = "0.1.1-alpha" -authors = ["TVM Contributors"] -license = "Apache-2.0" -edition = "2018" -description = "Low level bindings to TVM's cross language API." - -[features] -default = ["dynamic-linking"] -static-linking = [] -dynamic-linking = [] -runtime-only = [] -# Enabling any of the following features is like setting the value to "ON" in config.cmake. -use-cuda = [] -use-opencl = [] -use-vulkan = [] -use-metal = [] -use-rocm = [] -use-hexagon-device = [] -use-rpc = [] -use-threads = [] -use-llvm = [] -use-stackvm-runtime = [] -use-openmp = [] -use-rtti = [] -use-mscv-mt = [] -use-install-dev = [] -hide-private-symbols = [] -use-fallback-stl-map = [] -use-index-default-i64 = [] -use-tf-tvmdsoop = [] -use-byodt-posit = [] -use-mkl = [] -use-mkldnn = [] -use-dnnl-codegen = [] -use-cudnn = [] -use-cublas = [] -use-thrust = [] -use-miopen = [] -use-rocblas = [] -use-sort = [] -use-nnpack = [] -use-random = [] -use-cpp-rpc = [] -use-tflite = [] -use-coreml = [] -use-target-onnx = [] -use-arm-compute-lib = [] -use-arm-compute-lib-graph-runtime = [] -use-tensorrt-codegen = [] -use-tensorrt-runtime = [] -build-static-runtime = [] - -[dependencies] -thiserror = "^1.0" -anyhow = "^1.0" -ndarray = "0.12" -enumn = "^0.1" - -[build-dependencies] -bindgen = { version="0.57", default-features = false, features = ["runtime"] } -anyhow = "^1.0" -tvm-build = "0.2.4" diff --git a/rust/tvm-sys/README.md b/rust/tvm-sys/README.md deleted file mode 100644 index 735a9431aa33..000000000000 --- a/rust/tvm-sys/README.md +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - - - -# tvm-sys - -The low level bindings to TVM's C APIs for interacting with the runtime, -the cross-language object system, and packed function API. - -These will generate bindings to TVM, if you set `TVM_HOME` variable before -building it will instruct the bindings to use your source tree, if not the -crate will use `tvm-build` in order to build a sandboxed version of the library. - -This feature is intended to simplify the installation for brand new TVM users -by trying to automate the build process as much as possible. diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs deleted file mode 100644 index 2f30afb4b0ab..000000000000 --- a/rust/tvm-sys/build.rs +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern crate bindgen; - -use std::{ - path::{Path, PathBuf}, - str::FromStr, -}; - -use anyhow::{Context, Result}; -use tvm_build::{BuildConfig, CMakeSetting}; - -/// The necessary information for detecting a TVM installation. -struct TVMInstall { - source_path: PathBuf, - build_path: PathBuf, -} - -/// Find the TVM install using the provided path. -fn find_using_tvm_path>(tvm_path: P) -> Result { - Ok(TVMInstall { - source_path: tvm_path.as_ref().into(), - build_path: tvm_path.as_ref().into(), - }) -} - -#[allow(unused)] -fn if_unset, V: AsRef>(k: K, v: V) -> Result<()> { - match std::env::var(k.as_ref()) { - Ok(other) if other != "" => { - println!( - "cargo:warning=Using existing environment variable setting {:?}={:?}", - k.as_ref(), - v.as_ref() - ); - } - _ => std::env::set_var(k, v), - } - - Ok(()) -} - -/// Find a TVM installation using TVM build by either first installing or detecting. -fn find_using_tvm_build() -> Result { - let mut build_config = BuildConfig::default(); - build_config.repository = Some("https://github.com/apache/tvm".to_string()); - build_config.branch = Some(option_env!("TVM_BRANCH").unwrap_or("main").into()); - - if cfg!(feature = "use-cuda") { - build_config.settings.use_cuda = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-opencl") { - build_config.settings.use_opencl = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-vulkan") { - build_config.settings.use_vulkan = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-rocm") { - build_config.settings.use_rocm = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-metal") { - build_config.settings.use_metal = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-hexagon-device") { - build_config.settings.use_hexagon_device = Some(true); - } - if cfg!(feature = "use-rpc") { - build_config.settings.use_rpc = Some(true); - } - if cfg!(feature = "use-threads") { - build_config.settings.use_threads = Some(true); - } - if cfg!(feature = "use-llvm") { - build_config.settings.use_llvm = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-stackvm-runtime") { - build_config.settings.use_stackvm_runtime = Some(true); - } - if cfg!(feature = "use-graph-runtime") { - build_config.settings.use_graph_runtime = Some(true); - } - if cfg!(feature = "use-graph-runtime-debug") { - build_config.settings.use_graph_runtime_debug = Some(true); - } - if cfg!(feature = "use-openmp") { - build_config.settings.use_openmp = Some(true); - } - if cfg!(feature = "use-rtti") { - build_config.settings.use_rtti = Some(true); - } - if cfg!(feature = "use-mscv-mt") { - build_config.settings.use_mscv_mt = Some(true); - } - if cfg!(feature = "use-install-dev") { - build_config.settings.use_install_dev = Some(true); - } - if cfg!(feature = "hide_private-symbols") { - build_config.settings.hide_private_symbols = Some(true); - } - if cfg!(feature = "use-fallback-stl-map") { - build_config.settings.use_fallback_stl_map = Some(true); - } - if cfg!(feature = "use-index_default-i64") { - build_config.settings.use_index_default_i64 = Some(true); - } - if cfg!(feature = "use-tf-tvmdsoop") { - build_config.settings.use_tf_tvmdsoop = Some(true); - } - if cfg!(feature = "use-byodt-posit") { - build_config.settings.use_byodt_posit = Some(true); - } - if cfg!(feature = "use-mkl") { - build_config.settings.use_mkl = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-mkldnn") { - build_config.settings.use_mkldnn = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-dnnl-codegen") { - build_config.settings.use_dnnl_codegen = Some(true); - } - if cfg!(feature = "use-cudnn") { - build_config.settings.use_cudnn = Some(true); - } - if cfg!(feature = "use-cublas") { - build_config.settings.use_cublas = Some(true); - } - if cfg!(feature = "use-thrust") { - build_config.settings.use_thrust = Some(true); - } - if cfg!(feature = "use-miopen") { - build_config.settings.use_miopen = Some(true); - } - if cfg!(feature = "use-rocblas") { - build_config.settings.use_rocblas = Some(true); - } - if cfg!(feature = "use-sort") { - build_config.settings.use_sort = Some(true); - } - if cfg!(feature = "use-nnpack") { - build_config.settings.use_nnpack = Some(true); - } - if cfg!(feature = "use-random") { - build_config.settings.use_random = Some(true); - } - if cfg!(feature = "use-cpp-rpc") { - build_config.settings.use_cpp_rpc = Some(true); - } - if cfg!(feature = "use-tflite") { - build_config.settings.use_tflite = Some(true); - } - if cfg!(feature = "use-coreml") { - build_config.settings.use_coreml = Some(true); - } - if cfg!(feature = "use-target-onnx") { - build_config.settings.use_target_onnx = Some(true); - } - if cfg!(feature = "use-arm-compute-lib") { - build_config.settings.use_arm_compute_lib = Some(true); - } - if cfg!(feature = "use-arm-compute-lib-graph-runtime") { - build_config.settings.use_arm_compute_lib_graph_runtime = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-tensorrt-codegen") { - build_config.settings.use_tensorrt_codegen = Some(true); - } - if cfg!(feature = "use-tensorrt-runtime") { - build_config.settings.use_tensorrt_runtime = CMakeSetting::from_str("on").ok(); - } - if cfg!(any( - feature = "static-linking", - feature = "build-static-runtime" - )) { - build_config.settings.build_static_runtime = Some(true); - } - - let build_result = tvm_build::build(build_config)?; - let source_path = build_result.revision.source_path(); - let build_path = build_result.revision.build_path(); - Ok(TVMInstall { - source_path, - build_path, - }) -} - -fn main() -> Result<()> { - let TVMInstall { - source_path, - build_path, - } = match option_env!("TVM_HOME") { - Some(tvm_path) if tvm_path != "" => find_using_tvm_path(tvm_path), - _ => find_using_tvm_build(), - }?; - - // If the TVM_HOME environment variable changed, the LLVM_CONFIG_PATH environment variable - // changed or the source headers have changed we need to rebuild the Rust bindings. - println!("cargo:rerun-if-env-changed=TVM_HOME"); - println!("cargo:rerun-if-env-changed=LLVM_CONFIG_PATH"); - println!("cargo:rerun-if-changed={}/include", source_path.display()); - - let library_name = if cfg!(feature = "runtime-only") { - "tvm_runtime" - } else { - "tvm" - }; - - match &std::env::var("CARGO_CFG_TARGET_ARCH") - .expect("CARGO_CFG_TARGET_ARCH must be set by CARGO")[..] - { - "wasm32" => {} - _ => { - if cfg!(feature = "static-linking") { - println!("cargo:rustc-link-lib=static={}", library_name); - // TODO(@jroesch): move this to tvm-build as library_path? - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } - - if cfg!(feature = "dynamic-linking") { - println!("cargo:rustc-link-lib=dylib={}", library_name); - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } - } - }; - - let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); - let backend_api = source_path.join("include/tvm/runtime/c_backend_api.h"); - let source_path = source_path.display().to_string(); - let dlpack_include = format!("-I{}/3rdparty/dlpack/include/", source_path); - let tvm_include = format!("-I{}/include/", source_path); - - let out_file = PathBuf::from(std::env::var("OUT_DIR")?).join("c_runtime_api.rs"); - - // @see rust-bindgen#550 for `blacklist_type` - bindgen::Builder::default() - .header(runtime_api.display().to_string()) - .header(backend_api.display().to_string()) - .clang_arg(dlpack_include) - .clang_arg(tvm_include) - .blacklist_type("max_align_t") - .layout_tests(false) - .derive_partialeq(true) - .derive_eq(true) - .derive_default(true) - .generate() - .map_err(|()| { - anyhow::anyhow!("bindgen failed to generate the Rust bindings for the C API") - })? - .write_to_file(out_file) - .context("failed to write the generated Rust binding to disk")?; - - Ok(()) -} diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs deleted file mode 100644 index 92208303e89c..000000000000 --- a/rust/tvm-sys/src/array.rs +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - mem, - os::raw::{c_int, c_void}, -}; - -use crate::ffi::{ - DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLDevice, - DLDeviceType_kDLCPU, DLTensor, -}; - -/// `From` conversions to `DLTensor` for `ndarray::Array`. -/// Takes a reference to the `ndarray` since `DLTensor` is not owned. -macro_rules! impl_dltensor_from_ndarray { - ($type:ty, $typecode:expr) => { - impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { - fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { - DLTensor { - data: arr.as_mut_ptr() as *mut c_void, - device: DLDevice { - device_type: DLDeviceType_kDLCPU, - device_id: 0, - }, - ndim: arr.ndim() as c_int, - dtype: DLDataType { - code: $typecode as u8, - bits: 8 * mem::size_of::<$type>() as u8, - lanes: 1, - }, - shape: arr.shape().as_ptr() as *const i64 as *mut i64, - strides: arr.strides().as_ptr() as *const i64 as *mut i64, - byte_offset: 0, - ..Default::default() - } - } - } - }; -} - -impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); -impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs deleted file mode 100644 index 2903a81d9c36..000000000000 --- a/rust/tvm-sys/src/byte_array.rs +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use std::convert::TryFrom; - -use crate::errors::ValueDowncastError; -use crate::ffi::{TVMByteArray, TVMByteArrayFree}; -use crate::{ArgValue, RetValue}; - -/// A newtype wrapping a raw TVM byte-array. -/// -/// ## Example -/// -/// ``` -/// let v = b"hello"; -/// let barr = tvm_sys::ByteArray::from(&v); -/// assert_eq!(barr.len(), v.len()); -/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); -/// ``` -pub enum ByteArray { - Rust(TVMByteArray), - External(TVMByteArray), -} - -impl Drop for ByteArray { - fn drop(&mut self) { - match self { - ByteArray::Rust(bytes) => { - let ptr = bytes.data; - let len = bytes.size as _; - let cap = bytes.size as _; - let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; - drop(data); - } - ByteArray::External(byte_array) => unsafe { - if TVMByteArrayFree(byte_array as _) != 0 { - panic!("error"); - } - }, - } - } -} - -impl ByteArray { - /// Gets the underlying byte-array - pub fn data(&self) -> &[u8] { - match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { - std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) - }, - } - } - - /// Gets the length of the underlying byte-array - pub fn len(&self) -> usize { - match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, - } - } - - /// Converts the underlying byte-array to `Vec` - pub fn to_vec(&self) -> Vec { - self.data().to_vec() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl>> From for ByteArray { - fn from(arg: T) -> Self { - let mut incoming_bytes: Vec = arg.into(); - let mut bytes = Vec::with_capacity(incoming_bytes.len()); - bytes.append(&mut incoming_bytes); - - let mut bytes = std::mem::ManuallyDrop::new(bytes); - let ptr = bytes.as_mut_ptr(); - assert_eq!(bytes.len(), bytes.capacity()); - ByteArray::Rust(TVMByteArray { - data: ptr as _, - size: bytes.len() as _, - }) - } -} - -impl<'a> From<&'a ByteArray> for ArgValue<'a> { - fn from(val: &'a ByteArray) -> ArgValue<'a> { - match val { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { - ArgValue::Bytes(byte_array) - } - } - } -} - -// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. -// impl From for RetValue { -// fn from(val: ByteArray) -> RetValue { -// match val { -// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { -// // TODO(@jroesch): This requires a little more work, going to land narratives -// RetValue::Bytes(byte_array) -// } -// } -// } -// } - -impl TryFrom for ByteArray { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - match val { - RetValue::Bytes(array) => Ok(ByteArray::External(array)), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn convert() { - let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(v.to_vec()); - assert_eq!(barr.len(), v.len()); - assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); - let v = b"hello"; - let barr = ByteArray::from(v.to_vec()); - assert_eq!(barr.len(), v.len()); - assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); - } -} diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs deleted file mode 100644 index 5f7e0c3a3b60..000000000000 --- a/rust/tvm-sys/src/datatype.rs +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::any::TypeId; -use std::convert::TryFrom; -use std::str::FromStr; - -use crate::ffi::DLDataType; -use crate::packed_func::RetValue; - -use thiserror::Error; - -const DL_INT_CODE: u8 = 0; -const DL_UINT_CODE: u8 = 1; -const DL_FLOAT_CODE: u8 = 2; -const DL_HANDLE: u8 = 3; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(C)] -pub struct DataType { - code: u8, - bits: u8, - lanes: u16, -} - -impl DataType { - pub const fn new(code: u8, bits: u8, lanes: u16) -> DataType { - DataType { code, bits, lanes } - } - - /// Returns the number of bytes occupied by an element of this `DataType`. - pub fn itemsize(&self) -> usize { - (self.bits as usize * self.lanes as usize) >> 3 - } - - /// Returns whether this `DataType` represents primitive type `T`. - pub fn is_type(&self) -> bool { - if self.lanes != 1 { - return false; - } - let typ = TypeId::of::(); - (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 32) - || (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 64) - || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 32) - || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 64) - || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 32) - || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 64) - } - - pub fn code(&self) -> usize { - self.code as usize - } - - pub fn bits(&self) -> usize { - self.bits as usize - } - - pub fn lanes(&self) -> usize { - self.lanes as usize - } - - pub const fn int(bits: u8, lanes: u16) -> DataType { - DataType::new(DL_INT_CODE, bits, lanes) - } - - pub const fn float(bits: u8, lanes: u16) -> DataType { - DataType::new(DL_FLOAT_CODE, bits, lanes) - } - - pub const fn float32() -> DataType { - Self::float(32, 1) - } - - pub const fn uint(bits: u8, lanes: u16) -> DataType { - DataType::new(DL_UINT_CODE, bits, lanes) - } -} - -impl<'a> From<&'a DataType> for DLDataType { - fn from(dtype: &'a DataType) -> Self { - Self { - code: dtype.code as u8, - bits: dtype.bits as u8, - lanes: dtype.lanes as u16, - } - } -} - -impl From for DataType { - fn from(dtype: DLDataType) -> Self { - Self { - code: dtype.code, - bits: dtype.bits, - lanes: dtype.lanes, - } - } -} - -impl From for DLDataType { - fn from(dtype: DataType) -> Self { - Self { - code: dtype.code, - bits: dtype.bits, - lanes: dtype.lanes, - } - } -} - -#[derive(Debug, Error)] -pub enum ParseDataTypeError { - #[error("invalid number: {0}")] - InvalidNumber(std::num::ParseIntError), - #[error("missing data type specifier (e.g., int32, float64)")] - MissingDataType, - #[error("unknown type: {0}")] - UnknownType(String), -} - -/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` -/// such as "int32", "float32" or with lane "float32x1". -impl FromStr for DataType { - type Err = ParseDataTypeError; - - fn from_str(type_str: &str) -> Result { - use ParseDataTypeError::*; - - if type_str == "bool" { - return Ok(DataType::new(1, 1, 1)); - } - - let mut type_lanes = type_str.split('x'); - let typ = type_lanes.next().ok_or(MissingDataType)?; - let lanes = type_lanes - .next() - .map(|l| ::from_str_radix(l, 10)) - .unwrap_or(Ok(1)) - .map_err(InvalidNumber)?; - let (type_name, bits) = match typ.find(char::is_numeric) { - Some(idx) => { - let (name, bits_str) = typ.split_at(idx); - ( - name, - u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?, - ) - } - None => (typ, 32), - }; - - let type_code = match type_name { - "int" => DL_INT_CODE, - "uint" => DL_UINT_CODE, - "float" => DL_FLOAT_CODE, - "handle" => DL_HANDLE, - _ => return Err(UnknownType(type_name.to_string())), - }; - - Ok(DataType::new(type_code, bits, lanes)) - } -} - -impl std::fmt::Display for DataType { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - if self.bits == 1 && self.lanes == 1 { - return write!(f, "bool"); - } - let mut type_str = match self.code { - DL_INT_CODE => "int", - DL_UINT_CODE => "uint", - DL_FLOAT_CODE => "float", - DL_HANDLE => "handle", - _ => "unknown", - } - .to_string(); - - type_str += &self.bits.to_string(); - if self.lanes > 1 { - type_str += &format!("x{}", self.lanes); - } - f.write_str(&type_str) - } -} - -impl From for RetValue { - fn from(dt: DataType) -> RetValue { - RetValue::DataType((&dt).into()) - } -} - -impl TryFrom for DataType { - type Error = anyhow::Error; - fn try_from(ret_value: RetValue) -> anyhow::Result { - match ret_value { - RetValue::DataType(dt) => Ok(dt.into()), - // TODO(@jroesch): improve - _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), - } - } -} diff --git a/rust/tvm-sys/src/device.rs b/rust/tvm-sys/src/device.rs deleted file mode 100644 index 0344983c1622..000000000000 --- a/rust/tvm-sys/src/device.rs +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! Provides [`Device`] and related device queries. -//! -//! Create a new device for device type and device id. -//! -//! # Example -//! -//! ``` -//! # use tvm_sys::{DeviceType, Device}; -//! let cpu = DeviceType::from("cpu"); -//! let dev = Device::new(cpu , 0); -//! let cpu0 = Device::cpu(0); -//! assert_eq!(dev, cpu0); -//! ``` -//! -//! Or from a supported device name. -//! -//! ``` -//! use tvm_sys::Device; -//! let cpu0 = Device::from("cpu"); -//! println!("{}", cpu0); -//! ``` - -use std::convert::TryFrom; -use std::fmt::{self, Display, Formatter}; -use std::str::FromStr; - -use crate::ffi::{self, *}; -use crate::packed_func::{ArgValue, RetValue}; - -use anyhow::Result; -use enumn::N; -use thiserror::Error; - -/// Device type represents the set of devices supported by -/// [TVM](https://github.com/apache/tvm). -/// -/// ## Example -/// -/// ``` -/// use tvm_sys::DeviceType; -/// let cpu = DeviceType::from("cpu"); -/// println!("device is: {}", cpu); -///``` - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, N)] -#[repr(i64)] -pub enum DeviceType { - CPU = 1, - CUDA = 2, - CUDAHost = 3, - OpenCL = 4, - Vulkan = 7, - Metal = 8, - VPI = 9, - ROCM = 10, - ExtDev = 12, -} - -impl Default for DeviceType { - /// default device is cpu. - fn default() -> Self { - DeviceType::CPU - } -} - -impl From for ffi::DLDeviceType { - fn from(device_type: DeviceType) -> Self { - device_type as Self - } -} - -impl From for DeviceType { - fn from(device_type: ffi::DLDeviceType) -> Self { - Self::n(device_type as _).expect("invalid enumeration value for ffi::DLDeviceType") - } -} - -impl Display for DeviceType { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "{}", - match self { - DeviceType::CPU => "cpu", - DeviceType::CUDA => "cuda", - DeviceType::CUDAHost => "cuda_host", - DeviceType::OpenCL => "opencl", - DeviceType::Vulkan => "vulkan", - DeviceType::Metal => "metal", - DeviceType::VPI => "vpi", - DeviceType::ROCM => "rocm", - DeviceType::ExtDev => "ext_device", - // DeviceType(_) => "rpc", - } - ) - } -} - -impl<'a> From<&'a str> for DeviceType { - fn from(type_str: &'a str) -> Self { - match type_str { - "cpu" => DeviceType::CPU, - "llvm" => DeviceType::CPU, - "cuda" => DeviceType::CUDA, - "nvptx" => DeviceType::CUDA, - "cl" => DeviceType::OpenCL, - "opencl" => DeviceType::OpenCL, - "metal" => DeviceType::Metal, - "vpi" => DeviceType::VPI, - "rocm" => DeviceType::ROCM, - _ => panic!("{:?} not supported!", type_str), - } - } -} - -impl<'a> From<&DeviceType> for ArgValue<'a> { - fn from(dev: &DeviceType) -> Self { - Self::Int(*dev as _) - } -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct Device { - pub device_type: DeviceType, - pub device_id: usize, -} - -impl Device { - pub fn new(device_type: DeviceType, device_id: usize) -> Device { - Device { - device_type, - device_id, - } - } -} - -impl<'a> From<&'a Device> for DLDevice { - fn from(dev: &'a Device) -> Self { - Self { - device_type: dev.device_type.into(), - device_id: dev.device_id as i32, - } - } -} - -impl Default for Device { - fn default() -> Self { - Self { - device_type: DLDeviceType_kDLCPU.into(), - device_id: 0, - } - } -} - -#[derive(Debug, Error)] -#[error("unsupported device: {0}")] -pub struct UnsupportedDeviceError(String); - -macro_rules! impl_tvm_device { - ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { - /// Creates a Device from a string (e.g., "cpu", "cuda", "ext_dev") - impl FromStr for Device { - type Err = UnsupportedDeviceError; - fn from_str(type_str: &str) -> Result { - Ok(Self { - device_type: match type_str { - $( $( stringify!($dev_name) )|+ => $dev_type.into()),+, - _ => return Err(UnsupportedDeviceError(type_str.to_string())), - }, - device_id: 0, - }) - } - } - - impl Device { - $( - $( - pub fn $dev_name(device_id: usize) -> Self { - Self { - device_type: $dev_type.into(), - device_id: device_id, - } - } - )+ - )+ - } - }; -} - -impl_tvm_device!( - DLDeviceType_kDLCPU: [cpu, llvm], - DLDeviceType_kDLCUDA: [cuda, nvptx], - DLDeviceType_kDLOpenCL: [cl], - DLDeviceType_kDLMetal: [metal], - DLDeviceType_kDLVPI: [vpi], - DLDeviceType_kDLROCM: [rocm], - DLDeviceType_kDLExtDev: [ext_dev] -); - -impl<'a> From<&'a str> for Device { - fn from(target: &str) -> Self { - Device::new(DeviceType::from(target), 0) - } -} - -impl From for Device { - fn from(dev: ffi::DLDevice) -> Self { - Device { - device_type: DeviceType::from(dev.device_type), - device_id: dev.device_id as usize, - } - } -} - -impl From for ffi::DLDevice { - fn from(dev: Device) -> Self { - ffi::DLDevice { - device_type: dev.device_type.into(), - device_id: dev.device_id as i32, - } - } -} - -impl Display for Device { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}({})", self.device_type, self.device_id) - } -} - -impl<'a> From<&'a Device> for ArgValue<'a> { - fn from(dev: &'a Device) -> Self { - DLDevice::from(dev).into() - } -} - -impl<'a> From for ArgValue<'a> { - fn from(dev: Device) -> Self { - DLDevice::from(dev).into() - } -} - -impl From for RetValue { - fn from(ret_value: Device) -> RetValue { - RetValue::Device(ret_value.into()) - } -} - -impl TryFrom for Device { - type Error = anyhow::Error; - fn try_from(ret_value: RetValue) -> anyhow::Result { - match ret_value { - RetValue::Device(dt) => Ok(dt.into()), - // TODO(@jroesch): improve - _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn device() { - let dev = Device::cpu(0); - println!("device: {}", dev); - let default_dev = Device::new(DeviceType::CPU, 0); - assert_eq!(dev.clone(), default_dev); - assert_ne!(dev, Device::cuda(0)); - - let str_dev = Device::new(DeviceType::CUDA, 0); - assert_eq!(str_dev.clone(), str_dev); - assert_ne!(str_dev, Device::new(DeviceType::CPU, 0)); - } -} diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs deleted file mode 100644 index 54fe261ec37e..000000000000 --- a/rust/tvm-sys/src/errors.rs +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use thiserror::Error; - -#[derive(Error, Debug)] -#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")] -pub struct ValueDowncastError { - pub actual_type: String, - pub expected_type: &'static str, -} - -#[derive(Error, Debug)] -#[error("Function call `{context:?}` returned error: {message:?}")] -pub struct FuncCallError { - context: String, - message: String, -} - -impl FuncCallError { - pub fn get_with_context(context: String) -> Self { - Self { - context, - message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } - .to_str() - .expect("failed while attempting to retrieve the TVM error message") - .to_owned(), - } - } -} diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs deleted file mode 100644 index f9ac3b461c69..000000000000 --- a/rust/tvm-sys/src/lib.rs +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This crate contains the minimal interface over TVM's -//! C runtime API. -//! -//! These common bindings are useful to both runtimes -//! written in Rust, as well as higher level API bindings. -//! -//! See the `tvm-rt` or `tvm` crates for full bindings to -//! the TVM API. - -/// The low-level C runtime FFI API for TVM. -pub mod ffi { - #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] - - use std::os::raw::{c_char, c_int, c_void}; - - include!(concat!(env!("OUT_DIR"), "/c_runtime_api.rs")); - - pub type BackendPackedCFunc = extern "C" fn( - args: *const TVMValue, - type_codes: *const c_int, - num_args: c_int, - out_ret_value: *mut TVMValue, - out_ret_tcode: *mut u32, - resource_handle: *mut c_void, - ) -> c_int; -} - -pub mod array; -pub mod byte_array; -pub mod datatype; -pub mod device; -pub mod errors; -#[macro_use] -pub mod packed_func; -pub mod value; - -pub use byte_array::ByteArray; -pub use datatype::DataType; -pub use device::{Device, DeviceType}; -pub use errors::*; -pub use packed_func::{ArgValue, RetValue}; - -impl std::convert::TryFrom> for RetValue -where - RetValue: std::convert::TryFrom, - E: From<>::Error>, -{ - type Error = E; - - fn try_from(val: Result) -> Result { - val.and_then(|t| RetValue::try_from(t).map_err(|e| e.into())) - } -} diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs deleted file mode 100644 index 3d78ce52d621..000000000000 --- a/rust/tvm-sys/src/packed_func.rs +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - convert::TryFrom, - ffi::{CStr, CString}, - os::raw::{c_char, c_void}, -}; - -use crate::{errors::ValueDowncastError, ffi::*}; - -pub use crate::ffi::TVMValue; - -pub trait PackedFunc: - Fn(&[ArgValue]) -> Result + Send + Sync -{ -} - -impl PackedFunc for T where - T: Fn(&[ArgValue]) -> Result + Send + Sync -{ -} - -/// Calls a packed function and returns a `RetValue`. -/// -/// # Example -/// -/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` -#[macro_export] -macro_rules! call_packed { - ($fn:expr, $($args:expr),+) => { - $fn(&[$($args.into(),)+]) - }; - ($fn:expr) => { - $fn(&Vec::new()) - }; -} - -/// Constructs a derivative of a TVMPodValue. -macro_rules! TVMPODValue { - { - $(#[$m:meta])+ - $name:ident $(<$a:lifetime>)? { - $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? - }, - match $value:ident { - $($tvm_type:ident => { $from_tvm_type:expr })+ - }, - match &self { - $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ - } - $(,)? - } => { - $(#[$m])+ - #[derive(Clone, Debug)] - pub enum $name $(<$a>)? { - Int(i64), - UInt(i64), - Float(f64), - Bool(bool), - Null, - DataType(DLDataType), - String(*mut c_char), - Device(DLDevice), - Handle(*mut c_void), - ArrayHandle(TVMArrayHandle), - ObjectHandle(*mut c_void), - ModuleHandle(TVMModuleHandle), - FuncHandle(TVMFunctionHandle), - NDArrayHandle(*mut c_void), - $($extra_variant($variant_type)),+ - } - - impl $(<$a>)? $name $(<$a>)? { - pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { - use $name::*; - #[allow(non_upper_case_globals)] - unsafe { - match type_code as _ { - DLDataTypeCode_kDLInt => Int($value.v_int64), - DLDataTypeCode_kDLUInt => UInt($value.v_int64), - DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0), - TVMArgTypeCode_kTVMNullptr => Null, - TVMArgTypeCode_kTVMDataType => DataType($value.v_type), - TVMArgTypeCode_kDLDevice => Device($value.v_device), - TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)), - TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), - $( $tvm_type => { $from_tvm_type } ),+ - _ => unimplemented!("{}", type_code), - } - } - } - - pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { - use $name::*; - match self { - Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), - UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), - Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool), - Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), - Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), - String(val) => { - ( - TVMValue { v_handle: *val as *mut c_void }, - TVMArgTypeCode_kTVMStr, - ) - } - Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), - ArrayHandle(val) => { - ( - TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMArgTypeCode_kTVMNDArrayHandle, - ) - }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), - ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), - FuncHandle(val) => ( - TVMValue { v_handle: *val }, - TVMArgTypeCode_kTVMPackedFuncHandle - ), - NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), - $( $self_type($val) => { $from_self_type } ),+ - } - } - } - } -} - -TVMPODValue! { - /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way - /// to obtain a `ArgValue` is automatically via `call_packed!`. - ArgValue<'a> { - Bytes(&'a TVMByteArray), - Str(&'a CStr), - }, - match value { - TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } - }, - match &self { - Bytes(val) => { - (TVMValue { v_handle: *val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) - } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } - } -} - -TVMPODValue! { - /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. - /// Can be downcasted using `try_from` if it contains the desired type. - /// - /// # Example - /// - /// ``` - /// use std::convert::{TryFrom, TryInto}; - /// use tvm_sys::RetValue; - /// - /// let a = 42u32; - /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap(); - /// - /// let s = "hello, world!"; - /// let t: RetValue = s.to_string().into(); - /// assert_eq!(String::try_from(t).unwrap(), s); - /// ``` - RetValue { - Bytes(TVMByteArray), - Str(&'static CStr), - }, - match value { - TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } - }, - match &self { - Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } - Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } - } -} - -#[macro_export] -macro_rules! try_downcast { - ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { - match $val { - $( $pat => { Ok($converter) } )+ - _ => Err($crate::errors::ValueDowncastError { - actual_type: format!("{:?}", $val), - expected_type: stringify!($into), - }), - } - }; -} - -/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode. -macro_rules! impl_pod_value { - ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { - $( - impl<'a> From<$type> for ArgValue<'a> { - fn from(val: $type) -> Self { - Self::$variant(val as $inner_ty) - } - } - - impl<'a> From<&'a $type> for ArgValue<'a> { - fn from(val: &'a $type) -> Self { - Self::$variant(*val as $inner_ty) - } - } - - impl<'a> TryFrom> for $type { - type Error = $crate::errors::ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type }) - } - } - - impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { - type Error = $crate::errors::ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type }) - } - } - - impl From<$type> for RetValue { - fn from(val: $type) -> Self { - Self::$variant(val as $inner_ty) - } - } - - impl TryFrom for $type { - type Error = $crate::errors::ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type }) - } - } - )+ - }; -} - -impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); -impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); -impl_pod_value!(Float, f64, [f32, f64]); -impl_pod_value!(Bool, bool, [bool]); -impl_pod_value!(DataType, DLDataType, [DLDataType]); -impl_pod_value!(Device, DLDevice, [DLDevice]); - -impl<'a> From<&'a str> for ArgValue<'a> { - fn from(s: &'a str) -> Self { - Self::String(CString::new(s).unwrap().into_raw()) - } -} - -impl<'a> From for ArgValue<'a> { - fn from(s: String) -> Self { - Self::String(CString::new(s).unwrap().into_raw()) - } -} - -impl<'a> From<&'a CStr> for ArgValue<'a> { - fn from(s: &'a CStr) -> Self { - Self::Str(s) - } -} - -impl<'a> From<&'a CString> for ArgValue<'a> { - fn from(s: &'a CString) -> Self { - Self::String(s.as_ptr() as _) - } -} - -impl<'a> From<&'a TVMByteArray> for ArgValue<'a> { - fn from(s: &'a TVMByteArray) -> Self { - Self::Bytes(s) - } -} - -impl<'a> TryFrom> for &'a str { - type Error = ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) - } -} - -impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { - type Error = ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result { - try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) - } -} - -/// Converts an unspecialized handle to a ArgValue. -impl<'a, T> From<*const T> for ArgValue<'a> { - fn from(ptr: *const T) -> Self { - Self::Handle(ptr as *mut c_void) - } -} - -/// Converts an unspecialized mutable handle to a ArgValue. -impl<'a, T> From<*mut T> for ArgValue<'a> { - fn from(ptr: *mut T) -> Self { - Self::Handle(ptr as *mut c_void) - } -} - -impl<'a> From<&'a mut DLTensor> for ArgValue<'a> { - fn from(arr: &'a mut DLTensor) -> Self { - Self::ArrayHandle(arr as *mut DLTensor) - } -} - -impl<'a> From<&'a DLTensor> for ArgValue<'a> { - fn from(arr: &'a DLTensor) -> Self { - Self::ArrayHandle(arr as *const _ as *mut DLTensor) - } -} - -impl TryFrom for String { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!( - val -> String, - |RetValue::String(s)| { unsafe { CString::from_raw(s).into_string().unwrap() }}, - |RetValue::Str(s)| { s.to_str().unwrap().to_string() } - ) - } -} - -impl From for RetValue { - fn from(s: String) -> Self { - Self::String(std::ffi::CString::new(s).unwrap().into_raw()) - } -} - -impl From for RetValue { - fn from(arr: TVMByteArray) -> Self { - Self::Bytes(arr) - } -} - -impl TryFrom for TVMByteArray { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val }) - } -} - -impl Default for RetValue { - fn default() -> Self { - Self::Int(0) - } -} - -impl TryFrom for std::ffi::CString { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> std::ffi::CString, - |RetValue::Str(val)| { val.into() }) - } -} - -impl From<()> for RetValue { - fn from(_: ()) -> Self { - RetValue::Null - } -} - -impl TryFrom for () { - type Error = ValueDowncastError; - - fn try_from(val: RetValue) -> Result<(), Self::Error> { - try_downcast!(val -> bool, - |RetValue::Null| { () }) - } -} diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs deleted file mode 100644 index 9c987af4cef6..000000000000 --- a/rust/tvm-sys/src/value.rs +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::str::FromStr; - -use crate::ffi::*; - -use thiserror::Error; - -macro_rules! impl_pod_tvm_value { - ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { - $( - impl From<$ty> for TVMValue { - fn from(val: $ty) -> Self { - TVMValue { $field: val as $field_ty } - } - } - - impl From for $ty { - fn from(val: TVMValue) -> Self { - unsafe { val.$field as $ty } - } - } - )+ - }; - ($field:ident, $ty:ty) => { - impl_pod_tvm_value!($field, $ty, $ty); - } -} - -impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); -impl_pod_tvm_value!(v_float64, f64, f32, f64); -impl_pod_tvm_value!(v_type, DLDataType); -impl_pod_tvm_value!(v_device, DLDevice); - -#[derive(Debug, Error)] -#[error("unsupported device: {0}")] -pub struct UnsupportedDeviceError(String); - -macro_rules! impl_tvm_device { - ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { - /// Creates a DLDevice from a string (e.g., "cpu", "cuda", "ext_dev") - impl FromStr for DLDevice { - type Err = UnsupportedDeviceError; - fn from_str(type_str: &str) -> Result { - Ok(Self { - device_type: match type_str { - $( $( stringify!($dev_name) )|+ => $dev_type ),+, - _ => return Err(UnsupportedDeviceError(type_str.to_string())), - }, - device_id: 0, - }) - } - } - - impl DLDevice { - $( - $( - pub fn $dev_name(device_id: usize) -> Self { - Self { - device_type: $dev_type, - device_id: device_id as i32, - } - } - )+ - )+ - } - }; -} - -impl_tvm_device!( - DLDeviceType_kDLCPU: [cpu, llvm], - DLDeviceType_kDLCUDA: [cuda, nvptx], - DLDeviceType_kDLOpenCL: [cl], - DLDeviceType_kDLMetal: [metal], - DLDeviceType_kDLVPI: [vpi], - DLDeviceType_kDLROCM: [rocm], - DLDeviceType_kDLExtDev: [ext_dev] -); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 602a198a2bf6..f0a317659d3a 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -21,7 +21,7 @@ * \file tvm/arith/analyzer.cc */ #include -#include +#include #include #include @@ -231,17 +231,18 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // Current analysis may not be powerful enough to prove expressions containing // the same symbolic value multiple times. However, when the symbolic values are // "T.vscale" and the compile target uses a scalable architecture extension like - // SVE, we can make some assumptions about the value of vscale and iterate over a + // VLA, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. Target curr_target = Target::Current(); if (ContainsVscaleCall(simplified)) { - if (TargetHasSVE(curr_target)) { - return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); + if (TargetHasVLA(curr_target)) { + auto kVScaleValues = GetVScaleValues(curr_target); + return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues); } LOG(WARNING) << "The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " - "AArch64 SVE targets, but the target was " + "VLA targets, but the target was " << curr_target; } return false; @@ -268,7 +269,7 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") +TVM_FFI_REGISTER_GLOBAL("arith.CreateAnalyzer") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { using ffi::Function; using ffi::TypedFunction; @@ -319,7 +320,7 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") }); } else if (name == "bind") { return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (auto opt_range = args[1].as()) { + if (auto opt_range = args[1].try_cast()) { self->Bind(args[0].cast(), opt_range.value()); } else { self->Bind(args[0].cast(), args[1].cast()); diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index d52ae7e6fde3..b8b5d6482428 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -22,7 +22,7 @@ * \brief Utility to deduce bound of expression */ #include -#include +#include #include #include @@ -402,7 +402,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, return DeduceBound(v, e, hmap, rmap); } -TVM_REGISTER_GLOBAL("arith.DeduceBound") +TVM_FFI_REGISTER_GLOBAL("arith.DeduceBound") .set_body_typed([](PrimExpr v, PrimExpr cond, const Map hint_map, const Map relax_map) { return DeduceBound(v, cond, hint_map, relax_map); diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index b11708398fe9..1b82e93eacf7 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -921,7 +921,7 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // try eliminate from lhs for (size_t i = 0; i < lhs_prods.size(); ++i) { if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) { - lhs_prods.Set(i, NullOpt); + lhs_prods.Set(i, std::nullopt); ++num_elimination; new_common_scale = new_common_scale * value; return; @@ -1391,7 +1391,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { // First convert a < b into a - b < 0 PrimExpr expr = this->CanonicalMutate(op->a - op->b); // Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ..., s{n-1}, c) - // 1. if can prove -d < xn < d, then we can simplify + // 1. if can prove 0 <= xn < d, then we can simplify // the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} * (s{n-1}/d) < c/d, // e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x < 2` // 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d % (m/d) @@ -1417,8 +1417,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { ICHECK(extra->dtype == dtype); PrimExpr normal_extra = extra->Normalize(); if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) && - this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) { - // Case 1. -d < xn < d + this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) { + // Case 1. 0 <= xn < d divisible.CopyOnWrite()->DivideBy(gcd); return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype)); } else if (extra->args.size() == 1 && diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 65ac749d45e7..2c905dd563ef 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -24,7 +24,7 @@ #ifndef TVM_ARITH_CONST_FOLD_H_ #define TVM_ARITH_CONST_FOLD_H_ -#include +#include #include #include @@ -45,7 +45,7 @@ namespace arith { * \tparam Op The operator type. * * \note a and b Must already matched data types with each other. - * \return NullOpt if constant fold fails, otherwise return folded result. + * \return std::nullopt if constant fold fails, otherwise return folded result. */ template inline Optional TryConstFold(PrimExpr a, PrimExpr b); @@ -57,7 +57,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b); * \tparam Op The operator type. * * \note a and b Must already matched data types with each other. - * \return NullOpt if constant fold fails, otherwise return folded result. + * \return std::nullopt if constant fold fails, otherwise return folded result. */ template inline Optional TryConstFold(PrimExpr a); @@ -148,7 +148,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); - return NullOpt; + return std::nullopt; } template <> @@ -174,7 +174,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } if (fb && fb->value == 0) return a; }); - return NullOpt; + return std::nullopt; } template <> @@ -210,7 +210,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (fb->value == 0) return b; } }); - return NullOpt; + return std::nullopt; } template <> @@ -246,7 +246,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(fb->value, 0) << "Divide by zero"; } }); - return NullOpt; + return std::nullopt; } template <> @@ -266,7 +266,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); - return NullOpt; + return std::nullopt; } template <> @@ -292,7 +292,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { } else if (rtype.bits() == 64) { return FloatImm(rtype, std::floor(fa->value / fb->value)); } else { - return NullOpt; + return std::nullopt; } } if (fa && fa->value == 0) return a; @@ -301,7 +301,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(fb->value, 0) << "Divide by zero"; } }); - return NullOpt; + return std::nullopt; } template <> @@ -321,7 +321,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); - return NullOpt; + return std::nullopt; } template <> @@ -332,7 +332,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; - return NullOpt; + return std::nullopt; } template <> @@ -343,7 +343,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; - return NullOpt; + return std::nullopt; } template <> @@ -352,7 +352,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); }); - return NullOpt; + return std::nullopt; } template <> @@ -361,7 +361,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); }); - return NullOpt; + return std::nullopt; } template <> @@ -370,7 +370,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); }); - return NullOpt; + return std::nullopt; } template <> @@ -379,7 +379,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); }); - return NullOpt; + return std::nullopt; } template <> @@ -388,7 +388,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); }); - return NullOpt; + return std::nullopt; } template <> @@ -397,7 +397,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); }); - return NullOpt; + return std::nullopt; } template <> @@ -408,7 +408,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && !pa->value) return a; if (pb && pb->value) return a; if (pb && !pb->value) return b; - return NullOpt; + return std::nullopt; } template <> @@ -419,7 +419,7 @@ inline Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && !pa->value) return b; if (pb && pb->value) return b; if (pb && !pb->value) return a; - return NullOpt; + return std::nullopt; } template <> @@ -428,7 +428,7 @@ inline Optional TryConstFold(PrimExpr a) { if (pa) { return IntImm(DataType::UInt(1), !(pa->value)); } - return NullOpt; + return std::nullopt; } /*! \brief Helper namespace for symbolic value limits */ diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index ecd3b25bfc67..a440b52074e8 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -21,7 +21,7 @@ * \file tvm/arith/const_int_bound.cc */ #include -#include +#include #include #include @@ -51,7 +51,7 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); +TVM_FFI_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { @@ -364,15 +364,16 @@ class ConstIntBoundAnalyzer::Impl // only special handle >> and & which can be // used for index calculation. + auto curr_target = Target::Current(); if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); } else if (op->op.same_as(tir::builtin::shift_left())) { return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); - } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) { - unsigned int max_val = - *std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end()); + } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) { + auto kVScaleValues = GetVScaleValues(curr_target); + unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); return MakeBound(1, max_val); } else { return Everything(op->dtype); @@ -751,7 +752,7 @@ class ConstIntBoundAnalyzer::Impl } } } - return NullOpt; + return std::nullopt; } /*! \brief Propagate constraints through ceil(log2(arg)) diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index b496e7fefca5..303360002e03 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -69,6 +69,6 @@ Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { return results; } -TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); +TVM_FFI_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 4d3164cbd382..0dcbc7623590 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -22,7 +22,7 @@ * \brief Utility to detect patterns in the expression. */ #include -#include +#include #include #include #include @@ -290,9 +290,9 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); +TVM_FFI_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); -TVM_REGISTER_GLOBAL("arith.DetectClipBound") +TVM_FFI_REGISTER_GLOBAL("arith.DetectClipBound") .set_body_typed([](const PrimExpr& e, const Array& vars) { return DetectClipBound(e, vars); }); diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 8c7c33bcc3ee..5f9d78003001 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -21,7 +21,7 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ -#include +#include #include #include #include @@ -135,9 +135,9 @@ Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); } -Map> DomainTouchedAccessMap(const PrimFunc& func) { +Map> DomainTouchedAccessMap(const PrimFunc& func) { auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions(); - Map> ret; + Map> ret; auto& buffer_map = func->buffer_map; for (auto& var : func->params) { auto& buffer = buffer_map[var]; @@ -153,7 +153,7 @@ Map> DomainTouchedAccessMap(const PrimFunc& fu combined.push_back(Array(touch)); } - runtime::Array fields; + Array fields; fields.push_back(loads); fields.push_back(stores); fields.push_back(combined); @@ -162,8 +162,8 @@ Map> DomainTouchedAccessMap(const PrimFunc& fu return ret; } -TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); -TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap); +TVM_FFI_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); +TVM_FFI_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 8c314992ab49..01e7a3096927 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include @@ -195,15 +195,16 @@ Range IntGroupBounds::FindBestRange(const Map& vranges_addl) const { TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); -TVM_REGISTER_GLOBAL("arith.IntGroupBounds") +TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds") .set_body_typed([](PrimExpr coef, Array lower, Array equal, Array upper) { return IntGroupBounds(coef, lower, equal, upper); }); -TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange); +TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds_from_range") + .set_body_typed(IntGroupBounds::FromRange); -TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") +TVM_FFI_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK(args.size() == 1 || args.size() == 2); auto bounds = args[0].cast(); @@ -243,7 +244,7 @@ IntConstraints::IntConstraints(Array variables, Map ranges, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); -TVM_REGISTER_GLOBAL("arith.IntConstraints") +TVM_FFI_REGISTER_GLOBAL("arith.IntConstraints") .set_body_typed([](Array variables, Map ranges, Array relations) { return IntConstraints(variables, ranges, relations); }); @@ -288,7 +289,7 @@ IntConstraintsTransform IntConstraintsTransform::operator+( TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); -TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform") +TVM_FFI_REGISTER_GLOBAL("arith.IntConstraintsTransform") .set_body_typed([](IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 587e0121f057..d3b7b30628a1 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include @@ -57,7 +57,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); +TVM_FFI_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -1080,7 +1080,7 @@ static Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& return IntSet::Nothing(); } if (!analyzer->CanProve(extent >= split->scale)) { - return NullOpt; + return std::nullopt; } const PrimExpr& base = iter_min->base; @@ -1110,7 +1110,7 @@ Optional> EstimateRegionStrictBound(const Array& region, for (const Range& range : region) { if (!is_const_number(range->extent)) { // dynamic extent is not supported yet. - return NullOpt; + return std::nullopt; } affine_indices.push_back(range->min); } @@ -1120,7 +1120,7 @@ Optional> EstimateRegionStrictBound(const Array& region, iter_sum_exprs = res->indices; } if (iter_sum_exprs.empty()) { - return NullOpt; + return std::nullopt; } ICHECK_EQ(iter_sum_exprs.size(), ndim); Array result; @@ -1132,7 +1132,7 @@ Optional> EstimateRegionStrictBound(const Array& region, if (int_set.defined()) { result.push_back(int_set.value()); } else { - return NullOpt; + return std::nullopt; } } return result; @@ -1192,42 +1192,42 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "[" << op->min_value << ", " << op->max_value << ']'; }); -TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint); +TVM_FFI_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint); -TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector); +TVM_FFI_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector); -TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval); +TVM_FFI_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); +TVM_FFI_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); +TVM_FFI_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); -TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing); +TVM_FFI_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing); -TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); +TVM_FFI_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything); -TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") +TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionLowerBound") .set_body_typed([](Array region, Map var_dom, PrimExpr predicate) -> Optional> { Analyzer analyzer; return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); }); -TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") +TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionStrictBound") .set_body_typed([](Array region, Map var_dom, PrimExpr predicate) -> Optional> { Analyzer analyzer; return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); }); -TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") +TVM_FFI_REGISTER_GLOBAL("arith.EstimateRegionUpperBound") .set_body_typed([](Array region, Map var_dom, PrimExpr predicate) -> Optional> { Analyzer analyzer; return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); }); -TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); -TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); -TVM_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); +TVM_FFI_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; }); +TVM_FFI_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; }); +TVM_FFI_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound); } // namespace arith } // namespace tvm diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index d24c278f1048..2aa0ca6b6425 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -47,7 +47,7 @@ IterMark::IterMark(PrimExpr source, PrimExpr extent) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { +TVM_FFI_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { return IterMark(source, extent); }); @@ -92,7 +92,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr ex data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterSplitExpr") +TVM_FFI_REGISTER_GLOBAL("arith.IterSplitExpr") .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { return IterSplitExpr(source, lower_factor, extent, scale); }); @@ -114,7 +114,7 @@ IterSumExpr::IterSumExpr(Array args, PrimExpr base) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("arith.IterSumExpr") +TVM_FFI_REGISTER_GLOBAL("arith.IterSumExpr") .set_body_typed([](Array args, PrimExpr base) { return IterSumExpr(args, base); }); @@ -985,7 +985,7 @@ class IterMapRewriter : public ExprMutator { * \return The sum with the fused IterMark and extra offset if succeed. */ Optional TryCombineSplitFromSameSource(IterSumExpr expr) { - if (expr->args.size() <= 1) return NullOpt; + if (expr->args.size() <= 1) return std::nullopt; std::unordered_map hit_count; // most iter map are small n < 5 // so we can afford N^2 complexity @@ -1000,7 +1000,7 @@ class IterMapRewriter : public ExprMutator { hit_count[expr->args[i]->source] = 1; } } - if (!has_overlap) return NullOpt; + if (!has_overlap) return std::nullopt; std::vector visited(expr->args.size(), false); std::vector reverse_flattened_iters; @@ -1096,8 +1096,8 @@ class IterMapRewriter : public ExprMutator { } // select the iterators in order std::vector visited(expr->args.size(), false); - int base_index = FindBaseIter(expr, visited, NullOpt); - if (base_index == -1) return NullOpt; + int base_index = FindBaseIter(expr, visited, std::nullopt); + if (base_index == -1) return std::nullopt; PrimExpr base_scale = expr->args[base_index]->scale; std::vector flattened_iters, grouped_iters; @@ -1114,8 +1114,8 @@ class IterMapRewriter : public ExprMutator { // find position such that expr->args[j] match expected scale // if it is first step, we can simply start with base index int matched_pos = i == 0 ? base_index - : FindIterWithExactScale(expr, visited, expected_scale, NullOpt, -1, - first_possible_unit_extent_pos); + : FindIterWithExactScale(expr, visited, expected_scale, std::nullopt, + -1, first_possible_unit_extent_pos); if (matched_pos != -1) { matched_scale = expected_scale; is_exact_match = true; @@ -1129,7 +1129,7 @@ class IterMapRewriter : public ExprMutator { } } if (matched_pos == -1) { - return NullOpt; + return std::nullopt; } ICHECK(matched_scale.defined()); // look for the longest constrained iter started from expr->args[j] @@ -1163,7 +1163,7 @@ class IterMapRewriter : public ExprMutator { } } if (k == expr->args.size()) { - return NullOpt; + return std::nullopt; } visited[k] = true; flattened_iters.push_back(expr->args[k]); @@ -1209,7 +1209,7 @@ class IterMapRewriter : public ExprMutator { // old iter if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale)) { // the extra offset is not consistent with old - return NullOpt; + return std::nullopt; } return IterSumExpr({IterSplitExpr(it->second.mark, base_scale)}, expr->base + expected_extra_base); @@ -1372,7 +1372,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, lhs_expr = analyzer.Simplify(lhs_expr); rhs_expr = analyzer.Simplify(rhs_expr); } - Optional lower_bound = NullOpt, upper_bound = NullOpt; + Optional lower_bound = std::nullopt, upper_bound = std::nullopt; PrimExpr iter; if (is_greater) { if (bound_at_left) { @@ -1513,7 +1513,7 @@ IterMapResult DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { @@ -1538,7 +1538,7 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const Map& input_iter return rewriter.RewriteToNormalizedIterSum(index); } -TVM_REGISTER_GLOBAL("arith.NormalizeToIterSum") +TVM_FFI_REGISTER_GLOBAL("arith.NormalizeToIterSum") .set_body_typed([](PrimExpr index, const Map& input_iters) { arith::Analyzer ana; return NormalizeToIterSum(index, input_iters, &ana); @@ -2133,7 +2133,7 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { return normalizer.Convert(expr); } -TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); +TVM_FFI_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr); Array IterMapSimplify(const Array& indices, const Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, @@ -2162,7 +2162,7 @@ Array IterMapSimplify(const Array& indices, const Map& indices, const Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { @@ -2495,7 +2495,7 @@ Array> SubspaceDivide(const Array& bindings, return results; } -TVM_REGISTER_GLOBAL("arith.SubspaceDivide") +TVM_FFI_REGISTER_GLOBAL("arith.SubspaceDivide") .set_body_typed([](const Array& bindings, const Map& root_iters, const Array& sub_iters, const PrimExpr& predicate, int check_level, bool simplify_trivial_iterators) { @@ -2634,7 +2634,7 @@ Map InverseAffineIterMap(const Array& iter_map, return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); } -TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); +TVM_FFI_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); TVM_REGISTER_NODE_TYPE(IterMapResultNode); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 197e5ec8b868..fa4891d5a00b 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -22,7 +22,7 @@ * \brief Modular set analysis */ #include -#include +#include #include #include #include @@ -57,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); +TVM_FFI_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index 40c7ab3c54ac..a1a9768110ed 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -22,7 +22,7 @@ * \brief Utility to deduce bound of expression */ #include -#include +#include #include #include #include @@ -212,7 +212,8 @@ PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameter return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); } -TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression); +TVM_FFI_REGISTER_GLOBAL("arith.NarrowPredicateExpression") + .set_body_typed(NarrowPredicateExpression); } // namespace arith } // namespace tvm diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 4f4d7e18578f..e514ad1b1ad7 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -272,7 +272,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); } -TVM_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet); +TVM_FFI_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet); TVM_REGISTER_NODE_TYPE(PresburgerSetNode); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 3682054e8e4b..c911124700fe 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -446,10 +446,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // mul co-efficient folding TVM_TRY_REWRITE(x + x, x * 2); - TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1)); + TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), (y + 1) * x); TVM_TRY_REWRITE(matches_one_of(x * y + x * z, y * x + x * z, x * y + z * x, y * x + z * x), - x * (y + z)); + (y + z) * x); // DivMod rules // truc div @@ -563,12 +563,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(matches_one_of(max(x, y) - y, x - min(y, x)), max(x - y, 0)); TVM_TRY_REWRITE(matches_one_of(x - min(x, y), max(y, x) - y), max(0, x - y)); - // mul co-efficient folding + // mul co-efficient folding: pefer co-effiicent to stay at rhs TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x)); - TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), x * (y - 1)); - TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), x * (1 - y)); + TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), (y - 1) * x); + TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), (1 - y) * x); TVM_TRY_REWRITE(matches_one_of(x * y - x * z, y * x - x * z, x * y - z * x, y * x - z * x), - x * (y - z)); + (y - z) * x); // constant cancelation TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); @@ -1662,7 +1662,7 @@ Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim return make_const(expr->dtype, false); } } - return NullOpt; + return std::nullopt; } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { @@ -1948,7 +1948,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { auto [lhs, lhs_offset] = ExtractConstantOffset(ret->a); auto [rhs, rhs_offset] = ExtractConstantOffset(ret->b); if (lhs_offset == 0 && rhs_offset == 0) { - return NullOpt; + return std::nullopt; } int64_t diff = rhs_offset - lhs_offset; @@ -1962,7 +1962,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { return lhs < rhs + make_const(rhs.dtype(), diff); } - return NullOpt; + return std::nullopt; }(); if (merge_constants) { return RecursiveRewrite(merge_constants.value()); diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index beb75c1f3e09..1937b9c34e03 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -86,14 +86,41 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } -bool TargetHasSVE(Optional target) { +bool TargetHasVLA(Optional target) { if (!target.defined()) { target = Target::Current(); } + bool has_vla{false}; if (target.defined()) { - return Downcast(target)->GetFeature("has_sve").value_or(Bool(false)); + // aarch64 + has_vla = Downcast(target)->GetFeature("has_sve").value_or(Bool(false)); + // riscv{32,64} + static auto target_has_feature_fn = + tvm::ffi::Function::GetGlobalRequired("target.target_has_feature"); + has_vla |= target_has_feature_fn("v", target).cast(); } - return false; + return has_vla; +} + +const std::vector GetVScaleValues(Optional target) { + unsigned int vector_width = 0; + std::vector kVScaleValues; + if (!target.defined()) { + target = Target::Current(); + } + if (target.defined()) { + static auto llvm_get_vector_width_fn = + tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width"); + vector_width = llvm_get_vector_width_fn(target).cast(); + } + // scale list with powers of two + for (unsigned int i = 0;; ++i) { + auto power = static_cast(std::pow(2, i)); + if (power > (vector_width / 8)) break; + kVScaleValues.push_back(power); + } + + return kVScaleValues; } } // namespace arith diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index d31e81fffc97..2470d5dcd827 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -35,9 +35,6 @@ namespace tvm { namespace arith { -/*! \brief A list of known vscale values to try for an AArch64 SVE target. */ -static const std::vector kAArch64VScaleValues = {1, 2, 4, 8, 16}; - /*! * \brief Check if an expr is a call to the vscale intrinsic. * \param expr The expr to check @@ -80,10 +77,18 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr /*! * \brief Check whether the compilation target supports SVE + * \brief Check whether the compilation target supports VLA + * \param target The target to check. + * \return Whether VLA is supported + */ +bool TargetHasVLA(Optional target = std::nullopt); + +/*! + * \brief Get a list of known vscale values to try for an VLA target. * \param target The target to check. - * \return Whether SVE is supported + * \return A list of vscale values as std::vector */ -bool TargetHasSVE(Optional target = NullOpt); +const std::vector GetVScaleValues(Optional target = std::nullopt); } // namespace arith } // namespace tvm diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index fb6250a778ef..4d90c61ea3cb 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -24,8 +24,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -454,7 +454,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol return transform; } -TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") +TVM_FFI_REGISTER_GLOBAL("arith.SolveLinearEquations") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveLinearEquations(args[0].cast()); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 0e5e6d485e74..62f314d1902f 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -24,8 +24,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -535,7 +535,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ return transform; } -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") +TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { IntConstraints problem; PartialSolvedInequalities ret_ineq; @@ -553,7 +553,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition") *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); }); -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") +TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveInequalitiesToRange(args[0].cast()); @@ -568,7 +568,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange") } }); -TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") +TVM_FFI_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { *ret = SolveInequalitiesDeskewRange(args[0].cast()); diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc index 4b1c1850e76e..041ffe7091b2 100644 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ b/src/contrib/msc/core/codegen/code_stack.cc @@ -64,7 +64,7 @@ void BaseStack::FuncDef(const String& func_name, const String& ret_type) { PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), IdDoc(ret_type), Array())); } else { - PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), NullOpt, + PushDoc(FunctionDoc(IdDoc(func_name), Array(), Array(), std::nullopt, Array())); } } @@ -214,13 +214,13 @@ void BaseStack::FuncCall(const String& callee, Optional assign_to, void BaseStack::FuncCall(const String& callee, const String& assign_to, const String& caller) { Optional assign_doc; if (assign_to.size() == 0) { - assign_doc = NullOpt; + assign_doc = std::nullopt; } else { assign_doc = IdDoc(assign_to); } Optional caller_doc; if (caller.size() == 0) { - caller_doc = NullOpt; + caller_doc = std::nullopt; } else { caller_doc = IdDoc(caller); } @@ -231,7 +231,7 @@ void BaseStack::MethodCall(const String& callee, bool new_line) { const auto& host = PopDoc(); if (host->IsInstance()) { const auto& v_callee = callee + (new_line ? DocSymbol::NextLine() : ""); - FuncCall(v_callee, NullOpt, Downcast(host)); + FuncCall(v_callee, std::nullopt, Downcast(host)); } else if (const auto* a_node = host.as()) { ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, Array(), true), @@ -411,7 +411,7 @@ void BaseStack::ScopeStart(const String& scope_def, const String& scope_ref) { if (scope_ref.size() > 0) { PushDoc(ScopeDoc(IdDoc(scope_ref), IdDoc(scope_def), Array())); } else { - PushDoc(ScopeDoc(NullOpt, IdDoc(scope_def), Array())); + PushDoc(ScopeDoc(std::nullopt, IdDoc(scope_def), Array())); } BlockStart(); } diff --git a/src/contrib/msc/core/codegen/code_stack.h b/src/contrib/msc/core/codegen/code_stack.h index e348bcdab1bf..ff4e6b58247a 100644 --- a/src/contrib/msc/core/codegen/code_stack.h +++ b/src/contrib/msc/core/codegen/code_stack.h @@ -157,7 +157,7 @@ class BaseStack { /*! \brief Push call and maybe assign Doc*/ void FuncCall(const String& callee, Optional assign_to, - Optional caller = NullOpt); + Optional caller = std::nullopt); void FuncCall(const String& callee, const String& assign_to = "", const String& caller = ""); /*! \brief Push method call Doc*/ @@ -165,7 +165,7 @@ class BaseStack { /*! \brief Push inplace call and maybe assign Doc*/ void InplaceStart(const String& callee, Optional assign_to, - Optional caller = NullOpt); + Optional caller = std::nullopt); void InplaceStart(const String& callee, const String& assign_to = "", const String& caller = ""); /*! \brief End inplace call*/ @@ -355,7 +355,7 @@ class BaseStack { return *this; \ } \ Stack& func_call(const String& callee, Optional assign_to, \ - Optional caller = NullOpt) { \ + Optional caller = std::nullopt) { \ FuncCall(callee, assign_to, caller); \ return *this; \ } \ @@ -369,7 +369,7 @@ class BaseStack { return *this; \ } \ Stack& inplace_start(const String& callee, Optional assign_to, \ - Optional caller = NullOpt) { \ + Optional caller = std::nullopt) { \ InplaceStart(callee, assign_to, caller); \ return *this; \ } \ diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 14b2114236e5..d38beab5b4ed 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -78,7 +78,7 @@ const JsonMSCTensor MSCTensorNode::ToJson() const { void MSCTensorNode::FromJson(const JsonMSCTensor& j_tensor) { name = j_tensor.name; alias = j_tensor.alias; - dtype = DataType(runtime::StringToDLDataType(j_tensor.dtype)); + dtype = DataType(ffi::StringToDLDataType(j_tensor.dtype)); if (j_tensor.layout.size() > 0) { layout = tvm::tir::Layout(j_tensor.layout); } @@ -1431,14 +1431,14 @@ TVM_REGISTER_NODE_TYPE(MSCGraphNode); TVM_REGISTER_NODE_TYPE(WeightGraphNode); -TVM_REGISTER_GLOBAL("msc.core.MSCTensor") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensor") .set_body_typed([](const String& name, const DataType& dtype, const String& layout, const Array& shape, const String& alias, const Array& prims) -> MSCTensor { return MSCTensor(name, dtype, layout, shape, alias, prims); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorToJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorToJson") .set_body_typed([](const MSCTensor& tensor) -> String { const auto& tensor_json = tensor->ToJson(); std::ostringstream os; @@ -1447,10 +1447,10 @@ TVM_REGISTER_GLOBAL("msc.core.MSCTensorToJson") return os.str(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorFromJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorFromJson") .set_body_typed([](const String& tensor_json) -> MSCTensor { return MSCTensor(tensor_json); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJoint") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, const String& optype, const Map& attrs, const Array& scope, const Array& parents, @@ -1464,7 +1464,7 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint") weights); }); -TVM_REGISTER_GLOBAL("msc.core.MSCPrim") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCPrim") .set_body_typed([](Integer index, const String& name, const String& optype, const Map& attrs, const Array& parents) -> MSCPrim { Array b_parents; @@ -1474,7 +1474,7 @@ TVM_REGISTER_GLOBAL("msc.core.MSCPrim") return MSCPrim(index->value, name, optype, b_parents, attrs); }); -TVM_REGISTER_GLOBAL("msc.core.WeightJoint") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJoint") .set_body_typed([](Integer index, const String& name, const String& shared_ref, const String& weight_type, const MSCTensor& weight, const Array parents, const Map& attrs, @@ -1490,108 +1490,109 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJoint") b_friends); }); -TVM_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointSetAttr") .set_body_typed([](const WeightJoint& node, const String& key, const String& value) { node->attrs.Set(key, value); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraph") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraph") .set_body_typed([](const String& name, const Array& nodes, const Array& input_names, const Array& output_names, const Array& prims) -> MSCGraph { return MSCGraph(name, nodes, input_names, output_names, prims); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraph") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraph") .set_body_typed([](const MSCGraph& graph, const Map>& main_wtypes, const Map& relation_wtypes) -> WeightGraph { return WeightGraph(graph, main_wtypes, relation_wtypes); }); // MSC Graph APIS -TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphHasNode") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasNode(name)); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindNode") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { return graph->FindNode(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindPrim") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCPrim { return graph->FindPrim(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphHasTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> Bool { return Bool(graph->HasTensor(name)); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindTensor") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCTensor { return graph->FindTensor(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphSetTensorAlias") .set_body_typed([](const MSCGraph& graph, const MSCTensor& tensor, const String& alias) { tensor->alias = alias; graph->tensor_alias.Set(alias, tensor->name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindProducer") .set_body_typed([](const MSCGraph& graph, const String& name) -> MSCJoint { return graph->FindProducer(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFindConsumers") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFindConsumers") .set_body_typed([](const MSCGraph& graph, const String& name) -> Array { return graph->FindConsumers(name); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphInputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphInputAt") .set_body_typed([](const MSCGraph& graph, int index) -> MSCTensor { return graph->InputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphOutputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphOutputAt") .set_body_typed([](const MSCGraph& graph, int index) -> MSCTensor { return graph->OutputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphGetInputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphGetInputs") .set_body_typed([](const MSCGraph& graph) -> Array { return graph->GetInputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphGetOutputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphGetOutputs") .set_body_typed([](const MSCGraph& graph) -> Array { return graph->GetOutputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphToJson").set_body_typed([](const MSCGraph& graph) -> String { - const auto& graph_json = graph->ToJson(); - std::ostringstream os; - dmlc::JSONWriter writer(&os); - graph_json.Save(&writer); - return os.str(); -}); +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphToJson") + .set_body_typed([](const MSCGraph& graph) -> String { + const auto& graph_json = graph->ToJson(); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + graph_json.Save(&writer); + return os.str(); + }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphFromJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphFromJson") .set_body_typed([](const String& graph_json) -> MSCGraph { return MSCGraph(graph_json); }); -TVM_REGISTER_GLOBAL("msc.core.MSCGraphToPrototxt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCGraphToPrototxt") .set_body_typed([](const MSCGraph& graph) -> String { return graph->ToPrototxt(); }); // Weight Graph APIS -TVM_REGISTER_GLOBAL("msc.core.WeightGraphHasNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphHasNode") .set_body_typed([](const WeightGraph& graph, const String& name) -> Bool { return Bool(graph->HasNode(name)); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphFindNode") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphFindNode") .set_body_typed([](const WeightGraph& graph, const String& name) -> WeightJoint { return graph->FindNode(name); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphToJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphToJson") .set_body_typed([](const WeightGraph& graph) -> String { const auto& graph_json = graph->ToJson(); std::ostringstream os; @@ -1600,69 +1601,69 @@ TVM_REGISTER_GLOBAL("msc.core.WeightGraphToJson") return os.str(); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphFromJson") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphFromJson") .set_body_typed([](const String& graph_json) -> WeightGraph { return WeightGraph(graph_json); }); -TVM_REGISTER_GLOBAL("msc.core.WeightGraphToPrototxt") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightGraphToPrototxt") .set_body_typed([](const WeightGraph& graph) -> String { return graph->ToPrototxt(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointInputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointInputAt") .set_body_typed([](const MSCJoint& node, int index) -> MSCTensor { return node->InputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointOutputAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointOutputAt") .set_body_typed([](const MSCJoint& node, int index) -> MSCTensor { return node->OutputAt(index); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointWeightAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointWeightAt") .set_body_typed([](const MSCJoint& node, const String& wtype) -> MSCTensor { return node->WeightAt(wtype); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetInputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetInputs") .set_body_typed([](const MSCJoint& node) -> Array { return node->GetInputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetOutputs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetOutputs") .set_body_typed([](const MSCJoint& node) -> Array { return node->GetOutputs(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetWeights") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetWeights") .set_body_typed([](const MSCJoint& node) -> Map { return node->weights; }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointHasAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointHasAttr") .set_body_typed([](const MSCJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }); -TVM_REGISTER_GLOBAL("msc.core.MSCJointGetAttrs") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCJointGetAttrs") .set_body_typed([](const MSCJoint& node) -> Map { return node->attrs; }); -TVM_REGISTER_GLOBAL("msc.core.WeightJointHasAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointHasAttr") .set_body_typed([](const WeightJoint& node, const String& key) -> Bool { return Bool(node->HasAttr(key)); }); -TVM_REGISTER_GLOBAL("msc.core.WeightJointGetAttrs") +TVM_FFI_REGISTER_GLOBAL("msc.core.WeightJointGetAttrs") .set_body_typed([](const WeightJoint& node) -> Map { return node->attrs; }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorDTypeName") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorDTypeName") .set_body_typed([](const MSCTensor& tensor) -> String { return tensor->DTypeName(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorDimAt") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorDimAt") .set_body_typed([](const MSCTensor& tensor, const String& axis) -> Integer { return tensor->DimAt(axis); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorGetSize") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorGetSize") .set_body_typed([](const MSCTensor& tensor) -> Integer { return tensor->GetSize(); }); -TVM_REGISTER_GLOBAL("msc.core.MSCTensorSetAlias") +TVM_FFI_REGISTER_GLOBAL("msc.core.MSCTensorSetAlias") .set_body_typed([](const MSCTensor& tensor, const String& alias) { tensor->alias = alias; }); -TVM_REGISTER_GLOBAL("msc.core.PruneWeights") +TVM_FFI_REGISTER_GLOBAL("msc.core.PruneWeights") .set_body_typed([](const MSCGraph& graph, const Map& pruned_tensors) -> MSCGraph { return PruneWeights(graph, pruned_tensors); diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index 853f75216f1c..2550f5652fc7 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -124,7 +124,7 @@ void LayoutsFinder::VisitExpr_(const CallNode* call_node) { func = local_funcs_[call_node->op]; } if (func.defined()) { - const auto& layouts_opt = func->GetAttr>(msc_attr::kInputLayouts); + const auto& layouts_opt = func->GetAttr>(msc_attr::kInputLayouts); if (layouts_opt.defined()) { for (const auto& pair : layouts_opt.value()) { layouts_.Set(pair.first, pair.second); @@ -172,7 +172,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { if (expr_tensor_map_.count(f)) { LOG_INFO << "Replica tuple input " << f; } else if (const auto* f_node = f.as()) { - AddNode(f, NullOpt, f_node->name_hint()); + AddNode(f, std::nullopt, f_node->name_hint()); } else { LOG_FATAL << "Unexpected tuple input " << f << "(" << f->GetTypeKey() << ")"; } @@ -183,7 +183,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } expr_tensor_map_.Set(p, tuple_names); } else { - AddNode(p, NullOpt, p->name_hint()); + AddNode(p, std::nullopt, p->name_hint()); } ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p; for (const auto& name : expr_tensor_map_[p]) { @@ -339,7 +339,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } else if (const auto* call_node = expr.as()) { if (const auto* v_node = call_node->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& name_opt = func->GetAttr(relax::attr::kComposite); + const auto& name_opt = func->GetAttr(relax::attr::kComposite); if (name_opt.defined()) { attrs = FuncAttrGetter().GetAttrs(func); } @@ -553,7 +553,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional& bin } else if (const auto* s_sinfo = sinfo.as()) { Array shape{s_sinfo->ndim}; const auto& t_name = node_name + ":" + std::to_string(0); - const auto& dtype = DataType(runtime::StringToDLDataType("int32")); + const auto& dtype = DataType(ffi::StringToDLDataType("int32")); outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape)); } else if (const auto* tuple_sinfo = sinfo.as()) { size_t field_size = optype == "nn.batch_norm" ? 1 : num_output; @@ -757,7 +757,7 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVa } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { - const auto& name_opt = val->GetAttr(relax::attr::kComposite); + const auto& name_opt = val->GetAttr(relax::attr::kComposite); ICHECK(name_opt.defined()) << "Unexpected target func without composite"; ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) << "Target should be given for target function"; @@ -766,15 +766,15 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNo const std::tuple GraphBuilder::ParseFunc(const Function& func) { String node_name, optype, layout; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); // get node_name if (name_opt.defined()) { node_name = name_opt.value(); } // get optype - const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - const auto& composite_opt = func->GetAttr(relax::attr::kComposite); + const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& composite_opt = func->GetAttr(relax::attr::kComposite); if (codegen_opt.defined()) { optype = codegen_opt.value(); } else if (optype_opt.defined()) { @@ -786,7 +786,7 @@ const std::tuple GraphBuilder::ParseFunc(const Function& } } // get layout - const auto& layout_opt = func->GetAttr(msc_attr::kLayout); + const auto& layout_opt = func->GetAttr(msc_attr::kLayout); if (layout_opt.defined()) { layout = layout_opt.value(); } @@ -834,7 +834,7 @@ void WeightsExtractor::VisitExpr_(const CallNode* op) { } } -TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax") +TVM_FFI_REGISTER_GLOBAL("msc.core.BuildFromRelax") .set_body_typed([](const IRModule& module, const String& entry_name, const String& options) -> MSCGraph { auto builder = GraphBuilder(module, entry_name, options); @@ -844,7 +844,7 @@ TVM_REGISTER_GLOBAL("msc.core.BuildFromRelax") return builder.Build(func); }); -TVM_REGISTER_GLOBAL("msc.core.GetRelaxWeights") +TVM_FFI_REGISTER_GLOBAL("msc.core.GetRelaxWeights") .set_body_typed([](const IRModule& module, const String& entry_name) -> Map { const auto& func = Downcast(module->Lookup(entry_name)); diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 00582fab4b00..4eac04349728 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -273,7 +273,7 @@ class GraphBuilder : public ExprVisitor { const MSCRBuildConfig config() { return config_; } /*! \brief Create and add MSCJoint from expr*/ - const MSCJoint AddNode(const Expr& expr, const Optional& binding_var = NullOpt, + const MSCJoint AddNode(const Expr& expr, const Optional& binding_var = std::nullopt, const String& name = ""); /*! \brief Create and add MSCPrim from prim*/ diff --git a/src/contrib/msc/core/ir/plugin.cc b/src/contrib/msc/core/ir/plugin.cc index d34972639a7b..fc6000a20f3d 100644 --- a/src/contrib/msc/core/ir/plugin.cc +++ b/src/contrib/msc/core/ir/plugin.cc @@ -305,20 +305,20 @@ const Plugin GetPlugin(const String& name) { return PluginRegistry::Global()->Ge bool IsPlugin(const String& name) { return PluginRegistry::Global()->Registered(name); } -TVM_REGISTER_GLOBAL("msc.core.RegisterPlugin") +TVM_FFI_REGISTER_GLOBAL("msc.core.RegisterPlugin") .set_body_typed([](const String& name, const String& json_str) { PluginRegistry::Global()->Register(name, json_str); }); -TVM_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array { +TVM_FFI_REGISTER_GLOBAL("msc.core.ListPluginNames").set_body_typed([]() -> Array { return ListPluginNames(); }); -TVM_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin { +TVM_FFI_REGISTER_GLOBAL("msc.core.GetPlugin").set_body_typed([](const String& name) -> Plugin { return GetPlugin(name); }); -TVM_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool { +TVM_FFI_REGISTER_GLOBAL("msc.core.IsPlugin").set_body_typed([](const String& name) -> Bool { return Bool(IsPlugin(name)); }); diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 838d284d131f..0f0b24fd3a28 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -113,7 +113,7 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { } else { output_ << float_imm->value; } - } else if (const auto* string_obj = value.as()) { + } else if (const auto* string_obj = value.as()) { output_ << "\"" << tvm::support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc index 95df8da85cb2..234ca3aec9c3 100644 --- a/src/contrib/msc/core/printer/print_utils.cc +++ b/src/contrib/msc/core/printer/print_utils.cc @@ -32,7 +32,7 @@ const String DocSymbol::Empty() { return "::EMPTY"; } const String DocSymbol::NextLine() { return "::NEXT_LINE"; } -const ExprDoc DocUtils::ToDoc(int64_t val) { return LiteralDoc::Int(val, NullOpt); } +const ExprDoc DocUtils::ToDoc(int64_t val) { return LiteralDoc::Int(val, std::nullopt); } const ExprDoc DocUtils::ToDoc(int val) { return ToDoc(static_cast(val)); } @@ -42,7 +42,7 @@ const ExprDoc DocUtils::ToDoc(const IntImm& val) { return ToDoc(val->value); } const ExprDoc DocUtils::ToDoc(const Integer& val) { return ToDoc(val->value); } -const ExprDoc DocUtils::ToDoc(double val) { return LiteralDoc::Float(val, NullOpt); } +const ExprDoc DocUtils::ToDoc(double val) { return LiteralDoc::Float(val, std::nullopt); } const ExprDoc DocUtils::ToDoc(float val) { return ToDoc(static_cast(val)); } @@ -52,11 +52,11 @@ const ExprDoc DocUtils::ToDoc(const char* val) { return IdDoc(std::string(val)); const ExprDoc DocUtils::ToDoc(const String& val) { return IdDoc(val); } -const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, NullOpt); } +const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, std::nullopt); } const ExprDoc DocUtils::ToDoc(const ExprDoc& val) { return val; } -const ExprDoc DocUtils::ToStr(const String& val) { return LiteralDoc::Str(val, NullOpt); } +const ExprDoc DocUtils::ToStr(const String& val) { return LiteralDoc::Str(val, std::nullopt); } const PointerDoc DocUtils::ToPtr(const String& val) { return PointerDoc(val); } diff --git a/src/contrib/msc/core/printer/print_utils.h b/src/contrib/msc/core/printer/print_utils.h index 17d4b0b077a6..b3949d54a762 100644 --- a/src/contrib/msc/core/printer/print_utils.h +++ b/src/contrib/msc/core/printer/print_utils.h @@ -83,7 +83,7 @@ class DocUtils { bool use_constructor = true) { Optional type_doc; if (type.size() == 0) { - type_doc = NullOpt; + type_doc = std::nullopt; } else { type_doc = IdDoc(type); } @@ -103,7 +103,7 @@ class DocUtils { TVM_DLL static const AssignDoc ToAssign(const LT& lhs, const RT& rhs, const String& annotation = "") { if (annotation.size() == 0) { - return AssignDoc(ToDoc(lhs), ToDoc(rhs), NullOpt); + return AssignDoc(ToDoc(lhs), ToDoc(rhs), std::nullopt); } return AssignDoc(ToDoc(lhs), ToDoc(rhs), IdDoc(annotation)); } @@ -114,13 +114,13 @@ class DocUtils { if (rhs.size() > 0) { rhs_doc = IdDoc(rhs); } else { - rhs_doc = NullOpt; + rhs_doc = std::nullopt; } Optional annotation_doc; if (annotation.size() > 0) { annotation_doc = IdDoc(annotation); } else { - annotation_doc = NullOpt; + annotation_doc = std::nullopt; } return AssignDoc(ToDoc(lhs), rhs_doc, annotation_doc); } diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index 06a70a5498c1..44a915ae7b8d 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -31,16 +31,16 @@ namespace contrib { namespace msc { LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) { - if (obj.as()) { - return LiteralDoc::Str(Downcast(obj), NullOpt); + if (obj.as()) { + return LiteralDoc::Str(Downcast(obj), std::nullopt); } else if (obj.as()) { - return LiteralDoc::Int(Downcast(obj)->value, NullOpt); + return LiteralDoc::Int(Downcast(obj)->value, std::nullopt); } else if (obj.as()) { - return LiteralDoc::Float(Downcast(obj)->value, NullOpt); + return LiteralDoc::Float(Downcast(obj)->value, std::nullopt); } std::ostringstream obj_des; obj_des << obj; - return LiteralDoc::Str(obj_des.str(), NullOpt); + return LiteralDoc::Str(obj_des.str(), std::nullopt); } DictDoc PrototxtPrinter::ToDictDoc(const Map& dict) { diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 523db32b3a8d..0225ff319097 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -154,7 +154,7 @@ Pass BindNamedParams(String func_name, Map params) { return CreateModulePass(pass_func, 0, "BindNamedParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindNamedParams").set_body_typed(BindNamedParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindNamedParams").set_body_typed(BindNamedParams); } // namespace transform diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc index ca91b3424e0d..b554e08ab820 100644 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ b/src/contrib/msc/core/transform/bind_shape.cc @@ -132,7 +132,7 @@ Pass BindShape(const String& entry_name) { return CreateModulePass(pass_func, 0, "BindShape", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindShape").set_body_typed(BindShape); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindShape").set_body_typed(BindShape); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index ef7fafbe164c..1eabf3306f36 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -53,7 +53,7 @@ class TupleFuser : public ExprMutator { if (gv->name_hint == entry_name_) { main_var = gv; } else { - const auto& name_opt = func->GetAttr(attr::kComposite); + const auto& name_opt = func->GetAttr(attr::kComposite); if (name_opt.defined() && StringUtils::StartsWith(name_opt.value(), target_)) { target_funcs_.Set(gv, Downcast(func)); } @@ -74,8 +74,7 @@ class TupleFuser : public ExprMutator { const auto& arg = val->args[i]; if (arg->IsInstance()) { String tuple_name; - const auto& name_opt = - target_funcs_[val->op]->GetAttr(msc_attr::kUnique); + const auto& name_opt = target_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.defined()) { if (val->args.size() == 1) { tuple_name = name_opt.value() + "_input"; @@ -185,19 +184,19 @@ class TupleFuser : public ExprMutator { func_attrs.Set(attr::kComposite, target_ + func_name); func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr_span, msc_attr::kName)); - Function function = Function(/*params=*/params, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*is_pure=*/true, // + Function function = Function(/*params=*/params, // + /*body=*/body, // + /*ret_struct_info=*/std::nullopt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(func_attrs)); Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); - function = Function(/*params=*/params, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*is_pure=*/true, // + function = Function(/*params=*/params, // + /*body=*/body, // + /*ret_struct_info=*/std::nullopt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(func_attrs)); } function = SymbolicVarRenewMutator::Renew(function); @@ -232,7 +231,7 @@ Pass FuseTuple(const String& target, const String& entry_name) { return CreateModulePass(pass_func, 0, "FuseTuple", {}); } -TVM_REGISTER_GLOBAL("relax.transform.FuseTuple").set_body_typed(FuseTuple); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseTuple").set_body_typed(FuseTuple); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index aa07c4b66631..a91eb590af26 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -61,7 +61,7 @@ class ParamsInliner : public ExprMutator { continue; } if (struct_info->IsInstance()) { - const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); ICHECK(optype_opt.defined()) << "Can not find attr " << msc_attr::kOptype << " form extern func"; extern_types_.Set(p, optype_opt.value()); @@ -184,7 +184,7 @@ Pass InlineParams(const String& entry_name) { return CreateModulePass(pass_func, 0, "InlineParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.InlineParams").set_body_typed(InlineParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.InlineParams").set_body_typed(InlineParams); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/rewrite_utils.cc b/src/contrib/msc/core/transform/rewrite_utils.cc index 20e4821e6fa7..9cbc7c1a8c51 100644 --- a/src/contrib/msc/core/transform/rewrite_utils.cc +++ b/src/contrib/msc/core/transform/rewrite_utils.cc @@ -44,7 +44,7 @@ Expr RewriteUtils::MakeConstant(BlockBuilder builder, const String& name, double const DataType& dtype, size_t ndim) { const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); Span span = SpanUtils::CreateWithAttr(msc_attr::kName, name); - const auto& constant = Constant(data, NullOpt, span); + const auto& constant = Constant(data, std::nullopt, span); if (ndim == 0) { return constant; } diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index b7bbbcac46c5..4755ebf38960 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -53,7 +53,7 @@ class ByocNameSetter : public ExprMutator { if (gv->name_hint == entry_name_) { continue; } - const auto& name_opt = func->GetAttr(attr::kCodegen); + const auto& name_opt = func->GetAttr(attr::kCodegen); if (name_opt.defined() && name_opt.value() == target_) { const String& func_name = target_ + "_" + std::to_string(func_cnt); const auto& new_func = Downcast(VisitExpr(func)); @@ -73,7 +73,7 @@ class ByocNameSetter : public ExprMutator { ExprMutator::VisitBinding_(binding, val); if (val->op->IsInstance()) { ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; - const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); + const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.defined()) { val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); } @@ -101,7 +101,7 @@ Pass SetBYOCAttrs(const String& target, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetBYOCAttrs", {}); } -TVM_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SetBYOCAttrs").set_body_typed(SetBYOCAttrs); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 80416bafd0f2..dd87e60e7b80 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -1359,7 +1359,7 @@ Pass SetExprLayout(bool allow_missing, const String& entry_name) { return CreateModulePass(pass_func, 0, "SetExprLayout", {}); } -TVM_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SetExprLayout").set_body_typed(SetExprLayout); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index b3b9fbaa2400..4d0cc0314e18 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -158,7 +158,7 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { ExprVisitor::VisitBinding_(binding, val); - const auto& name_opt = val->GetAttr(attr::kComposite); + const auto& name_opt = val->GetAttr(attr::kComposite); if (name_opt.defined()) { local_funcs_.Set(binding->var, GetRef(val)); } @@ -257,8 +257,8 @@ class RelaxExprNameSetter : public ExprVisitor { const String GetFuncType(const Function& func) { String optype; - const auto& comp_opt = func->GetAttr(attr::kComposite); - const auto& code_opt = func->GetAttr(attr::kCodegen); + const auto& comp_opt = func->GetAttr(attr::kComposite); + const auto& code_opt = func->GetAttr(attr::kCodegen); if (comp_opt.defined()) { optype = comp_opt.value(); } else if (code_opt.defined()) { @@ -275,7 +275,7 @@ class RelaxExprNameSetter : public ExprVisitor { const String GetFuncName(const Call& call, const Function& func) { String name; // get from unique - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); if (name_opt.defined()) { return name_opt.value(); } @@ -324,7 +324,7 @@ Pass SetRelaxExprName(const String& entry_name, const String& target, return CreateModulePass(pass_func, 0, "SetRelaxExprName", {}); } -TVM_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SetRelaxExprName").set_body_typed(SetRelaxExprName); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index ff2cb30c3b6a..d03f3ba82b28 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -263,13 +263,13 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { String obj_string; if (!obj.defined()) { obj_string = ""; - } else if (obj.as()) { + } else if (obj.as()) { obj_string = Downcast(obj); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); - } else if (const auto* n = obj.as()) { + } else if (const auto* n = obj.as()) { for (size_t i = 0; i < n->size(); i++) { obj_string = obj_string + ToString((*n)[i].cast()); if (n->size() == 1 || i < n->size() - 1) { @@ -523,27 +523,27 @@ const DataType ExprUtils::GetDataType(const Expr& expr) { return Downcast(GetStructInfo(expr))->dtype; } -TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr); -TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs); -TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr") .set_body_typed([](const String& key, const String& value) -> Span { return SpanUtils::CreateWithAttr(key, value); }); -TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr") +TVM_FFI_REGISTER_GLOBAL("msc.core.SpanSetAttr") .set_body_typed([](const Span& span, const String& key, const String& value) -> Span { return SpanUtils::SetAttr(span, key, value); }); -TVM_REGISTER_GLOBAL("msc.core.CompareVersion") +TVM_FFI_REGISTER_GLOBAL("msc.core.CompareVersion") .set_body_typed([](const Array& given_version, const Array& target_version) -> Integer { return Integer(CommonUtils::CompareVersion(given_version, target_version)); }); -TVM_REGISTER_GLOBAL("msc.core.ToAttrKey").set_body_typed([](const String& key) -> String { +TVM_FFI_REGISTER_GLOBAL("msc.core.ToAttrKey").set_body_typed([](const String& key) -> String { return CommonUtils::ToAttrKey(key); }); diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 9506d4eac818..4bceb76d4699 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -150,7 +150,7 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc index 9c8f0aeb860a..570088ee35c2 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc @@ -173,7 +173,7 @@ class TFV1BatchnormCodeGen : public TFV1OpCode { for (size_t i = 0; i < weight_names.size(); i++) { const auto& w_doc = DocUtils::ToStr(node()->WeightAt(weight_names[i])->name); stack_.inplace_start("tf_v1.constant_initializer", init_names[i] + "_initializer") - .inplace_start("asnumpy", NullOpt, DocUtils::ToIndex("weights", w_doc)) + .inplace_start("asnumpy", std::nullopt, DocUtils::ToIndex("weights", w_doc)) .inplace_end() .inplace_end(); } diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 8992d3713aa6..8b85f2e88f04 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -136,7 +136,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.comment("Mark outputs"); for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); - stack_.func_call("markOutput", NullOpt, DocUtils::ToPtr("network")) + stack_.func_call("markOutput", std::nullopt, DocUtils::ToPtr("network")) .call_arg("*" + IdxOutputBase(pair.first, pair.second)); } // mark batch_size @@ -146,7 +146,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { Array batch_flags{"MIN", "MAX", "OPT"}; for (const auto& i : graph()->GetInputs()) { for (const auto& f : batch_flags) { - stack_.func_call("setDimensions", NullOpt, DocUtils::ToPtr("profile")) + stack_.func_call("setDimensions", std::nullopt, DocUtils::ToPtr("profile")) .call_arg(DocUtils::ToStr(i->name)) .call_arg("OptProfileSelector::k" + f) .call_arg(ToDims(i->shape)); @@ -155,10 +155,10 @@ void TensorRTCodeGen::CodeGenClassDefine() { // set max workspace stack_.comment("Set max worksapce"); if (CompareVersion(6, 0, 0) >= 0) { - stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtr("config")) + stack_.func_call("setMaxWorkspaceSize", std::nullopt, DocUtils::ToPtr("config")) .call_arg(config()->max_workspace); } else { - stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtr("builder")) + stack_.func_call("setMaxWorkspaceSize", std::nullopt, DocUtils::ToPtr("builder")) .call_arg(config()->max_workspace); } // set data type @@ -169,10 +169,10 @@ void TensorRTCodeGen::CodeGenClassDefine() { .call_arg("ILogger::Severity::kINTERNAL_ERROR") .call_arg(DocUtils::ToStr("platform do not support float16, fallback to float32")) .cond_else() - .func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) + .func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kFP16"); if (config()->precision_mode == "strict") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) + stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kSTRICT_TYPES"); } stack_.func_call("log", "", "logger") @@ -186,16 +186,16 @@ void TensorRTCodeGen::CodeGenClassDefine() { .call_arg("ILogger::Severity::kINTERNAL_ERROR") .call_arg(DocUtils::ToStr("platform do not support int8, fallback to float32")) .cond_else() - .func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) + .func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kINT8"); if (config()->precision_mode == "strict") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) + stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kSTRICT_TYPES"); } else if (config()->precision_mode == "prefer") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) + stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kPREFER_PRECISION_CONSTRAINTS"); } else if (config()->precision_mode == "obey") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) + stack_.func_call("setFlag", std::nullopt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kOBEY_PRECISION_CONSTRAINTS"); } stack_.func_call("log", "", "logger") @@ -219,7 +219,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { .func_start(); stack_.comment("Create context") .func_call("TRTPtr", DocUtils::ToDeclare("auto", "context")) - .func_call("createExecutionContext", NullOpt, DocUtils::ToPtr("engine")) + .func_call("createExecutionContext", std::nullopt, DocUtils::ToPtr("engine")) .pop_nest(); ReturnOnFail("context", "Failed to create the context"); // prepare variables @@ -262,7 +262,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.func_call("cudaStreamSynchronize") .call_arg("stream") .comment("enquque with gpu buffers") - .func_call("enqueueV2", NullOpt, DocUtils::ToPtr("context")) + .func_call("enqueueV2", std::nullopt, DocUtils::ToPtr("context")) .call_arg("gpu_buffers") .call_arg("stream") .call_arg("nullptr") @@ -350,18 +350,18 @@ void TensorRTCodeGen::CodeGenMain() { "1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)", "uint32_t") .func_call("TRTPtr", DocUtils::ToDeclare("auto", "network")) - .func_call("createNetworkV2", NullOpt, DocUtils::ToPtr("builder")) + .func_call("createNetworkV2", std::nullopt, DocUtils::ToPtr("builder")) .call_arg("flags") .pop_nest(); } else { stack_.func_call("TRTPtr", DocUtils::ToDeclare("auto", "network")) - .func_call("createNetwork", NullOpt, DocUtils::ToPtr("builder")) + .func_call("createNetwork", std::nullopt, DocUtils::ToPtr("builder")) .pop_nest(); } ReturnOnFail("network", "Failed to create network"); // create config stack_.func_call("TRTPtr", DocUtils::ToDeclare("auto", "config")) - .func_call("createBuilderConfig", NullOpt, DocUtils::ToPtr("builder")) + .func_call("createBuilderConfig", std::nullopt, DocUtils::ToPtr("builder")) .pop_nest(); ReturnOnFail("config", "Failed to create config"); // add codegen before build @@ -395,7 +395,7 @@ void TensorRTCodeGen::CodeGenMain() { .assign("profile_verbose", "ProfilingVerbosity::kNONE") .cond_end() .cond_end() - .func_call("setProfilingVerbosity", NullOpt, DocUtils::ToPtr("config")) + .func_call("setProfilingVerbosity", std::nullopt, DocUtils::ToPtr("config")) .call_arg("profile_verbose"); // Serialize engine stack_.comment("Serialize engine") @@ -422,7 +422,7 @@ void TensorRTCodeGen::CodeGenMain() { stack_.comment("Dump info by inspector") .cond_if("profile_level > 0") .func_call("TRTPtr", DocUtils::ToDeclare("auto", "inspector")) - .func_call("createEngineInspector", NullOpt, DocUtils::ToPtr("engine")) + .func_call("createEngineInspector", std::nullopt, DocUtils::ToPtr("engine")) .pop_nest() .func_call("getEngineInformation", DocUtils::ToDeclare("std::string", "result"), DocUtils::ToPtr("inspector")) @@ -574,7 +574,7 @@ const Map TensorRTCodeGen::GetStepCtx() { return step_ctx; } -TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); @@ -582,7 +582,7 @@ TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") return codegen.GetSources(print_config); }); -TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTRoot").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTRoot").set_body_typed([]() -> String { #ifdef TENSORRT_ROOT_DIR return TENSORRT_ROOT_DIR; #else @@ -601,7 +601,7 @@ Array MSCTensorRTCompiler(Array functions, Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; - const auto& name_opt = func->GetAttr(msc_attr::kUnique); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); ICHECK(name_opt.defined()) << "Can not find " << msc_attr::kUnique << " from attrs"; const auto& name = name_opt.value(); std::string func_name = GetExtSymbol(func); @@ -618,7 +618,7 @@ Array MSCTensorRTCompiler(Array functions, return compiled_functions; } -TVM_REGISTER_GLOBAL("relax.ext.msc_tensorrt").set_body_typed(MSCTensorRTCompiler); +TVM_FFI_REGISTER_GLOBAL("relax.ext.msc_tensorrt").set_body_typed(MSCTensorRTCompiler); } // namespace msc } // namespace contrib diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index d90cdc35d17d..5a63ecbc7d06 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -36,16 +36,16 @@ const Array TensorRTOpCode::GetDocs() { CodeGenBuild(); if (node()->optype == "tuple") { for (size_t i = 0; i < node()->outputs.size(); i++) { - stack_.func_call("setName", NullOpt, DocUtils::ToPtr(IdxOutput(i))) + stack_.func_call("setName", std::nullopt, DocUtils::ToPtr(IdxOutput(i))) .call_arg(DocUtils::ToStr(node()->OutputAt(i)->name)); } } else if (node()->optype == "get_item") { - stack_.func_call("setName", NullOpt, DocUtils::ToPtr(IdxNode())) + stack_.func_call("setName", std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(DocUtils::ToStr(node()->OutputAt(0)->name)); } else if (node()->optype != "input") { SetLayerByValue("Name", DocUtils::ToStr(node()->name)); for (size_t i = 0; i < node()->outputs.size(); i++) { - stack_.func_call("setName", NullOpt, DocUtils::ToPtr(IdxOutput(i))) + stack_.func_call("setName", std::nullopt, DocUtils::ToPtr(IdxOutput(i))) .call_arg(DocUtils::ToStr(node()->OutputAt(i)->name)); } } @@ -157,29 +157,29 @@ const size_t TensorRTOpCode::AttrToAxis(const String& key, size_t ndim) { template void TensorRTOpCode::SetLayerByAttr(const String& method, const String& key) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())).op_arg(key, ""); + stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).op_arg(key, ""); } template void TensorRTOpCode::SetLayerByValue(const String& method, const T& value) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())).call_arg(value); + stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())).call_arg(value); } void TensorRTOpCode::SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())) + stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(AttrToDims(key, use_ndim)); } template void TensorRTOpCode::SetLayerByDimsValue(const String& method, const std::vector& value, bool use_ndim) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())) + stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } void TensorRTOpCode::SetLayerByDimsValue(const String& method, const Array& value, bool use_ndim) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())) + stack_.func_call("set" + method, std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } @@ -269,7 +269,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call() .op_input_arg() - .func_call("setOutputType", NullOpt, DocUtils::ToPtr(IdxNode())) + .func_call("setOutputType", std::nullopt, DocUtils::ToPtr(IdxNode())) .call_arg(0) .op_dtype_arg(node()->OutputAt(0)->dtype); } diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index 1131b82172a0..67f453268e2a 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -795,7 +795,7 @@ Expr RewriteSplit(BlockBuilder builder, const Var& var, const Call& src_call, split_begins.push_back(i * size); split_ends.push_back(i * size + size); } - } else if (src_attrs->indices_or_sections->IsInstance()) { + } else if (src_attrs->indices_or_sections->IsInstance()) { const auto& indices = Downcast>(src_attrs->indices_or_sections); int64_t last_index = 0; for (size_t i = 0; i < indices.size(); ++i) { @@ -913,7 +913,7 @@ Pass TransformTensorRT(const String& config) { return CreateFunctionPass(pass_func, 0, "TransformTensorRT", {}); } -TVM_REGISTER_GLOBAL("relax.transform.TransformTensorRT").set_body_typed(TransformTensorRT); +TVM_FFI_REGISTER_GLOBAL("relax.transform.TransformTensorRT").set_body_typed(TransformTensorRT); } // namespace transform } // namespace relax diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 547c1c22ba75..228efa4381ee 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -151,7 +151,7 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 1d6d74d7e43a..53d1bc0562fc 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -210,7 +210,7 @@ const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { } } -TVM_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") +TVM_FFI_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, const String& print_config) -> Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc index e54b9eedfea8..02904c3bd9c8 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -769,8 +769,8 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b stack_.call_arg(DocUtils::ToAttrAccess("meta_attr", a->name)); } stack_.call_arg("layouts") - .func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin")) - .inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_")) + .func_call("setPluginNamespace", std::nullopt, DocUtils::ToPtr("plugin")) + .inplace_start("c_str", std::nullopt, DocUtils::ToDoc("name_space_")) .inplace_end() .func_end("plugin"); // deserializePlugin @@ -784,8 +784,8 @@ void TensorRTPluginCodeGen::CodegenCreator(const Plugin& plugin, bool dynamic, b .call_arg("name") .call_arg("data") .call_arg("length") - .func_call("setPluginNamespace", NullOpt, DocUtils::ToPtr("plugin")) - .inplace_start("c_str", NullOpt, DocUtils::ToDoc("name_space_")) + .func_call("setPluginNamespace", std::nullopt, DocUtils::ToPtr("plugin")) + .inplace_start("c_str", std::nullopt, DocUtils::ToDoc("name_space_")) .inplace_end() .func_end("plugin"); } @@ -883,7 +883,7 @@ void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { } } -TVM_REGISTER_GLOBAL("msc.plugin.GetTensorRTPluginSources") +TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTensorRTPluginSources") .set_body_typed([](const String& codegen_config, const String& print_config, const String& codegen_type) -> Map { TensorRTPluginCodeGen codegen = TensorRTPluginCodeGen(codegen_config); diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index 75471d85db0d..59b99f22c7ce 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -430,7 +430,7 @@ void TorchPluginCodeGen::CodeGenMalloc(const Plugin& plugin, const ArrayFindDeviceRefIdx(tensors[i]); if (device_idx >= 0) { const auto& input_doc = DocUtils::ToIndex("input_tensors", device_idx); - stack_.inplace_start("device", NullOpt, input_doc).inplace_end(); + stack_.inplace_start("device", std::nullopt, input_doc).inplace_end(); } else { stack_.inplace_start("TorchUtils::ToTorchDevice") .call_arg(DocUtils::ToStr(tensors[i]->device)) @@ -492,7 +492,7 @@ void TorchPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& devi } } -TVM_REGISTER_GLOBAL("msc.plugin.GetTorchPluginSources") +TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTorchPluginSources") .set_body_typed([](const String& codegen_config, const String& print_config, const String& codegen_type) -> Map { TorchPluginCodeGen codegen = TorchPluginCodeGen(codegen_config); diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index e1d3c9960f6d..610fbc4c3282 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -213,12 +213,12 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { stack_.func_end("infer_output"); // register funcs - stack_.func_call("TVM_REGISTER_GLOBAL") + stack_.func_call("TVM_FFI_REGISTER_GLOBAL") .call_arg(DocUtils::ToStr("msc.plugin.op.InferStructInfo" + plugin->name)) .method_call("set_body_typed") .call_arg("InferStructInfo" + plugin->name) .line() - .func_call("TVM_REGISTER_GLOBAL") + .func_call("TVM_FFI_REGISTER_GLOBAL") .call_arg(DocUtils::ToStr("msc.plugin.op.InferLayout" + plugin->name)) .method_call("set_body_typed") .call_arg("InferLayout" + plugin->name) @@ -260,7 +260,7 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { CodeGenCompute(plugin, "cpu"); stack_.cond_end().func_end(); // register the compute - stack_.func_call("TVM_REGISTER_GLOBAL") + stack_.func_call("TVM_FFI_REGISTER_GLOBAL") .call_arg(DocUtils::ToStr(plugin->name)) .method_call("set_body") .call_arg(func_name) @@ -393,7 +393,7 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device } } -TVM_REGISTER_GLOBAL("msc.plugin.GetTVMPluginSources") +TVM_FFI_REGISTER_GLOBAL("msc.plugin.GetTVMPluginSources") .set_body_typed([](const String& codegen_config, const String& print_config, const String& codegen_type) -> Map { TVMPluginCodeGen codegen = TVMPluginCodeGen(codegen_config); diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 9de36b0a28af..3a54085c2290 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -43,7 +43,7 @@ Map> CollectCallMap(const IRModule& mod) { return call_map; } -TVM_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); +TVM_FFI_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap); } // namespace ir } // namespace tvm diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 9e43e33a6c4a..877530f4c378 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -21,9 +21,9 @@ * \file src/ir/apply_pass_to_function.cc * \brief Utility transformation that applies an inner pass to a subset of an IRModule */ +#include #include #include -#include #include #include @@ -130,7 +130,7 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, return CreateModulePass(pass_func, 0, pass_name, {}); } -TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); +TVM_FFI_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); } // namespace transform } // namespace tvm diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index ae2675643113..ecabc9bfb2cc 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -75,7 +75,7 @@ class AttrFunctor { } } virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; - virtual R VisitAttr_(const ArrayObj* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ffi::ArrayObj* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -112,7 +112,7 @@ class AttrFunctor { using namespace tir; FType vtable; // Set dispatch - ATTR_FUNCTOR_DISPATCH(ArrayObj); + ATTR_FUNCTOR_DISPATCH(ffi::ArrayObj); ATTR_FUNCTOR_DISPATCH(IntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index fd87c2bc8e0c..52a2cceeaf79 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -20,8 +20,8 @@ /*! * \file attrs.cc */ +#include #include -#include #include "attr_functor.h" @@ -73,11 +73,11 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); -TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { +TVM_FFI_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { return attrs->dict; }); -TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { +TVM_FFI_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { return attrs->ListFieldInfo(); }); diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index ec11f2c04f6c..70197074317d 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -33,7 +33,7 @@ namespace tvm { /* Diagnostic */ TVM_REGISTER_NODE_TYPE(DiagnosticNode); -TVM_REGISTER_GLOBAL("diagnostics.Diagnostic") +TVM_FFI_REGISTER_GLOBAL("diagnostics.Diagnostic") .set_body_typed([](int level, Span span, String message) { return Diagnostic(static_cast(level), span, message); }); @@ -106,7 +106,7 @@ TVM_DLL DiagnosticRenderer::DiagnosticRenderer( data_ = std::move(n); } -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer") .set_body_typed([](ffi::TypedFunction renderer) { return DiagnosticRenderer(renderer); }); @@ -134,7 +134,7 @@ void DiagnosticContext::Render() { } } -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") .set_body_typed([](DiagnosticRenderer renderer, DiagnosticContext ctx) { renderer.Render(ctx); }); @@ -147,7 +147,7 @@ DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRen data_ = std::move(n); } -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContext") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticContext") .set_body_typed([](const IRModule& module, const DiagnosticRenderer& renderer) { return DiagnosticContext(module, renderer); }); @@ -157,12 +157,12 @@ void DiagnosticContext::Emit(const Diagnostic& diagnostic) { (*this)->diagnostics.push_back(diagnostic); } -TVM_REGISTER_GLOBAL("diagnostics.Emit") +TVM_FFI_REGISTER_GLOBAL("diagnostics.Emit") .set_body_typed([](DiagnosticContext ctx, const Diagnostic& diagnostic) { return ctx.Emit(diagnostic); }); -TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender") +TVM_FFI_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender") .set_body_typed([](DiagnosticContext context) { return context.Render(); }); /*! \brief Emit a diagnostic. */ @@ -195,7 +195,7 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } -TVM_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { +TVM_FFI_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { return DiagnosticContext::Default(module); }); @@ -311,11 +311,13 @@ DiagnosticRenderer TerminalRenderer(std::ostream& out) { }); } -TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cerr); }); +TVM_FFI_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { + return TerminalRenderer(std::cerr); +}); -TVM_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); }); +TVM_FFI_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); }); -TVM_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() { tvm::ffi::Function::RemoveGlobal(OVERRIDE_RENDERER); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 9713f88f7ddd..ce40df21eb9a 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -20,8 +20,8 @@ /*! * \file env_func.cc */ +#include #include -#include #include namespace tvm { @@ -47,15 +47,15 @@ ObjectPtr CreateEnvNode(const std::string& name) { EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); +TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncCall").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { EnvFunc env = args[0].cast(); ICHECK_GE(args.size(), 1); env->func.CallPacked(args.Slice(1), rv); }); -TVM_REGISTER_GLOBAL("ir.EnvFuncGetFunction").set_body_typed([](const EnvFunc& n) { +TVM_FFI_REGISTER_GLOBAL("ir.EnvFuncGetFunction").set_body_typed([](const EnvFunc& n) { return n->func; }); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index de665dcd22b3..387572f6427b 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -22,9 +22,9 @@ * \brief The expression AST nodes for the common IR infra. */ #include +#include #include #include -#include #include #include @@ -64,7 +64,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) { +TVM_FFI_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) { return IntImm(dtype, value, span); }); @@ -115,7 +115,7 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) { +TVM_FFI_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) { return FloatImm(dtype, value, span); }); @@ -128,9 +128,9 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { return Range(make_object(min, extent, span)); } -TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); +TVM_FFI_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); -TVM_REGISTER_GLOBAL("ir.Range") +TVM_FFI_REGISTER_GLOBAL("ir.Range") .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { if (end.defined()) { return Range(begin, end.value(), span); @@ -151,11 +151,11 @@ GlobalVar::GlobalVar(String name_hint, Type type, Span span) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { return GlobalVar(name, type); }); -TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { +TVM_FFI_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { std::stringstream ss; ss << ref; return ss.str(); diff --git a/src/ir/function.cc b/src/ir/function.cc index 8f543b03260c..66d66e3c8133 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -21,19 +21,21 @@ * \file src/ir/function.cc * \brief The function data structure. */ +#include #include #include #include -#include #include namespace tvm { -TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { return func->attrs; }); +TVM_FFI_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { + return func->attrs; +}); -TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); -TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithAttr") .set_body_typed([](ffi::RValueRef func_ref, String key, Any value) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { @@ -45,7 +47,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") } }); -TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") .set_body_typed([](ffi::RValueRef func_ref, Map attr_map) -> BaseFunc { BaseFunc func = *std::move(func_ref); @@ -61,7 +63,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") TVM_FFI_UNREACHABLE(); }); -TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") +TVM_FFI_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") .set_body_typed([](ffi::RValueRef func_ref, String key) -> BaseFunc { BaseFunc func = *std::move(func_ref); if (func->IsInstance()) { diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 6abac574e1b7..3df9ae00fb53 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -25,7 +25,7 @@ #include namespace tvm { TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode); -TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { auto n = DummyGlobalInfo(make_object()); return n; }); @@ -39,7 +39,8 @@ VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) { } TVM_REGISTER_NODE_TYPE(VDeviceNode); -TVM_REGISTER_GLOBAL("ir.VDevice").set_body_typed([](Target tgt, int dev_id, MemoryScope mem_scope) { - return VDevice(tgt, dev_id, mem_scope); -}); +TVM_FFI_REGISTER_GLOBAL("ir.VDevice") + .set_body_typed([](Target tgt, int dev_id, MemoryScope mem_scope) { + return VDevice(tgt, dev_id, mem_scope); + }); } // namespace tvm diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 3d3b8919916f..1b47c1a89639 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -23,7 +23,7 @@ */ #include "tvm/ir/global_var_supply.h" -#include +#include #include @@ -92,24 +92,23 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) { TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply") .set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) { +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) { return GlobalVarSupply(std::move(mod)); }); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const Array& mods) { - return GlobalVarSupply(mods); -}); +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules") + .set_body_typed([](const Array& mods) { return GlobalVarSupply(mods); }); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal") .set_body_method(&GlobalVarSupplyNode::FreshGlobal); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor") .set_body_method(&GlobalVarSupplyNode::UniqueGlobalFor); -TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar") +TVM_FFI_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar") .set_body_method(&GlobalVarSupplyNode::ReserveGlobalVar); } // namespace tvm diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index ad66f2944891..a273245c1b64 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -22,10 +22,10 @@ * \brief Infrastructure for instrumentation. */ #include +#include #include #include #include -#include #include @@ -175,7 +175,7 @@ void BasePassInstrumentNode::RunAfterPass(const IRModule& ir_module, TVM_REGISTER_NODE_TYPE(BasePassInstrumentNode); -TVM_REGISTER_GLOBAL("instrument.PassInstrument") +TVM_FFI_REGISTER_GLOBAL("instrument.PassInstrument") .set_body_typed( [](String name, ffi::TypedFunction enter_pass_ctx, ffi::TypedFunction exit_pass_ctx, @@ -308,9 +308,9 @@ String RenderPassProfiles() { return os.str(); } -TVM_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); +TVM_FFI_REGISTER_GLOBAL("instrument.RenderTimePassProfiles").set_body_typed(RenderPassProfiles); -TVM_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("instrument.MakePassTimingInstrument").set_body_typed([]() { auto run_before_pass = [](const IRModule&, const transform::PassInfo& pass_info) { PassProfile::EnterPass(pass_info->name); return true; diff --git a/src/ir/module.cc b/src/ir/module.cc index 02e353c8a6c6..3166ffba9787 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -20,13 +20,13 @@ * \file module.cc * \brief The global module in TVM. */ +#include +#include #include #include #include #include #include -#include -#include #include #include @@ -242,7 +242,7 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr, TVM_REGISTER_NODE_TYPE(IRModuleNode); -TVM_REGISTER_GLOBAL("ir.IRModule") +TVM_FFI_REGISTER_GLOBAL("ir.IRModule") .set_body_typed([](tvm::Map funcs, tvm::ObjectRef attrs, Map> global_infos) { auto dict_attrs = [&attrs]() { @@ -250,7 +250,7 @@ TVM_REGISTER_GLOBAL("ir.IRModule") return DictAttrs(); } else if (auto* as_dict_attrs = attrs.as()) { return GetRef(as_dict_attrs); - } else if (attrs.as()) { + } else if (attrs.as()) { return tvm::DictAttrs(Downcast>(attrs)); } else { LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map"; @@ -260,20 +260,20 @@ TVM_REGISTER_GLOBAL("ir.IRModule") return IRModule(funcs, {}, dict_attrs, global_infos); }); -TVM_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Clone").set_body_typed([](IRModule mod) -> IRModule { IRModule clone = mod; clone.CopyOnWrite(); return clone; }); -TVM_REGISTER_GLOBAL("ir.Module_Add") +TVM_FFI_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { ICHECK(val->IsInstance()); mod->Add(var, Downcast(val), update); return mod; }); -TVM_REGISTER_GLOBAL("ir.Module_Remove") +TVM_FFI_REGISTER_GLOBAL("ir.Module_Remove") .set_body_typed([](IRModule mod, Variant var) -> IRModule { GlobalVar gvar = [&]() { if (auto opt = var.as()) { @@ -289,7 +289,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Remove") return mod; }); -TVM_REGISTER_GLOBAL("ir.Module_Contains") +TVM_FFI_REGISTER_GLOBAL("ir.Module_Contains") .set_body_typed([](IRModule mod, Variant var) -> bool { if (auto opt = var.as()) { return mod->functions.count(opt.value()); @@ -301,55 +301,57 @@ TVM_REGISTER_GLOBAL("ir.Module_Contains") } }); -TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar").set_body_method(&IRModuleNode::GetGlobalVar); +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetGlobalVar").set_body_method(&IRModuleNode::GetGlobalVar); -TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars").set_body_method(&IRModuleNode::GetGlobalVars); +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetGlobalVars").set_body_method(&IRModuleNode::GetGlobalVars); -TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar").set_body_method(&IRModuleNode::ContainGlobalVar); +TVM_FFI_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") + .set_body_method(&IRModuleNode::ContainGlobalVar); -TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); +TVM_FFI_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); -TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { +TVM_FFI_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); -TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") +TVM_FFI_REGISTER_GLOBAL("ir.Module_UpdateFunction") .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); -TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") +TVM_FFI_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo") .set_body_typed([](IRModule mod, String name, Array global_info) { mod->UpdateGlobalInfo(name, global_info); }); -TVM_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetAttrs").set_body_typed([](IRModule mod) -> ObjectRef { return mod->GetAttrs(); }); -TVM_REGISTER_GLOBAL("ir.Module_WithAttr") +TVM_FFI_REGISTER_GLOBAL("ir.Module_WithAttr") .set_body_typed([](ffi::RValueRef mod, String key, ffi::Any value) -> IRModule { return WithAttr(*std::move(mod), key, value); }); -TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr") +TVM_FFI_REGISTER_GLOBAL("ir.Module_WithoutAttr") .set_body_typed([](ffi::RValueRef mod, String key) -> IRModule { return WithoutAttr(*std::move(mod), key); }); -TVM_REGISTER_GLOBAL("ir.Module_WithAttrs") +TVM_FFI_REGISTER_GLOBAL("ir.Module_WithAttrs") .set_body_typed([](ffi::RValueRef mod, Map attr_map) -> IRModule { return WithAttrs(*std::move(mod), attr_map); }); -TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { - return mod->GetAttr(key); -}); +TVM_FFI_REGISTER_GLOBAL("ir.Module_GetAttr") + .set_body_typed([](IRModule mod, String key) -> ObjectRef { + return mod->GetAttr(key); + }); } // namespace tvm diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc index 087fc82a50f7..e73b0e63e3d0 100644 --- a/src/ir/name_supply.cc +++ b/src/ir/name_supply.cc @@ -23,7 +23,7 @@ */ #include "tvm/ir/name_supply.h" -#include +#include #include @@ -92,14 +92,15 @@ std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) TVM_REGISTER_NODE_TYPE(NameSupplyNode); -TVM_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) { +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply").set_body_typed([](String prefix) { return NameSupply(prefix); }); -TVM_REGISTER_GLOBAL("ir.NameSupply_FreshName").set_body_method(&NameSupplyNode::FreshName); +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_FreshName").set_body_method(&NameSupplyNode::FreshName); -TVM_REGISTER_GLOBAL("ir.NameSupply_ReserveName").set_body_method(&NameSupplyNode::ReserveName); +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_ReserveName").set_body_method(&NameSupplyNode::ReserveName); -TVM_REGISTER_GLOBAL("ir.NameSupply_ContainsName").set_body_method(&NameSupplyNode::ContainsName); +TVM_FFI_REGISTER_GLOBAL("ir.NameSupply_ContainsName") + .set_body_method(&NameSupplyNode::ContainsName); } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index 70f7528e5e76..b6d1f39526db 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -21,10 +21,10 @@ * \file src/ir/op.cc * \brief Primitive operators and intrinsics. */ +#include #include #include #include -#include #include #include @@ -75,13 +75,13 @@ void OpRegEntry::UpdateAttr(const String& key, ffi::Any value, int plevel) { } // Frontend APIs -TVM_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { return OpRegistry::Global()->ListAllNames(); }); -TVM_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); +TVM_FFI_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); -TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> ffi::Any { +TVM_FFI_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> ffi::Any { auto op_map = Op::GetAttrMap(attr_name); ffi::Any rv; if (op_map.count(op)) { @@ -90,50 +90,50 @@ TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) - return rv; }); -TVM_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool { +TVM_FFI_REGISTER_GLOBAL("ir.OpHasAttr").set_body_typed([](Op op, String attr_name) -> bool { return Op::HasAttrMap(attr_name); }); -TVM_REGISTER_GLOBAL("ir.OpSetAttr") +TVM_FFI_REGISTER_GLOBAL("ir.OpSetAttr") .set_body_typed([](Op op, String attr_name, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attr(attr_name, value, plevel); }); -TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { +TVM_FFI_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); reg.reset_attr(attr_name); }); -TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { +TVM_FFI_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before"; auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); op.describe(descr); }); -TVM_REGISTER_GLOBAL("ir.OpAddArgument") +TVM_FFI_REGISTER_GLOBAL("ir.OpAddArgument") .set_body_typed([](Op op, String name, String type, String description) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.add_argument(name, type, description); }); -TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { +TVM_FFI_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_support_level(level); }); -TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { +TVM_FFI_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_num_inputs(n); }); -TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { +TVM_FFI_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); reg.set_attrs_type_key(key); }); -TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") +TVM_FFI_REGISTER_GLOBAL("ir.RegisterOpAttr") .set_body_typed([](String op_name, String attr_key, ffi::AnyView value, int plevel) { auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); // enable resgiteration and override of certain properties @@ -146,7 +146,7 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") } }); -TVM_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") +TVM_FFI_REGISTER_GLOBAL("ir.RegisterOpLowerIntrinsic") .set_body_typed([](String name, ffi::Function f, String target, int plevel) { tvm::OpRegEntry::RegisterOrGet(name).set_attr(target + ".FLowerIntrinsic", f, plevel); diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc index 48e7abe5618a..0dca97302470 100644 --- a/src/ir/replace_global_vars.cc +++ b/src/ir/replace_global_vars.cc @@ -22,8 +22,8 @@ * \brief IRModule transform to replace GlobalVar instances across any IR type. */ +#include #include -#include #include @@ -62,7 +62,7 @@ IRModule ReplaceGlobalVars(IRModule mod, Map replacements) return mod; } -TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); +TVM_FFI_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars); IRModule ModuleReplaceGlobalVars( IRModule mod, Map, Variant> replacements) { @@ -93,7 +93,7 @@ IRModule ModuleReplaceGlobalVars( return ReplaceGlobalVars(mod, gvar_replacements); } -TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); +TVM_FFI_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars); } // namespace transform } // namespace tvm diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 8e25b25a4ca4..482e1dfa1018 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -20,9 +20,9 @@ * \file source_map.cc * \brief The implementation of the source map data structure. */ +#include #include #include -#include #include @@ -50,7 +50,7 @@ ObjectPtr GetSourceNameNodeByStr(const std::string& name) { SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } -TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); +TVM_FFI_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { @@ -129,12 +129,12 @@ SequentialSpan::SequentialSpan(std::initializer_list init) { TVM_REGISTER_NODE_TYPE(SequentialSpanNode); -TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, - int column, int end_column) { +TVM_FFI_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source_name, int line, int end_line, + int column, int end_column) { return Span(source_name, line, end_line, column, end_column); }); -TVM_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array spans) { +TVM_FFI_REGISTER_GLOBAL("ir.SequentialSpan").set_body_typed([](tvm::Array spans) { return SequentialSpan(spans); }); @@ -218,11 +218,12 @@ SourceMap::SourceMap(Map source_map) { void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } -TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { - auto src_name = SourceName::Get(name); - Source source(src_name, content); - map.Add(source); - return src_name; -}); +TVM_FFI_REGISTER_GLOBAL("SourceMapAdd") + .set_body_typed([](SourceMap map, String name, String content) { + auto src_name = SourceName::Get(name); + Source source(src_name, content); + map.Add(source); + return src_name; + }); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 0730faa7b4d7..db4e47ca0d1a 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -22,6 +22,7 @@ * \brief Infrastructure for transformation passes. */ #include +#include #include #include #include @@ -29,14 +30,12 @@ #include #include #include -#include #include #include #include #include -#include "../runtime/object_internal.h" #include "../runtime/regex.h" namespace tvm { @@ -87,7 +86,7 @@ PassContext PassContext::Current() { } // linearly scan the pass array to match pass_name -bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { +bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { if (x == pass_name) return true; } @@ -378,8 +377,7 @@ class ModulePass : public Pass { TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; -PassInfo::PassInfo(int opt_level, String name, tvm::Array required, - bool traceable) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required, bool traceable) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -533,12 +531,12 @@ Pass CreateModulePass(std::function pass_func, TVM_REGISTER_NODE_TYPE(PassInfoNode); -TVM_REGISTER_GLOBAL("transform.PassInfo") +TVM_FFI_REGISTER_GLOBAL("transform.PassInfo") .set_body_typed([](int opt_level, String name, tvm::Array required, bool traceable) { return PassInfo(opt_level, name, required, traceable); }); -TVM_REGISTER_GLOBAL("transform.Info").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { +TVM_FFI_REGISTER_GLOBAL("transform.Info").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { Pass pass = args[0].cast(); *ret = pass->Info(); }); @@ -563,7 +561,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); -TVM_REGISTER_GLOBAL("transform.MakeModulePass") +TVM_FFI_REGISTER_GLOBAL("transform.MakeModulePass") .set_body_typed( [](ffi::TypedFunction, PassContext)> pass_func, PassInfo pass_info) { @@ -573,7 +571,7 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") return ModulePass(wrapped_pass_func, pass_info); }); -TVM_REGISTER_GLOBAL("transform.RunPass") +TVM_FFI_REGISTER_GLOBAL("transform.RunPass") .set_body_typed([](Pass pass, ffi::RValueRef mod) { return pass(*std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -586,12 +584,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("transform.Sequential") +TVM_FFI_REGISTER_GLOBAL("transform.Sequential") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto passes = args[0].cast>(); int opt_level = args[1].cast(); std::string name = args[2].cast(); - auto required = args[3].cast>(); + auto required = args[3].cast>(); bool traceable = args[4].cast(); PassInfo pass_info = PassInfo(opt_level, name, required, /* traceable */ traceable); *ret = Sequential(passes, pass_info); @@ -613,7 +611,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); -TVM_REGISTER_GLOBAL("transform.PassContext") +TVM_FFI_REGISTER_GLOBAL("transform.PassContext") .set_body_typed([](int opt_level, Array required, Array disabled, Array instruments, Optional> config, Array trace_stack, @@ -659,24 +657,27 @@ class PassContext::Internal { static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("transform.GetTraceStack").set_body_method(&PassContextNode::GetTraceStack); -TVM_REGISTER_GLOBAL("transform.PushTrace").set_body_method(&PassContextNode::PushTrace); -TVM_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); -TVM_REGISTER_GLOBAL("transform.GetTraceStackSize") +TVM_FFI_REGISTER_GLOBAL("transform.GetTraceStack").set_body_method(&PassContextNode::GetTraceStack); +TVM_FFI_REGISTER_GLOBAL("transform.PushTrace").set_body_method(&PassContextNode::PushTrace); +TVM_FFI_REGISTER_GLOBAL("transform.PopTrace").set_body_method(&PassContextNode::PopTrace); +TVM_FFI_REGISTER_GLOBAL("transform.GetTraceStackSize") .set_body_method(&PassContextNode::GetTraceStackSize); -TVM_REGISTER_GLOBAL("transform.GetCurrentTrace").set_body_method(&PassContextNode::GetCurrentTrace); -TVM_REGISTER_GLOBAL("transform.SetNumEvals").set_body_method(&PassContextNode::SetNumEvals); -TVM_REGISTER_GLOBAL("transform.IncNumEvals").set_body_method(&PassContextNode::IncNumEvals); -TVM_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") +TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentTrace") + .set_body_method(&PassContextNode::GetCurrentTrace); +TVM_FFI_REGISTER_GLOBAL("transform.SetNumEvals").set_body_method(&PassContextNode::SetNumEvals); +TVM_FFI_REGISTER_GLOBAL("transform.IncNumEvals").set_body_method(&PassContextNode::IncNumEvals); +TVM_FFI_REGISTER_GLOBAL("transform.GetTuningAPIDatabase") .set_body_method(&PassContextNode::GetTuningAPIDatabase); -TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); +TVM_FFI_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); -TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); +TVM_FFI_REGISTER_GLOBAL("transform.EnterPassContext") + .set_body_typed(PassContext::Internal::EnterScope); -TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); +TVM_FFI_REGISTER_GLOBAL("transform.ExitPassContext") + .set_body_typed(PassContext::Internal::ExitScope); -TVM_REGISTER_GLOBAL("transform.OverrideInstruments") +TVM_FFI_REGISTER_GLOBAL("transform.OverrideInstruments") .set_body_typed([](PassContext pass_ctx, Array instruments) { pass_ctx.InstrumentExitPassContext(); pass_ctx->instruments = instruments; @@ -691,9 +692,9 @@ Pass PrintIR(String header, bool show_meta_data) { return CreateModulePass(pass_func, 0, "PrintIR", {}, /* traceable */ false); } -TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); +TVM_FFI_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); -TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); +TVM_FFI_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 3c648418c6a9..8bc48a11141f 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -21,8 +21,8 @@ * \file src/ir/type.cc * \brief Common type system AST nodes throughout the IR. */ +#include #include -#include namespace tvm { PrimType::PrimType(runtime::DataType dtype, Span span) { @@ -34,7 +34,7 @@ PrimType::PrimType(runtime::DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { +TVM_FFI_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); @@ -47,7 +47,7 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); -TVM_REGISTER_GLOBAL("ir.PointerType") +TVM_FFI_REGISTER_GLOBAL("ir.PointerType") .set_body_typed([](Type element_type, String storage_scope = "") { return PointerType(element_type, storage_scope); }); @@ -62,9 +62,10 @@ FuncType::FuncType(tvm::Array arg_types, Type ret_type, Span span) { TVM_REGISTER_NODE_TYPE(FuncTypeNode); -TVM_REGISTER_GLOBAL("ir.FuncType").set_body_typed([](tvm::Array arg_types, Type ret_type) { - return FuncType(arg_types, ret_type); -}); +TVM_FFI_REGISTER_GLOBAL("ir.FuncType") + .set_body_typed([](tvm::Array arg_types, Type ret_type) { + return FuncType(arg_types, ret_type); + }); TupleType::TupleType(Array fields, Span span) { ObjectPtr n = make_object(); @@ -77,7 +78,7 @@ TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { +TVM_FFI_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { return TupleType(fields); }); diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index ad3966ed1e72..58c4a7b33c4f 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -67,9 +67,9 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { // The JSON object is always an array whose first element is a tag. For example: // `['TENSOR', 'float32', [1, 224, 224, 3]] // Step 1. Extract the tag - String tag{runtime::ObjectPtr(nullptr)}; + String tag{ffi::ObjectPtr(nullptr)}; try { - const ArrayObj* json_array = json_obj.as(); + const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() >= 1); tag = json_array->at(0).cast(); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error @@ -111,7 +111,7 @@ Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) /******** TensorInfo ********/ -TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) { +TensorInfo::TensorInfo(runtime::DataType dtype, ffi::Shape shape) { ObjectPtr n = make_object(); n->dtype = dtype; n->shape = shape; @@ -129,12 +129,12 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { DLDataType dtype; Array shape; try { - const ArrayObj* json_array = json_obj.as(); + const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 3); // Load json[1] => dtype { String dtype_str = json_array->at(1).cast(); - dtype = runtime::StringToDLDataType(dtype_str); + dtype = StringToDLDataType(dtype_str); } // Load json[2] => shape shape = AsIntArray(json_array->at(2).cast()); @@ -145,7 +145,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { std::vector s; std::transform(shape.begin(), shape.end(), std::back_inserter(s), [](Integer i) { return i.IntValue(); }); - return TensorInfo(DataType(dtype), ShapeTuple(s.begin(), s.end())); + return TensorInfo(DataType(dtype), ffi::Shape(s.begin(), s.end())); } /******** Repr ********/ @@ -162,12 +162,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); TVM_REGISTER_NODE_TYPE(TensorInfoNode); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc").set_body_typed(ArgInfo::FromEntryFunc); -TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); -TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo") - .set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo { +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromEntryFunc") + .set_body_typed(ArgInfo::FromEntryFunc); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TensorInfo") + .set_body_typed([](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { return TensorInfo(dtype, shape); }); diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index 9d725e91e247..85e189e73228 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -52,21 +52,21 @@ TVM_REGISTER_NODE_TYPE(BuilderResultNode); TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_NODE_TYPE(PyBuilderNode); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderInput") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderInput") .set_body_typed([](IRModule mod, Target target, Optional> params) -> BuilderInput { return BuilderInput(mod, target, params); }); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderResult") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderResult") .set_body_typed([](Optional artifact_path, Optional error_msg) -> BuilderResult { return BuilderResult(artifact_path, error_msg); }); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method(&BuilderNode::Build); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderBuild").set_body_method(&BuilderNode::Build); -TVM_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.BuilderPyBuilder").set_body_typed(Builder::PyBuilder); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/cost_model/cost_model.cc b/src/meta_schedule/cost_model/cost_model.cc index 1d28eb19d7cb..5c1c7a568580 100644 --- a/src/meta_schedule/cost_model/cost_model.cc +++ b/src/meta_schedule/cost_model/cost_model.cc @@ -71,10 +71,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_NODE_TYPE(PyCostModelNode); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelUpdate").set_body_method(&CostModelNode::Update); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelLoad").set_body_method(&CostModelNode::Load); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelSave").set_body_method(&CostModelNode::Save); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelUpdate").set_body_method(&CostModelNode::Update); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelPredict") .set_body_typed([](CostModel model, // const TuneContext& context, // Array candidates, // @@ -82,7 +82,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.CostModelPredict") std::vector result = model->Predict(context, candidates); std::copy(result.begin(), result.end(), static_cast(p_addr)); }); -TVM_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel").set_body_typed(CostModel::PyCostModel); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.CostModelPyCostModel") + .set_body_typed(CostModel::PyCostModel); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index f859648464d4..034294eedcd3 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -25,14 +25,14 @@ namespace meta_schedule { /******** Workload ********/ Workload::Workload(IRModule mod) { - ObjectPtr n = runtime::make_object(); + ObjectPtr n = ffi::make_object(); n->mod = mod; n->shash = ModuleEquality::Create("structural")->Hash(mod); data_ = std::move(n); } Workload::Workload(IRModule mod, Workload::THashCode shash) { - ObjectPtr n = runtime::make_object(); + ObjectPtr n = ffi::make_object(); n->mod = mod; n->shash = shash; data_ = std::move(n); @@ -51,7 +51,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { IRModule mod{nullptr}; THashCode shash = 0; try { - const ArrayObj* json_array = json_obj.as(); + const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 2); // Load json[0] => shash String str_shash = json_array->at(0).cast(); @@ -134,7 +134,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w Optional target; Optional> args_info; try { - const ArrayObj* json_array = json_obj.as(); + const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 4); // Load json[1] => run_secs if (json_array->at(1) != nullptr) { @@ -146,7 +146,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w } // Load json[3] => args_info if (json_array->at(3) != nullptr) { - const ArrayObj* json_args_info = json_array->at(3).cast(); + const ffi::ArrayObj* json_args_info = json_array->at(3).cast(); Array info; info.reserve(json_args_info->size()); for (Any json_arg_info : *json_args_info) { @@ -177,11 +177,11 @@ DatabaseNode::~DatabaseNode() = default; Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target, const String& workload_name) { if (!this->HasWorkload(mod)) { - return NullOpt; + return std::nullopt; } Array records = this->GetTopK(this->CommitWorkload(mod), 1); if (records.empty()) { - return NullOpt; + return std::nullopt; } ICHECK_EQ(records.size(), 1); return records[0]; @@ -197,7 +197,7 @@ Optional DatabaseNode::QuerySchedule(const IRModule& mod, const T record->trace->ApplyToSchedule(sch, false); return sch; } else { - return NullOpt; + return std::nullopt; } } @@ -206,7 +206,7 @@ Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target if (Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { return opt_sch.value()->mod(); } else { - return NullOpt; + return std::nullopt; } } @@ -245,7 +245,7 @@ void Database::ExitWithScope() { ThreadLocalDatabases()->pop_back(); } Optional Database::Current() { std::vector* tls = ThreadLocalDatabases(); if (tls->empty()) { - return NullOpt; + return std::nullopt; } else { return tls->back(); } @@ -282,43 +282,46 @@ TVM_REGISTER_NODE_TYPE(WorkloadNode); TVM_REGISTER_NODE_TYPE(TuningRecordNode); TVM_REGISTER_OBJECT_TYPE(DatabaseNode); TVM_REGISTER_NODE_TYPE(PyDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { +TVM_FFI_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { return Workload(mod); }); -TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON").set_body_method(&WorkloadNode::AsJSON); -TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON").set_body_method(&WorkloadNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecord") .set_body_typed([](tir::Trace trace, Workload workload, Optional> run_secs, Optional target, Optional> args_info) { return TuningRecord(trace, workload, run_secs, target, args_info); }); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordAsMeasureCandidate") .set_body_method(&TuningRecordNode::AsMeasureCandidate); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON").set_body_method(&TuningRecordNode::AsJSON); -TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON") + .set_body_typed(TuningRecord::FromJSON); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseEnterWithScope") .set_body_method(&Database::EnterWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseExitWithScope") .set_body_method(&Database::ExitWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCurrent").set_body_typed(Database::Current); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseHasWorkload") .set_body_method(&DatabaseNode::HasWorkload); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") .set_body_method(&DatabaseNode::CommitWorkload); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") .set_body_method(&DatabaseNode::CommitTuningRecord); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseGetAllTuningRecords") .set_body_method(&DatabaseNode::GetAllTuningRecords); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQueryTuningRecord") .set_body_method(&DatabaseNode::QueryTuningRecord); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQuerySchedule") .set_body_method(&DatabaseNode::QuerySchedule); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseQueryIRModule") .set_body_method(&DatabaseNode::QueryIRModule); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned").set_body_method(&DatabaseNode::DumpPruned); -TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseDumpPruned") + .set_body_method(&DatabaseNode::DumpPruned); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index badba34bcab0..1f396882720b 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -29,7 +29,7 @@ namespace meta_schedule { void JSONDumps(Any json_obj, std::ostringstream& os) { if (json_obj == nullptr) { os << "null"; - } else if (auto opt_int_imm = json_obj.as()) { + } else if (auto opt_int_imm = json_obj.try_cast()) { IntImm int_imm = *std::move(opt_int_imm); if (int_imm->dtype == DataType::Bool()) { if (int_imm->value) { @@ -40,10 +40,10 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { } else { os << int_imm->value; } - } else if (auto opt_float_imm = json_obj.as()) { + } else if (auto opt_float_imm = json_obj.try_cast()) { FloatImm float_imm = *std::move(opt_float_imm); os << std::setprecision(20) << float_imm->value; - } else if (const auto* str = json_obj.as()) { + } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { os << "["; @@ -60,7 +60,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { std::vector> key_values; key_values.reserve(n); for (const auto& kv : *dict) { - if (auto key = kv.first.as()) { + if (auto key = kv.first.try_cast()) { key_values.emplace_back(key.value(), kv.second); } else { LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index facfc9b8809d..2a6b93f8cb3b 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -189,7 +189,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, auto json_obj = json_objs[task_id].cast(); Workload workload{nullptr}; try { - const ArrayObj* arr = json_obj.as(); + const ffi::ArrayObj* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); int64_t workload_index = arr->at(0).cast()->value; ICHECK(workload_index >= 0 && static_cast(workload_index) < workloads.size()); @@ -214,7 +214,8 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, } TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase") + .set_body_typed(Database::JSONDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 3d418206b031..cbc811752cad 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -97,7 +97,7 @@ Database Database::MemoryDatabase(String mod_eq_name) { } TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") .set_body_typed(Database::MemoryDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index 3aaee2112c0c..87f5c03a71eb 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -38,7 +38,7 @@ class OrderedUnionDatabaseNode : public DatabaseNode { return record; } } - return NullOpt; + return std::nullopt; } bool HasWorkload(const IRModule& mod) final { @@ -79,7 +79,7 @@ Database Database::OrderedUnionDatabase(Array databases) { } TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") .set_body_typed(Database::OrderedUnionDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index a1a1351812e2..c66ec5f4f0c1 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -40,11 +40,11 @@ class ScheduleFnDatabaseNode : public DatabaseNode { if (Optional sch = this->QuerySchedule(mod, target, workload_name)) { return TuningRecord(sch.value()->trace().value(), /*workload=*/Workload(mod, 0), // - /*run_secs=*/NullOpt, // + /*run_secs=*/std::nullopt, // /*target=*/target, // - /*arg_info=*/NullOpt); + /*arg_info=*/std::nullopt); } - return NullOpt; + return std::nullopt; } Optional QuerySchedule(const IRModule& mod, const Target& target, @@ -55,7 +55,7 @@ class ScheduleFnDatabaseNode : public DatabaseNode { /*debug_mode=*/0, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); if (!schedule_fn(sch)) { - return NullOpt; + return std::nullopt; } return sch; } @@ -99,7 +99,7 @@ Database Database::ScheduleFnDatabase(ffi::TypedFunction sc } TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") .set_body_typed(Database::ScheduleFnDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index 6d19a38c6d9e..2bc82b459cad 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -41,7 +41,7 @@ class UnionDatabaseNode : public DatabaseNode { } } std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); - return results.empty() ? Optional(NullOpt) : results[0]; + return results.empty() ? Optional(std::nullopt) : results[0]; } bool HasWorkload(const IRModule& mod) final { @@ -82,7 +82,8 @@ Database Database::UnionDatabase(Array databases) { } TVM_REGISTER_NODE_TYPE(UnionDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase").set_body_typed(Database::UnionDatabase); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase") + .set_body_typed(Database::UnionDatabase); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index ec04361f51ec..fb26e6eb693c 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -39,7 +39,7 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, } TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); -TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ExtractedTask") .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index 093558d2284e..9a3cecf4ce26 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -48,9 +48,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); -TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom") .set_body_method(&FeatureExtractorNode::ExtractFrom); -TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor") .set_body_typed(FeatureExtractor::PyFeatureExtractor); } // namespace meta_schedule diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 955e39df4fc3..2fc8878546d8 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -279,6 +279,9 @@ Pass SimplifyForFeatureExtraction() { } Stmt VisitStmt_(const ForNode* loop) final { + if (is_zero(loop->extent)) { + return Evaluate(0); + } if (is_zero(loop->min) && is_one(loop->extent) && loop->kind == ForKind::kSerial && loop->annotations.empty()) { unit_vars_.insert(loop->loop_var); @@ -1439,7 +1442,7 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, } TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); -TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") .set_body_typed(FeatureExtractor::PerStoreFeature); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/add_to_database.cc b/src/meta_schedule/measure_callback/add_to_database.cc index 68a4b93ea96f..becd9d2110df 100644 --- a/src/meta_schedule/measure_callback/add_to_database.cc +++ b/src/meta_schedule/measure_callback/add_to_database.cc @@ -65,7 +65,7 @@ MeasureCallback MeasureCallback::AddToDatabase() { } TVM_REGISTER_NODE_TYPE(AddToDatabaseNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackAddToDatabase") .set_body_typed(MeasureCallback::AddToDatabase); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 8f94e298463a..0ee49f2ab4f9 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -59,11 +59,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") .set_body_method(&MeasureCallbackNode::Apply); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") .set_body_typed(MeasureCallback::PyMeasureCallback); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackDefault") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackDefault") .set_body_typed(MeasureCallback::Default); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/remove_build_artifact.cc b/src/meta_schedule/measure_callback/remove_build_artifact.cc index 9242e79912df..da74e85cac07 100644 --- a/src/meta_schedule/measure_callback/remove_build_artifact.cc +++ b/src/meta_schedule/measure_callback/remove_build_artifact.cc @@ -46,7 +46,7 @@ MeasureCallback MeasureCallback::RemoveBuildArtifact() { } TVM_REGISTER_NODE_TYPE(RemoveBuildArtifactNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackRemoveBuildArtifact") .set_body_typed(MeasureCallback::RemoveBuildArtifact); } // namespace meta_schedule diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 63c32b189eee..1969d7fc83a9 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -63,7 +63,7 @@ MeasureCallback MeasureCallback::UpdateCostModel() { } TVM_REGISTER_NODE_TYPE(UpdateCostModelNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCallbackUpdateCostModel") .set_body_typed(MeasureCallback::UpdateCostModel); } // namespace meta_schedule diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 98926c340518..8f8c077aa815 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -115,7 +115,7 @@ std::vector MutateComputeLocationNode::Fin Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector candidates = FindCandidates(trace, rand_state); if (candidates.empty()) { - return NullOpt; + return std::nullopt; } const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; @@ -127,7 +127,7 @@ Mutator Mutator::MutateComputeLocation() { } TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") .set_body_typed(Mutator::MutateComputeLocation); } // namespace meta_schedule diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index e50e7410f6a0..a6a34e47a9d9 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -250,7 +250,7 @@ Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_s // Step 1. Find a parallel decision. Candidate candidate; if (!FindParallelDecision(trace, rand_state, &candidate)) { - return NullOpt; + return std::nullopt; } // Step 2. Replay the instructions to recover loop extents tir::Schedule sch = tir::Schedule::Traced( // @@ -283,7 +283,7 @@ Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_s // Step 5. Pick a new plan int n_plans = plan2limit.size(); if (n_plans == 0) { - return NullOpt; + return std::nullopt; } it = plan2limit.begin(); for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { @@ -312,7 +312,8 @@ Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { } TVM_REGISTER_NODE_TYPE(MutateParallelNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel") + .set_body_typed(Mutator::MutateParallel); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 89671c50ccc0..269b05240443 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -150,7 +150,7 @@ std::vector MutateThreadBindingNode::FindCan Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { std::vector candidates = FindCandidates(trace, rand_state); if (candidates.empty()) { - return NullOpt; + return std::nullopt; } Candidate candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; // Remove the current decision @@ -165,7 +165,7 @@ Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* r Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") .set_body_typed(Mutator::MutateThreadBinding); } // namespace meta_schedule diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index f06736abfb90..e8a728d05033 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -216,7 +216,7 @@ Optional MutateSampleTileSize(const Trace& trace, Instruction inst, } if (max_factor_index == 0) { if (n_splits <= 2) { - return NullOpt; + return std::nullopt; } // Failed on this dst_idx, try next one. continue; @@ -253,7 +253,7 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s int size_a = sample_perfect_tile_insts.size(); int size_b = sample_vectorize_insts.size(); if (size_a == 0 && size_b == 0) { - return NullOpt; + return std::nullopt; } int n = tir::SampleInt(rand_state, 0, size_a + size_b); if (n < size_a) { @@ -269,7 +269,8 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize") + .set_body_typed(Mutator::MutateTileSize); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index a3560943db04..28fcf3668f27 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -122,10 +122,10 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { Candidate candidate; if (!FindUnrollDecision(trace, rand_state, &candidate)) { - return NullOpt; + return std::nullopt; } if (candidate.probs.size() == 0) { - return NullOpt; + return std::nullopt; } candidate.probs.erase(candidate.probs.begin() + candidate.decision); int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); @@ -138,7 +138,7 @@ Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_sta Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } TVM_REGISTER_NODE_TYPE(MutateUnrollNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index e1831d213d1e..e415b3909f10 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -88,20 +88,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(MutatorNode); TVM_REGISTER_NODE_TYPE(PyMutatorNode); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") .set_body_method(&MutatorNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorApply") .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional { TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); return self->Apply(trace, &seed_); }); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method(&MutatorNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore") .set_body_typed(Mutator::DefaultCUDATensorCore); -TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon").set_body_typed(Mutator::DefaultHexagon); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon") + .set_body_typed(Mutator::DefaultHexagon); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index a6a71202ae15..01a75a5bfb36 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -150,8 +150,8 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::StorageRewrite()); - tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", - runtime::String(g_var->name_hint)); + tir::PrimFunc f = + WithAttr(GetRef(prim_func), "global_symbol", String(g_var->name_hint)); IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const dmlc::Error& e) { @@ -184,7 +184,7 @@ Postproc Postproc::DisallowAsyncStridedMemCopy() { } TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy") .set_body_typed(Postproc::DisallowAsyncStridedMemCopy); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc index 8362da552ea5..fd099ac5dd38 100644 --- a/src/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -83,7 +83,7 @@ Postproc Postproc::DisallowDynamicLoop() { } TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") .set_body_typed(Postproc::DisallowDynamicLoop); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index b5d62634c23f..e29f9dd54c5a 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -112,16 +112,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(PostprocNode); TVM_REGISTER_NODE_TYPE(PyPostprocNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") .set_body_method(&PostprocNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocClone").set_body_method(&PostprocNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultLLVM").set_body_typed(Postproc::DefaultLLVM); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDA").set_body_typed(Postproc::DefaultCUDA); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultCUDATensorCore") .set_body_typed(Postproc::DefaultCUDATensorCore); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocDefaultHexagon") .set_body_typed(Postproc::DefaultHexagon); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 353b90c36423..d23e07795cad 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -26,18 +26,18 @@ namespace tir { * \param sch The schedule * \param inst The instruction to be parsed * \param axis The axis name expected - * \return NullOpt if parsing fails; Otherwise, the extent of thread axis + * \return std::nullopt if parsing fails; Otherwise, the extent of thread axis */ Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) { static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); if (!inst->kind.same_as(inst_kind_bind)) { - return NullOpt; + return std::nullopt; } ICHECK_EQ(inst->inputs.size(), 1); ICHECK_EQ(inst->attrs.size(), 1); String thread_axis = Downcast(inst->attrs[0]); if (thread_axis != axis) { - return NullOpt; + return std::nullopt; } return Downcast(sch->Get(Downcast(inst->inputs[0]))->extent); } @@ -47,19 +47,19 @@ Optional ParseThreadBinding(const Schedule& sch, const Instruction& ins * \param sch The schedule * \param inst The instruction to be parsed * \param vector_lane The number of vector lane in vectorized cooperative fetching - * \return NullOpt if parsing fails; Otherwise, the annotated block + * \return std::nullopt if parsing fails; Otherwise, the annotated block */ Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, int64_t* vector_lane) { static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); if (!inst->kind.same_as(inst_kind_annotate)) { - return NullOpt; + return std::nullopt; } ICHECK_EQ(inst->inputs.size(), 2); ICHECK_EQ(inst->attrs.size(), 1); String ann_key = Downcast(inst->attrs[0]); if (ann_key != attr::meta_schedule_cooperative_fetch) { - return NullOpt; + return std::nullopt; } *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; return Downcast(inst->inputs[0]); @@ -186,7 +186,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } if (thread_extent_y != -1) { if (vector_lane > 1) { - Array split = sch->Split(fused, {NullOpt, // + Array split = sch->Split(fused, {std::nullopt, // Integer(thread_extent_y), // Integer(thread_extent_x), // Integer(vector_lane)}); @@ -194,7 +194,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); } else { - Array split = sch->Split(fused, {NullOpt, // + Array split = sch->Split(fused, {std::nullopt, // Integer(thread_extent_y), // Integer(thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); @@ -202,13 +202,13 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } } else { if (vector_lane > 1) { - Array split = sch->Split(fused, {NullOpt, // + Array split = sch->Split(fused, {std::nullopt, // Integer(thread_extent_x), // Integer(vector_lane)}); sch->Vectorize(split[2]); sch->Bind(split[1], "threadIdx.x"); } else { - Array split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)}); + Array split = sch->Split(fused, {std::nullopt, Integer(thread_extent_x)}); sch->Bind(split[1], "threadIdx.x"); } } @@ -227,7 +227,7 @@ Postproc Postproc::RewriteCooperativeFetch() { } TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") .set_body_typed(Postproc::RewriteCooperativeFetch); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 6c8079fdda07..84dc33ec98c8 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -206,7 +206,7 @@ bool RewriteLayout(const Schedule& sch) { auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name); add_layout_rewrite_block(anchor_block_rv, buffer_index); sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map, - NullOpt); + std::nullopt); } else { // When the layout-free buffer is consumed by cache_read, we need to find the index map // for a cache-read buffer that is directly consumed by an anchor op. The last buffer @@ -219,7 +219,7 @@ bool RewriteLayout(const Schedule& sch) { auto [anchor_block, buffer_index, index_map] = *tup_opt; // Transform the layout of the last cache-read buffer. sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index, - BufferIndexType::kRead, index_map, NullOpt); + BufferIndexType::kRead, index_map, std::nullopt); // Propagate the layout transformation over cache_read_chain, starting from // the next-to-last cache-read buffer. @@ -231,7 +231,8 @@ bool RewriteLayout(const Schedule& sch) { // transformed by TransformLayout below. add_layout_rewrite_block(cache_read_block_rv, 0); } - sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt); + sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, + std::nullopt); } } } @@ -272,7 +273,8 @@ Postproc Postproc::RewriteLayout() { } TVM_REGISTER_NODE_TYPE(RewriteLayoutNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout").set_body_typed(Postproc::RewriteLayout); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteLayout") + .set_body_typed(Postproc::RewriteLayout); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 965cc8baefd6..3f665cd8d82a 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -106,22 +106,22 @@ bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { for (const auto& ann : block->annotations) { if (ann.first == attr::meta_schedule_parallel) { found = true; - if (auto opt_int_imm = ann.second.as()) { + if (auto opt_int_imm = ann.second.try_cast()) { parsed->max_parallel_extent = (*opt_int_imm)->value; } } else if (ann.first == attr::meta_schedule_vectorize) { found = true; - if (auto opt_int_imm = ann.second.as()) { + if (auto opt_int_imm = ann.second.try_cast()) { parsed->max_vectorize_extent = (*opt_int_imm)->value; } } else if (ann.first == attr::meta_schedule_unroll_explicit) { found = true; - if (auto opt_int_imm = ann.second.as()) { + if (auto opt_int_imm = ann.second.try_cast()) { parsed->unroll_explicit = (*opt_int_imm)->value; } } else if (ann.first == attr::meta_schedule_unroll_implicit) { found = true; - if (auto opt_int_imm = ann.second.as()) { + if (auto opt_int_imm = ann.second.try_cast()) { parsed->unroll_implicit = (*opt_int_imm)->value; } } @@ -358,7 +358,7 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_rvs, int vec_len) { size_t n_loops = loop_rvs->size(); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); - Array split = sch->Split(fused, {NullOpt, Integer(vec_len)}); + Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); ICHECK_EQ(split.size(), 2); const LoopRV& outer = split[0]; const LoopRV& inner = split[1]; @@ -464,7 +464,7 @@ Postproc Postproc::RewriteParallelVectorizeUnroll() { } TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 05a7640f047c..3ffe0f9234d2 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -149,7 +149,7 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize_init); // The annotation of tensorization of the init statement should be moved to the init block // after 'DecomposeReduction'. - // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt. + // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is std::nullopt. sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize, tensorize_init.value_or("")); if (tensorize_init.defined()) { @@ -172,7 +172,7 @@ Postproc Postproc::RewriteReductionBlock() { } TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") .set_body_typed(Postproc::RewriteReductionBlock); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 4f8e0fb213f8..0f98484dd44e 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -107,7 +107,7 @@ Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { } TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") .set_body_typed(Postproc::RewriteTensorize); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 27ce34a8cb27..a2c9d1364ab6 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -146,7 +146,7 @@ Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { } TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") .set_body_typed(Postproc::RewriteUnboundBlock); } // namespace meta_schedule diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index b0d2c65edc85..8ffc424e4451 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -179,8 +179,8 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::LowerIntrin()); // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", - runtime::String(g_var->name_hint)); + tir::PrimFunc f = + WithAttr(GetRef(prim_func), "global_symbol", String(g_var->name_hint)); f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin bool noalias = pass_ctx->GetConfig("tir.noalias", true).value(); if (noalias) { @@ -215,7 +215,8 @@ Postproc Postproc::VerifyGPUCode() { } TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode") + .set_body_typed(Postproc::VerifyGPUCode); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index 4de975089653..7da2f8546b9e 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -69,7 +69,7 @@ Postproc Postproc::VerifyVTCMLimit() { } TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode); -TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit") .set_body_typed(Postproc::VerifyVTCMLimit); } // namespace meta_schedule diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index b7f261519fd2..2a034a7be297 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -114,24 +114,24 @@ void Profiler::ExitWithScope() { Optional Profiler::Current() { std::vector* profilers = ThreadLocalProfilers(); if (profilers->empty()) { - return NullOpt; + return std::nullopt; } else { return profilers->back(); } } TVM_REGISTER_NODE_TYPE(ProfilerNode); -TVM_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { +TVM_FFI_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { return Profiler(); }); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerEnterWithScope") .set_body_method(&Profiler::EnterWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerExitWithScope") .set_body_method(&Profiler::ExitWithScope); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method(&ProfilerNode::Get); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method(&ProfilerNode::Table); -TVM_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerCurrent").set_body_typed(Profiler::Current); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method(&ProfilerNode::Get); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerTable").set_body_method(&ProfilerNode::Table); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ProfilerTimedScope").set_body_typed(ProfilerTimedScope); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 0d5edf299e43..38d4225f0fbd 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -56,24 +56,25 @@ TVM_REGISTER_NODE_TYPE(RunnerResultNode); TVM_REGISTER_NODE_TYPE(RunnerFutureNode); TVM_REGISTER_OBJECT_TYPE(RunnerNode); TVM_REGISTER_NODE_TYPE(PyRunnerNode); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerInput") .set_body_typed([](String artifact_path, String device_type, Array args_info) -> RunnerInput { return RunnerInput(artifact_path, device_type, args_info); }); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerResult") .set_body_typed([](Optional> run_secs, Optional error_msg) -> RunnerResult { return RunnerResult(run_secs, error_msg); }); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFuture") .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { return RunnerFuture(f_done, f_result); }); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone").set_body_method(&RunnerFutureNode::Done); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult").set_body_method(&RunnerFutureNode::Result); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); -TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone").set_body_method(&RunnerFutureNode::Done); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult") + .set_body_method(&RunnerFutureNode::Result); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule/cpu/winograd.cc b/src/meta_schedule/schedule/cpu/winograd.cc index 16e53b56923a..4e09fa729b3c 100644 --- a/src/meta_schedule/schedule/cpu/winograd.cc +++ b/src/meta_schedule/schedule/cpu/winograd.cc @@ -59,7 +59,7 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); @@ -71,14 +71,14 @@ TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nhwc_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV block) -> Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {2, 3}, {0, 1, 4, 5}); return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); @@ -90,7 +90,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cpu.conv2d_nchw_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV block) -> Array { GetWinogradProducerAndInlineConst(sch, block); ScheduleDataPack(sch, block, {0, 1}, {2, 3, 4, 5}); diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index ee70ea8c717b..287f764a4640 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -67,13 +67,13 @@ Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblock get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); } ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); - Array splits = sch->Split(loop, {NullOpt, factor}); + Array splits = sch->Split(loop, {std::nullopt, factor}); ICHECK_EQ(splits.size(), 2); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); return {splits[0], splits[1]}; } else { - Array splits = sch->Split(loop, {NullOpt, + Array splits = sch->Split(loop, {std::nullopt, Integer(max_threadblocks), // Integer(max_threads_per_block)}); ICHECK_EQ(splits.size(), 3); diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 59ed7bdc009a..c80141f5288d 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -63,7 +63,7 @@ static Array ScheduleDataPack(tir::Schedule sch, tir::BlockRV block return {t0[0], t1[0], t0[1], t1[1]}; } -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetWinogradProducerAndInlineConst(sch, data_pack); BlockRV data_pad = GetWinogradProducerAndInlineConst(sch, input_tile); @@ -88,7 +88,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { GetWinogradProducerAndInlineConst(sch, inverse); ScheduleDataPack(sch, inverse, /*tiled=*/{2, 3}, /*unrolled=*/{0, 1, 4, 5}); @@ -101,7 +101,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nhwc_winograd_inverse") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; @@ -132,7 +132,7 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") return {sch}; }); -TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") .set_body_typed([](Schedule sch, BlockRV inverse) -> Array { GetWinogradProducerAndInlineConst(sch, inverse); // loops on top of the inverse block: [CO, P, tile_size, tile_size, alpha, alpha] @@ -142,8 +142,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_inverse") BlockRV output = sch->GetConsumers(inverse)[0]; Array nchw = sch->GetLoops(output); ICHECK_EQ(nchw.size(), 4); - Array hs = sch->Split(nchw[2], {NullOpt, Integer(tile_size)}); - Array ws = sch->Split(nchw[3], {NullOpt, Integer(tile_size)}); + Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); + Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); outer = ws[0]; } diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 2fc1352677cb..48149ed871e4 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -120,7 +120,7 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: } TVM_REGISTER_NODE_TYPE(AddRFactorNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") .set_body_typed(ScheduleRule::AddRFactor); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 15962baa927a..92de19163af5 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -69,7 +69,7 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { } public: - Optional target_ = NullOpt; + Optional target_ = std::nullopt; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target_", &target_); } @@ -87,7 +87,7 @@ bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { } TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") .set_body_typed(ScheduleRule::ApplyCustomRule); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index fa47d1edb860..892a79ea926d 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -82,7 +82,8 @@ ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_ } TVM_REGISTER_NODE_TYPE(AutoBindNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind").set_body_typed(ScheduleRule::AutoBind); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind") + .set_body_typed(ScheduleRule::AutoBind); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index d9e033eff810..948632e580e6 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -191,7 +191,7 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // } TVM_REGISTER_NODE_TYPE(AutoInlineNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") .set_body_typed(ScheduleRule::AutoInline); /*! \brief Inline blocks that produce a constant scalar. */ @@ -232,7 +232,7 @@ ScheduleRule ScheduleRule::InlineConstantScalars() { } TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") .set_body_typed(ScheduleRule::InlineConstantScalars); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index aa34b6467ab4..e06817e37c4c 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -86,7 +86,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. if (!InThreadScope(tmp_sch, target_block)) { const Array& split_res = - tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent}); + tmp_sch->Split(tgt_block_innermost_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); if (tgt_block_innermost_loop.same_as(target_loop)) { target_loop = split_res[0]; @@ -107,7 +107,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. const Array& split_res = - tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent}); + tmp_sch->Split(fused_reduce_loop, {std::nullopt, thread_extent}); tmp_sch->Bind(split_res[1], "threadIdx.x"); return {tmp_sch, sch}; @@ -291,7 +291,7 @@ ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { } TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") .set_body_typed(ScheduleRule::CrossThreadReduction); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 79cff3bad738..f020c8efd08a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -405,7 +405,7 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional TransformWithTensorIntrin(TensorCoreStateNode* state, const String& intrin_name) const; @@ -502,7 +502,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa return result; }); sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map, - /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); + /*pad_value=*/std::nullopt, /*assume_injective_transform=*/true); return {state}; } @@ -759,7 +759,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( const tir::Block block_before_reindex = GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { // only matmul-like computation is allowed - return NullOpt; + return std::nullopt; } state->tensor_core_reindex_store = state->sch->ReIndex(state->block_rv, 0, tir::BufferIndexType::kWrite); @@ -840,7 +840,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, - /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); + /*pad_value=*/std::nullopt, /*assume_injective_transform=*/true); }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { @@ -923,7 +923,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( } TVM_REGISTER_NODE_TYPE(MultiLevelTilingTensorCoreNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingTensorCore") .set_body_typed(ScheduleRule::MultiLevelTilingTensorCore); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index deceaa6f2c93..0da8ee35cf76 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -92,7 +92,7 @@ std::pair, Array> MultiLevelTilingWideVectorNode const int64_t* extent_int = tir::GetLoopIntExtent(loop); if (extent_int && *extent_int > vec_len) { Array inner_splits = sch->Split(/*loop=*/loop_rv, - /*factors=*/{NullOpt, PrimExpr(vec_len)}); + /*factors=*/{std::nullopt, PrimExpr(vec_len)}); Array outer_factors = sch->SamplePerfectTile( /*loop=*/inner_splits[0], /*n=*/n_tiles - 1, @@ -118,13 +118,13 @@ ScheduleRule ScheduleRule::MultiLevelTilingWideVector(String structure, Optional> reuse_read, Optional> reuse_write) { auto node = MultiLevelTilingInitCommon( - structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write); + structure, std::nullopt, max_innermost_factor, std::nullopt, reuse_read, reuse_write); node->vector_length_in_bits = vector_length_in_bits->value; return ScheduleRule(node); } TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector") .set_body_typed(ScheduleRule::MultiLevelTilingWideVector); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index a8563750c834..731860e8d6f0 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -33,7 +33,7 @@ Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name) { Optional tiled_loop_rv = TileWithTensorIntrin(sch, block, intrin_name); if (!tiled_loop_rv) { - return NullOpt; + return std::nullopt; } ICHECK(tiled_loop_rv.defined()); tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); @@ -106,7 +106,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(String intrin_name, String } TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 54091cfd73e1..905f8d8ce65f 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -134,7 +134,7 @@ ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, } TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); } // namespace meta_schedule diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 7796eddd44d3..ed71baade06a 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -122,7 +122,7 @@ ScheduleRule ScheduleRule::RandomComputeLocation() { } TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") .set_body_typed(ScheduleRule::RandomComputeLocation); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index ec3c6ac480d0..3640694b4e5f 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -68,10 +68,10 @@ Array ScheduleRule::DefaultLLVM() { /*max_innermost_factor=*/Integer(64)), ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(64), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -105,20 +105,20 @@ Array ScheduleRule::DefaultX86(const String& type) { ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/intrins[type], /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(64), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, {"scope", String("global")}}), ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(64), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -289,7 +289,7 @@ Array ScheduleRule::DefaultHexagon() { /*structure=*/"SRSRS", /*vector_length_in_bits=*/1024, /*max_innermost_factor=*/Integer(128), - /*reuse_read=*/NullOpt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -307,10 +307,10 @@ Array GetARMNeonSpecificRules() { ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/String("dot_4x4_i8i8s32_neon"), /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -323,10 +323,10 @@ Array GetARMDotprodSpecificRules() { ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/String("dot_4x4_i8i8s32_sdot"), /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -334,10 +334,10 @@ Array GetARMDotprodSpecificRules() { ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/String("dot_4x4_u8u8u32_udot"), /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -345,10 +345,10 @@ Array GetARMDotprodSpecificRules() { ScheduleRule::MultiLevelTilingWithIntrin( /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"), /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -374,10 +374,10 @@ Array ScheduleRule::DefaultARM(const String& type) { "dotprod" == type ? GetARMDotprodSpecificRules() : Array{}, ScheduleRule::MultiLevelTiling( /*structure=*/"SSRSRS", - /*tile_binds=*/NullOpt, + /*tile_binds=*/std::nullopt, /*max_innermost_factor=*/Integer(32), - /*vector_load_lens=*/NullOpt, - /*reuse_read=*/NullOpt, + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, /*reuse_write=*/ Map{{"req", String("may")}, {"levels", Array{1, 2}}, @@ -402,21 +402,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply").set_body_method(&ScheduleRuleNode::Apply); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone").set_body_method(&ScheduleRuleNode::Clone); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") + .set_body_method(&ScheduleRuleNode::Apply); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleClone") + .set_body_method(&ScheduleRuleNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") .set_body_typed(ScheduleRule::PyScheduleRule); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultLLVM") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultLLVM") .set_body_typed(ScheduleRule::DefaultLLVM); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDA") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDA") .set_body_typed(ScheduleRule::DefaultCUDA); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultCUDATensorCore") .set_body_typed(ScheduleRule::DefaultCUDATensorCore); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon") .set_body_typed(ScheduleRule::DefaultHexagon); -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM") .set_body_typed(ScheduleRule::DefaultARM); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index eea28fb0822b..4872f3aa5f6e 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -139,7 +139,7 @@ struct PerThreadData { TRandState* rand_state) { std::vector> mutators; std::vector masses; - mutators.push_back(NullOpt); + mutators.push_back(std::nullopt); masses.push_back(1.0 - genetic_mutate_prob); double total_mass_mutator = 0.0; if (genetic_mutate_prob > 0) { @@ -592,6 +592,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( // Loop until success for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { sampled_trace_id = trace_sampler(); + sampled_trace_id = sampled_trace_id % self->population_size; tir::Trace trace = population.at(sampled_trace_id)->trace().value(); if (Optional opt_mutator = mutator_sampler()) { // Decision: mutate @@ -698,7 +699,7 @@ std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { - return NullOpt; + return std::nullopt; } int sample_num = num_trials_per_iter; if (ed > max_trials) { @@ -718,7 +719,7 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure if (static_cast(unmeasured.size()) < self->init_min_unmeasured) { TVM_PY_LOG(WARNING, self->ctx_->logger) << "Cannot sample enough initial population, evolutionary search failed."; - return NullOpt; + return std::nullopt; } TVM_PY_LOG(INFO, self->ctx_->logger) << "Sampled " << unmeasured.size() << " candidate(s)"; inits.insert(inits.end(), measured.begin(), measured.end()); @@ -732,7 +733,7 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure if (picks.empty()) { ++this->num_empty_iters; if (this->num_empty_iters >= self->num_empty_iters_before_early_stop) { - return NullOpt; + return std::nullopt; } } return AssembleCandidates(picks); @@ -801,11 +802,11 @@ Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, } TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") .set_body_typed(SearchStrategy::EvolutionarySearch); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation") .set_body_typed(EvolutionarySearchSampleInitPopulation); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel") .set_body_typed(EvolutionarySearchEvolveWithCostModel); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 7bb4a02ab299..51cc40839195 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -54,9 +54,9 @@ class ReplayFuncNode : public SearchStrategyNode { /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The IRModule to be scheduled from TuneContext. */ - Optional mod_ = NullOpt; + Optional mod_ = std::nullopt; /*! \brief The space generator from TuneContext. */ - Optional space_generator_ = NullOpt; + Optional space_generator_ = std::nullopt; /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; @@ -108,8 +108,8 @@ class ReplayFuncNode : public SearchStrategyNode { SearchStrategy Clone() const final { ObjectPtr n = make_object(); n->rand_state_ = -1; - n->mod_ = NullOpt; - n->space_generator_ = NullOpt; + n->mod_ = std::nullopt; + n->space_generator_ = std::nullopt; n->state_ = nullptr; return SearchStrategy(n); } @@ -117,7 +117,7 @@ class ReplayFuncNode : public SearchStrategyNode { inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { - return NullOpt; + return std::nullopt; } ed = std::min(ed, max_trials); Array result; @@ -157,7 +157,7 @@ SearchStrategy SearchStrategy::ReplayFunc() { } TVM_REGISTER_NODE_TYPE(ReplayFuncNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") .set_body_typed(SearchStrategy::ReplayFunc); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index dc9063ee1083..c9a7459fdf61 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -67,7 +67,7 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; /*! \brief The IRModule to be scheduled from TuneContext. */ - Optional mod_ = NullOpt; + Optional mod_ = std::nullopt; /*! \brief The number of threads to be used. */ int num_threads_ = -1; /*! \brief The postprocessors. */ @@ -145,7 +145,7 @@ class ReplayTraceNode : public SearchStrategyNode { inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { if (st >= max_trials) { - return NullOpt; + return std::nullopt; } ed = std::min(ed, max_trials); ICHECK_LT(st, ed); @@ -191,7 +191,7 @@ SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 1bc71502ad36..8fc6538b59f5 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -86,23 +86,23 @@ TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); -TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") .set_body_typed([](tir::Schedule sch, Optional> args_info) -> MeasureCandidate { return MeasureCandidate(sch, args_info.value_or({})); }); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") .set_body_typed(SearchStrategy::PySearchStrategy); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") .set_body_method(&SearchStrategyNode::PreTuning); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") .set_body_method(&SearchStrategyNode::PostTuning); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") .set_body_method(&SearchStrategyNode::NotifyRunnerResults); -TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyClone") .set_body_method(&SearchStrategyNode::Clone); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index da2178c736a1..91d5ba53d551 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -116,7 +116,7 @@ SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, } TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") .set_body_typed(SpaceGenerator::PostOrderApply); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index 89a02876f3d9..f7f2a3ba19de 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -97,7 +97,7 @@ SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, } TVM_REGISTER_NODE_TYPE(ScheduleFnNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") .set_body_typed(SpaceGenerator::ScheduleFn); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 8712f5ad4892..7306fffcb1af 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -190,13 +190,13 @@ SpaceGenerator SpaceGenerator::PySpaceGenerator( TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") .set_body_method(&SpaceGeneratorNode::InitializeWithTuneContext); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") .set_body_typed(SpaceGenerator::PySpaceGenerator); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone") .set_body_method(&SpaceGeneratorNode::Clone); } // namespace meta_schedule diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 819a4ee5f795..12bf75349430 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -82,7 +82,7 @@ SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_g } TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); -TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") .set_body_typed(SpaceGenerator::SpaceGeneratorUnion); } // namespace meta_schedule diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index c750067ace9f..23d23e624394 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -145,7 +145,7 @@ TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, i } TVM_REGISTER_NODE_TYPE(GradientBasedNode); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") .set_body_typed(TaskScheduler::GradientBased); } // namespace meta_schedule diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index d7c6f37e121d..9792fa7e7c25 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -63,7 +63,7 @@ TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { } TVM_REGISTER_NODE_TYPE(RoundRobinNode); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") .set_body_typed(TaskScheduler::RoundRobin); } // namespace meta_schedule diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index ca5c6e4988a3..85a406365377 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -22,7 +22,7 @@ namespace tvm { namespace meta_schedule { TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { - ObjectPtr n = runtime::make_object(); + ObjectPtr n = ffi::make_object(); n->ctx = ctx; n->task_weight = task_weight; n->flop = 1.0; @@ -85,7 +85,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { /*f_done=*/[]() -> bool { return true; }, /*f_result=*/ [msg = builder_result->error_msg]() -> RunnerResult { - return RunnerResult(NullOpt, msg); + return RunnerResult(std::nullopt, msg); })); } else { results.push_back(futures[j++]); @@ -104,7 +104,7 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r const BuilderResult& builder_result = self->builder_results.value()[i]; const MeasureCandidate& candidate = self->measure_candidates.value()[i]; const RunnerResult& runner_result = results[i]; - Optional error_msg = NullOpt; + Optional error_msg = std::nullopt; int trials = self->latency_ms.size() + 1; double run_ms = 1e9; if ((error_msg = builder_result->error_msg)) { @@ -135,9 +135,9 @@ void TaskCleanUp(TaskRecordNode* self, int task_id, const Array& r << ". Best GFLOPs: " << (self->flop / best_ms / 1e6); } } - self->measure_candidates = NullOpt; - self->builder_results = NullOpt; - self->runner_futures = NullOpt; + self->measure_candidates = std::nullopt; + self->builder_results = std::nullopt; + self->runner_futures = std::nullopt; } void TaskSchedulerNode::Tune(Array ctxs, Array task_weights, @@ -364,18 +364,19 @@ void PyTaskSchedulerNode::Tune(Array tasks, Array task_we TVM_REGISTER_NODE_TYPE(TaskRecordNode); TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") .set_body_typed(TaskScheduler::PyTaskScheduler); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune").set_body_method(&TaskSchedulerNode::Tune); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") + .set_body_method(&TaskSchedulerNode::Tune); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") .set_body_method(&TaskSchedulerNode::JoinRunningTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") .set_body_method(&TaskSchedulerNode::NextTaskId); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTerminateTask") .set_body_method(&TaskSchedulerNode::TerminateTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTouchTask") .set_body_method(&TaskSchedulerNode::TouchTask); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPrintTuningStatistics") .set_body_method(&TaskSchedulerNode::PrintTuningStatistics); } // namespace meta_schedule diff --git a/src/meta_schedule/trace_apply.cc b/src/meta_schedule/trace_apply.cc index 5ba3f3123cbb..9d22554d912f 100644 --- a/src/meta_schedule/trace_apply.cc +++ b/src/meta_schedule/trace_apply.cc @@ -254,7 +254,7 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm } } -TVM_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleUsingAnchorTrace") .set_body_typed(ScheduleUsingAnchorTrace); } // namespace meta_schedule diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 275f8d124cd1..31120ce45d4a 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -63,7 +63,7 @@ void TuneContextNode::Initialize() { } TVM_REGISTER_NODE_TYPE(TuneContextNode); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContext") .set_body_typed([](Optional mod, Optional target, Optional space_generator, Optional search_strategy, Optional task_name, @@ -72,10 +72,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads, rand_state, logger); }); -TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") +TVM_FFI_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize") .set_body_method(&TuneContextNode::Initialize); -TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClone").set_body_method(&TuneContextNode::Clone); +TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContextClone").set_body_method(&TuneContextNode::Clone); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index de777e305919..21483d3b98a4 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -37,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -319,7 +319,7 @@ struct ThreadedTraceApply { * \param mod The IRModule to be applied * \param trace The trace to apply to the IRModule * \param rand_state The random seed - * \return The schedule created, or NullOpt if any postprocessor fails + * \return The schedule created, or std::nullopt if any postprocessor fails */ Optional Apply(const IRModule& mod, const tir::Trace& trace, TRandState* rand_state) { @@ -336,7 +336,7 @@ struct ThreadedTraceApply { Item& item = items_[i]; if (!item.postproc->Apply(sch)) { item.fail_counter++; - return NullOpt; + return std::nullopt; } } return sch; @@ -418,15 +418,15 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) { * \return The array of floating point numbers */ inline Array AsFloatArray(const ObjectRef& obj) { - const ArrayObj* arr = obj.as(); + const ffi::ArrayObj* arr = obj.as(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); Array results; results.reserve(arr->size()); for (Any val : *arr) { auto float_value = [&]() -> FloatImm { - if (auto opt_int_imm = val.as()) { + if (auto opt_int_imm = val.try_cast()) { return FloatImm(DataType::Float(32), (*opt_int_imm)->value); - } else if (auto opt_float_imm = val.as()) { + } else if (auto opt_float_imm = val.try_cast()) { return *std::move(opt_float_imm); } else { LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << val.GetTypeKey(); @@ -445,13 +445,13 @@ inline Array AsFloatArray(const ObjectRef& obj) { * \return The array of integers */ inline Array AsIntArray(const ObjectRef& obj) { - const ArrayObj* arr = obj.as(); + const ffi::ArrayObj* arr = obj.as(); ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); Array results; results.reserve(arr->size()); for (Any val : *arr) { auto int_value = [&]() -> int64_t { - if (auto opt_int_imm = val.as()) { + if (auto opt_int_imm = val.try_cast()) { return (*opt_int_imm)->value; } else { LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << val.GetTypeKey(); diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index 9ec39e9f6aae..334c15b3be97 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -24,8 +24,8 @@ #ifndef TVM_NODE_ATTR_REGISTRY_H_ #define TVM_NODE_ATTR_REGISTRY_H_ +#include #include -#include #include #include diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc index 261ae4825a8d..7441db783296 100644 --- a/src/node/container_printing.cc +++ b/src/node/container_printing.cc @@ -21,16 +21,16 @@ * Printer implementation for containers * \file node/container_printint.cc */ +#include #include #include -#include namespace tvm { // Container printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); p->stream << '['; for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { @@ -42,8 +42,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); p->stream << '{'; for (auto it = op->begin(); it != op->end(); ++it) { if (it != op->begin()) { @@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << Downcast(node); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + p->stream << ffi::Downcast(node); }); } // namespace tvm diff --git a/src/node/object_path.cc b/src/node/object_path.cc index fda5d3d9d160..a99835ea17ad 100644 --- a/src/node/object_path.cc +++ b/src/node/object_path.cc @@ -17,10 +17,10 @@ * under the License. */ +#include +#include #include #include -#include -#include #include #include @@ -40,13 +40,13 @@ Optional ObjectPathNode::GetParent() const { return Downcast>(parent_); } -TVM_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_method(&ObjectPathNode::GetParent); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathGetParent").set_body_method(&ObjectPathNode::GetParent); // --- Length --- int32_t ObjectPathNode::Length() const { return length_; } -TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method(&ObjectPathNode::Length); // --- GetPrefix --- @@ -63,7 +63,7 @@ ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { return GetRef(node); } -TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix").set_body_method(&ObjectPathNode::GetPrefix); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathGetPrefix").set_body_method(&ObjectPathNode::GetPrefix); // --- IsPrefixOf --- @@ -75,7 +75,7 @@ bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { return this->PathsEqual(other->GetPrefix(this_len)); } -TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf").set_body_method(&ObjectPathNode::IsPrefixOf); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf").set_body_method(&ObjectPathNode::IsPrefixOf); // --- Attr --- @@ -95,7 +95,7 @@ ObjectPath ObjectPathNode::Attr(Optional attr_key) const { } } -TVM_REGISTER_GLOBAL("node.ObjectPathAttr") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathAttr") .set_body_typed([](const ObjectPath& object_path, Optional attr_key) { return object_path->Attr(attr_key); }); @@ -106,7 +106,7 @@ ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { return ObjectPath(make_object(this, index)); } -TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex").set_body_method(&ObjectPathNode::ArrayIndex); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathArrayIndex").set_body_method(&ObjectPathNode::ArrayIndex); // --- MissingArrayElement --- @@ -114,7 +114,7 @@ ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { return ObjectPath(make_object(this, index)); } -TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement") .set_body_method(&ObjectPathNode::MissingArrayElement); // --- MapValue --- @@ -123,7 +123,7 @@ ObjectPath ObjectPathNode::MapValue(Any key) const { return ObjectPath(make_object(this, std::move(key))); } -TVM_REGISTER_GLOBAL("node.ObjectPathMapValue").set_body_method(&ObjectPathNode::MapValue); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMapValue").set_body_method(&ObjectPathNode::MapValue); // --- MissingMapEntry --- @@ -131,7 +131,7 @@ ObjectPath ObjectPathNode::MissingMapEntry() const { return ObjectPath(make_object(this)); } -TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry") .set_body_method(&ObjectPathNode::MissingMapEntry); // --- PathsEqual ---- @@ -158,7 +158,7 @@ bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { return lhs == nullptr && rhs == nullptr; } -TVM_REGISTER_GLOBAL("node.ObjectPathEqual").set_body_method(&ObjectPathNode::PathsEqual); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathEqual").set_body_method(&ObjectPathNode::PathsEqual); // --- Repr --- @@ -191,7 +191,7 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const { return ObjectPath(make_object(name)); } -TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root); // ============== Individual path classes ============== diff --git a/src/node/reflection.cc b/src/node/reflection.cc index cf9f0dd3bd6e..2290403d3730 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -21,10 +21,10 @@ * Reflection utilities. * \file node/reflection.cc */ +#include #include #include #include -#include namespace tvm { @@ -253,7 +253,7 @@ ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, std::vector packed_args(kwargs.size() * 2); int index = 0; - for (const auto& kv : *static_cast(kwargs.get())) { + for (const auto& kv : *static_cast(kwargs.get())) { packed_args[index] = kv.first.cast().c_str(); packed_args[index + 1] = kv.second; index += 2; @@ -292,11 +292,11 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) { *rv = ReflectionVTable::Global()->CreateObject(type_key, args.Slice(1)); } -TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr); +TVM_FFI_REGISTER_GLOBAL("node.NodeGetAttr").set_body_packed(NodeGetAttr); -TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body_packed(NodeListAttrNames); +TVM_FFI_REGISTER_GLOBAL("node.NodeListAttrNames").set_body_packed(NodeListAttrNames); -TVM_REGISTER_GLOBAL("node.MakeNode").set_body_packed(MakeNode); +TVM_FFI_REGISTER_GLOBAL("node.MakeNode").set_body_packed(MakeNode); namespace { // Attribute visitor class for finding the attribute key by its address @@ -336,7 +336,7 @@ Optional GetAttrKeyByAddress(const Object* object, const void* attr_addr ReflectionVTable::Global()->VisitAttrs(const_cast(object), &visitor); const char* key = visitor.GetKey(); if (key == nullptr) { - return NullOpt; + return std::nullopt; } else { return String(key); } diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index 3e80751e6604..aa999655c03d 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -21,9 +21,9 @@ * Printer utilities * \file node/repr_printer.cc */ +#include #include #include -#include namespace tvm { @@ -133,12 +133,12 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { +TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { std::ostringstream os; os << obj; return os.str(); }); -TVM_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr); +TVM_FFI_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr); } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index ab1be755685f..ee7880f4485a 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -16,10 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include -#include #include @@ -72,13 +72,13 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->module_alias = Downcast(v.value()); } if (auto v = config_dict.Get("buffer_dtype")) { - n->buffer_dtype = DataType(runtime::StringToDLDataType(Downcast(v.value()))); + n->buffer_dtype = DataType(StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("int_dtype")) { - n->int_dtype = DataType(runtime::StringToDLDataType(Downcast(v.value()))); + n->int_dtype = DataType(StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("float_dtype")) { - n->float_dtype = DataType(runtime::StringToDLDataType(Downcast(v.value()))); + n->float_dtype = DataType(StringToDLDataType(Downcast(v.value()))); } if (auto v = config_dict.Get("verbose_expr")) { n->verbose_expr = v.value().cast(); @@ -135,9 +135,9 @@ Array PrinterConfigNode::GetBuiltinKeywords() { } TVM_REGISTER_NODE_TYPE(PrinterConfigNode); -TVM_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { +TVM_FFI_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { return PrinterConfig(config_dict); }); -TVM_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); +TVM_FFI_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); } // namespace tvm diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 96c6696b336e..986a2d044524 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -23,25 +23,23 @@ */ #include #include +#include #include #include #include #include -#include -#include #include #include #include -#include "../runtime/object_internal.h" #include "../support/base64.h" namespace tvm { inline std::string Type2String(const DataType& t) { return runtime::DLDataTypeToString(t); } -inline DataType String2Type(std::string s) { return DataType(runtime::StringToDLDataType(s)); } +inline DataType String2Type(std::string s) { return DataType(ffi::StringToDLDataType(s)); } inline std::string Base64Decode(std::string s) { dmlc::MemoryStringStream mstrm(&s); @@ -111,15 +109,15 @@ class NodeIndexer : public AttrVisitor { return; } MakeNodeIndex(node); - if (auto opt_array = node.as()) { - const ArrayObj* n = opt_array.value(); + if (auto opt_array = node.as()) { + const ffi::ArrayObj* n = opt_array.value(); for (auto elem : *n) { MakeIndex(elem); } - } else if (auto opt_map = node.as()) { - const MapObj* n = opt_map.value(); + } else if (auto opt_map = node.as()) { + const ffi::MapObj* n = opt_map.value(); bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { - return v.first.template as().has_value(); + return v.first.template as(); }); if (is_str_map) { for (const auto& kv : *n) { @@ -272,15 +270,15 @@ class JSONAttrGetter : public AttrVisitor { node_->attrs.clear(); node_->data.clear(); - if (auto opt_array = node.as()) { - const ArrayObj* n = opt_array.value(); + if (auto opt_array = node.as()) { + const ffi::ArrayObj* n = opt_array.value(); for (size_t i = 0; i < n->size(); ++i) { node_->data.push_back(node_index_->at(n->at(i))); } - } else if (auto opt_map = node.as()) { - const MapObj* n = opt_map.value(); + } else if (auto opt_map = node.as()) { + const ffi::MapObj* n = opt_map.value(); bool is_str_map = std::all_of(n->begin(), n->end(), [](const auto& v) { - return v.first.template as().has_value(); + return v.first.template as(); }); if (is_str_map) { for (const auto& kv : *n) { @@ -381,7 +379,7 @@ class FieldDependencyFinder : public AttrVisitor { return; } // Skip containers - if (jnode->type_key == ArrayObj::_type_key || jnode->type_key == MapObj::_type_key) { + if (jnode->type_key == ffi::ArrayObj::_type_key || jnode->type_key == ffi::MapObj::_type_key) { return; } jnode_ = jnode; @@ -518,13 +516,13 @@ class JSONAttrSetter : public AttrVisitor { void SetAttrs(Any* node, JSONNode* jnode) { jnode_ = jnode; // handling Array - if (jnode->type_key == ArrayObj::_type_key) { + if (jnode->type_key == ffi::ArrayObj::_type_key) { Array result; for (auto index : jnode->data) { result.push_back(node_list_->at(index)); } *node = result; - } else if (jnode->type_key == MapObj::_type_key) { + } else if (jnode->type_key == ffi::MapObj::_type_key) { Map result; if (jnode->keys.empty()) { ICHECK_EQ(jnode->data.size() % 2, 0U); @@ -701,7 +699,7 @@ Any LoadJSON(std::string json_str) { return nodes.at(jgraph.root); } -TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); +TVM_FFI_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); -TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); +TVM_FFI_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 1a20810e790f..6b19fb5355bb 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,13 +19,13 @@ /*! * \file src/node/structural_equal.cc */ +#include #include #include #include #include #include #include -#include #include #include @@ -36,12 +36,12 @@ namespace tvm { TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode); -TVM_REGISTER_GLOBAL("node.ObjectPathPairLhsPath") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathPairLhsPath") .set_body_typed([](const ObjectPathPair& object_path_pair) { return object_path_pair->lhs_path; }); -TVM_REGISTER_GLOBAL("node.ObjectPathPairRhsPath") +TVM_FFI_REGISTER_GLOBAL("node.ObjectPathPairRhsPath") .set_body_typed([](const ObjectPathPair& object_path_pair) { return object_path_pair->rhs_path; }); @@ -88,7 +88,7 @@ struct SEqualReducer::PathTracingData { bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { if (tracing_data_ == nullptr) { // Fast path: no tracing - return handler_->SEqualReduce(lhs, rhs, map_free_vars_, NullOpt); + return handler_->SEqualReduce(lhs, rhs, map_free_vars_, std::nullopt); } return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr); } @@ -96,7 +96,7 @@ bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { if (tracing_data_ == nullptr) { // Fast path: no tracing - return handler_->SEqualReduce(lhs, rhs, true, NullOpt); + return handler_->SEqualReduce(lhs, rhs, true, std::nullopt); } return ObjectAttrsEqual(lhs, rhs, true, nullptr); } @@ -239,7 +239,7 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair* paths) const { if (tracing_data_ == nullptr) { // Fast path: no tracing - return handler_->SEqualReduce(lhs, rhs, map_free_vars, NullOpt); + return handler_->SEqualReduce(lhs, rhs, map_free_vars, std::nullopt); } // Slow path: tracing object paths for better error reporting @@ -595,7 +595,7 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje return impl->DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); } -TVM_REGISTER_GLOBAL("node.StructuralEqual") +TVM_FFI_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const Any& lhs, const Any& rhs, bool assert_mode, bool map_free_vars) { // If we are asserting on failure, then the `defer_fails` option // should be enabled, to provide better error messages. For @@ -608,7 +608,7 @@ TVM_REGISTER_GLOBAL("node.StructuralEqual") .Equal(lhs, rhs, map_free_vars); }); -TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") +TVM_FFI_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") .set_body_typed([](const Any& lhs, const Any& rhs, bool map_free_vars) { Optional first_mismatch; bool equal = diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 94e676820369..efaa7037b013 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -20,13 +20,13 @@ * \file src/node/structural_hash.cc */ #include +#include #include #include #include #include #include #include -#include #include #include @@ -291,7 +291,7 @@ void SHashHandlerDefault::DispatchSHash(const ObjectRef& key, bool map_free_vars impl->DispatchSHash(key, map_free_vars); } -TVM_REGISTER_GLOBAL("node.StructuralHash") +TVM_FFI_REGISTER_GLOBAL("node.StructuralHash") .set_body_typed([](const Any& object, bool map_free_vars) -> int64_t { uint64_t hashed_value = SHashHandlerDefault().Hash(object, map_free_vars); return static_cast(hashed_value); @@ -315,11 +315,11 @@ void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool map_ struct StringObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { + static void SHashReduce(const ffi::StringObj* key, SHashReducer hash_reduce) { hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); } - static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, + static bool SEqualReduce(const ffi::StringObj* lhs, const ffi::StringObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->size != rhs->size) return false; @@ -350,26 +350,20 @@ struct RefToObjectPtr : public ObjectRef { } }; -TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) - .set_creator([](const std::string& bytes) { - return RefToObjectPtr::Get(runtime::String(bytes)); - }) +TVM_REGISTER_REFLECTION_VTABLE(ffi::StringObj, StringObjTrait) + .set_creator([](const std::string& bytes) { return RefToObjectPtr::Get(String(bytes)); }) .set_repr_bytes([](const Object* n) -> std::string { - return GetRef(static_cast(n)) - . - operator std::string(); + return GetRef(static_cast(n)).operator std::string(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; }); TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj, BytesObjTrait) - .set_creator([](const std::string& bytes) { - return RefToObjectPtr::Get(runtime::String(bytes)); - }) + .set_creator([](const std::string& bytes) { return RefToObjectPtr::Get(String(bytes)); }) .set_repr_bytes([](const Object* n) -> std::string { return GetRef(static_cast(n)).operator std::string(); }); @@ -439,14 +433,15 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrai struct ArrayObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ArrayObj* key, SHashReducer hash_reduce) { + static void SHashReduce(const ffi::ArrayObj* key, SHashReducer hash_reduce) { hash_reduce(static_cast(key->size())); for (uint32_t i = 0; i < key->size(); ++i) { hash_reduce(key->at(i)); } } - static bool SEqualReduce(const ArrayObj* lhs, const ArrayObj* rhs, SEqualReducer equal) { + static bool SEqualReduce(const ffi::ArrayObj* lhs, const ffi::ArrayObj* rhs, + SEqualReducer equal) { if (equal.IsPathTracingEnabled()) { return SEqualReduceTraced(lhs, rhs, equal); } @@ -459,7 +454,7 @@ struct ArrayObjTrait { } private: - static bool SEqualReduceTraced(const ArrayObj* lhs, const ArrayObj* rhs, + static bool SEqualReduceTraced(const ffi::ArrayObj* lhs, const ffi::ArrayObj* rhs, const SEqualReducer& equal) { uint32_t min_size = std::min(lhs->size(), rhs->size()); const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths(); @@ -517,22 +512,22 @@ struct ArrayObjTrait { return false; } }; -TVM_REGISTER_REFLECTION_VTABLE(ArrayObj, ArrayObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::ArrayObj, ArrayObjTrait) .set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); + return ffi::make_object(); }); -struct ShapeTupleObjTrait { +struct ShapeObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) { + static void SHashReduce(const ffi::ShapeObj* self, SHashReducer hash_reduce) { hash_reduce(static_cast(self->size)); for (uint32_t i = 0; i < self->size; ++i) { hash_reduce(self->data[i]); } } - static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs, + static bool SEqualReduce(const ffi::ShapeObj* lhs, const ffi::ShapeObj* rhs, SEqualReducer equal) { if (lhs->size != rhs->size) return false; for (uint32_t i = 0; i < lhs->size; ++i) { @@ -542,7 +537,7 @@ struct ShapeTupleObjTrait { } }; -TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait) +TVM_REGISTER_REFLECTION_VTABLE(ffi::ShapeObj, ShapeObjTrait) .set_creator([](const std::string& blob) { // Store shape tuple in blob to avoid large integer overflow in JSON. dmlc::MemoryStringStream mstrm(const_cast(&blob)); @@ -552,14 +547,14 @@ TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait) b64strm.Read(&size); std::vector data(size); b64strm.ReadArray(data.data(), size); - ShapeTuple shape(data); + ffi::Shape shape(data); return RefToObjectPtr::Get(shape); }) .set_repr_bytes([](const Object* n) -> std::string { std::string blob; dmlc::MemoryStringStream mstrm(&blob); support::Base64OutStream b64strm(&mstrm); - const auto* shape = static_cast(n); + const auto* shape = static_cast(n); b64strm.Write(shape->size); b64strm.WriteArray(shape->data, shape->size); b64strm.Finish(); @@ -569,7 +564,7 @@ TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait) struct MapObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduceForOMap(const MapObj* key, SHashReducer hash_reduce) { + static void SHashReduceForOMap(const ffi::MapObj* key, SHashReducer hash_reduce) { // SHash's var handling depends on the determinism of traversal. // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store @@ -605,7 +600,7 @@ struct MapObjTrait { } } - static void SHashReduceForSMap(const MapObj* key, SHashReducer hash_reduce) { + static void SHashReduceForSMap(const ffi::MapObj* key, SHashReducer hash_reduce) { // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store // Map where Var is defined in the function @@ -628,7 +623,7 @@ struct MapObjTrait { } } - static void SHashReduce(const MapObj* key, SHashReducer hash_reduce) { + static void SHashReduce(const ffi::MapObj* key, SHashReducer hash_reduce) { bool is_str_map = std::all_of(key->begin(), key->end(), [](const auto& v) { return v.first.template as(); }); @@ -639,7 +634,8 @@ struct MapObjTrait { } } - static bool SEqualReduceTraced(const MapObj* lhs, const MapObj* rhs, const SEqualReducer& equal) { + static bool SEqualReduceTraced(const ffi::MapObj* lhs, const ffi::MapObj* rhs, + const SEqualReducer& equal) { const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths(); // First, check that every key from `lhs` is also in `rhs`, // and their values are mapped to each other. @@ -678,7 +674,7 @@ struct MapObjTrait { TVM_FFI_UNREACHABLE(); } - static bool SEqualReduce(const MapObj* lhs, const MapObj* rhs, SEqualReducer equal) { + static bool SEqualReduce(const ffi::MapObj* lhs, const ffi::MapObj* rhs, SEqualReducer equal) { if (equal.IsPathTracingEnabled()) { return SEqualReduceTraced(lhs, rhs, equal); } @@ -699,8 +695,8 @@ struct MapObjTrait { return true; } }; -TVM_REGISTER_REFLECTION_VTABLE(MapObj, MapObjTrait) - .set_creator([](const std::string&) -> ObjectPtr { return MapObj::Empty(); }); +TVM_REGISTER_REFLECTION_VTABLE(ffi::MapObj, MapObjTrait) + .set_creator([](const std::string&) -> ObjectPtr { return ffi::MapObj::Empty(); }); struct ReportNodeTrait { static void VisitAttrs(runtime::profiling::ReportNode* report, AttrVisitor* attrs) { diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index a0ddb613d052..98122d1e1ec8 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -178,7 +178,7 @@ Optional FindImpureCall(const Expr& expr, const Optional& own_name) private: const Optional& own_name_; - Optional impure_expr_ = NullOpt; + Optional impure_expr_ = std::nullopt; }; if (own_name) { @@ -197,15 +197,15 @@ bool ContainsImpureCall(const Expr& expr, const Optional& own_name) { return FindImpureCall(expr, own_name).defined(); } -TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); -TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); -TVM_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.contains_impure_call").set_body_typed(ContainsImpureCall); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 37bbf3a9775e..ba163b51d6c9 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -92,7 +92,7 @@ Array ComputableAtCompileTime(const Function& func) { return CompileTimeCollector::Collect(func); } -TVM_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time") .set_body_typed(ComputableAtCompileTime); } // namespace relax diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc index 9c150fed8bfd..48ec7880b172 100644 --- a/src/relax/analysis/detect_recursion.cc +++ b/src/relax/analysis/detect_recursion.cc @@ -392,7 +392,7 @@ tvm::Array> DetectRecursion(const IRModule& m) { return ret; } -TVM_REGISTER_GLOBAL("relax.analysis.detect_recursion").set_body_typed(DetectRecursion); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.detect_recursion").set_body_typed(DetectRecursion); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index f0658dabb398..ab32abab5bea 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -614,7 +614,7 @@ Map> SuggestLayoutTransforms( return analyzer.GetSuggestedTransforms(); } -TVM_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) .set_body_typed([](PrimFunc fn, Array write_buffer_transformations) { return SuggestLayoutTransforms(fn, write_buffer_transformations); }); diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index d44252e86fd2..e09f061001f9 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -72,7 +72,7 @@ class StaticTypeDeriver : public StructInfoFunctor { Type GetStaticType(const StructInfo& info) { return StaticTypeDeriver()(info); } -TVM_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { +TVM_FFI_REGISTER_GLOBAL("relax.analysis.GetStaticType").set_body_typed([](const StructInfo& info) { return GetStaticType(info); }); @@ -270,7 +270,7 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh f_shape_var_map = [&](const tir::Var& var) -> Optional { auto it = shape_var_map.find(var); if (it != shape_var_map.end()) return (*it).second; - return NullOpt; + return std::nullopt; }; } @@ -278,14 +278,14 @@ StructInfo EraseToWellDefined(const StructInfo& info, Map sh f_var_map = [&](const Var& var) -> Optional { auto it = var_map.find(var); if (it != var_map.end()) return (*it).second; - return NullOpt; + return std::nullopt; }; } return EraseToWellDefined(info, f_shape_var_map, f_var_map, ana); } -TVM_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.EraseToWellDefined") .set_body_typed([](const StructInfo& info, Map shape_var_map, Map var_map) { return EraseToWellDefined(info, shape_var_map, var_map); @@ -595,7 +595,7 @@ BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& de } } -TVM_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.StructInfoBaseCheck") .set_body_typed([](const StructInfo& base, const StructInfo& derived) -> int { return static_cast(StructInfoBaseCheck(base, derived)); }); @@ -604,7 +604,7 @@ bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; } -TVM_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") +TVM_FFI_REGISTER_GLOBAL("relax.StructInfoIsBaseOf") .set_body_typed([](const StructInfo& base, const StructInfo& derived) { return IsBaseOf(base, derived); }); @@ -955,7 +955,7 @@ StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call } } -TVM_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.DeriveCallRetStructInfo") .set_body_typed([](const FuncStructInfo& finfo, const Call& call, const BlockBuilder& ctx) { return DeriveCallRetStructInfo(finfo, call, ctx); }); @@ -1143,7 +1143,7 @@ class StructInfoLCAFinder Optional> UnifyArray(const Array& lhs, const Array& rhs) { if (lhs.same_as(rhs)) return lhs; - if (lhs.size() != rhs.size()) return NullOpt; + if (lhs.size() != rhs.size()) return std::nullopt; size_t index = 0; return lhs.Map([&](const StructInfo& a) { return this->VisitStructInfo(a, rhs[index++]); }); } @@ -1158,7 +1158,7 @@ StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::An } } -TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.StructInfoLCA") .set_body_typed([](const StructInfo& lhs, const StructInfo& rhs) { return StructInfoLCA(lhs, rhs); }); @@ -1241,9 +1241,9 @@ Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { return detector.GetTIRVars(); } -TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo); -TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo") .set_body_typed(DefinableTIRVarsInStructInfo); class NonNegativeExpressionCollector : relax::StructInfoVisitor { @@ -1288,7 +1288,7 @@ Array CollectNonNegativeExpressions(const StructInfo& sinfo) { return NonNegativeExpressionCollector::Collect(sinfo); } -TVM_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions") .set_body_typed(CollectNonNegativeExpressions); class SymbolicVarCollector : public relax::ExprVisitor, @@ -1436,9 +1436,9 @@ Array DefinedSymbolicVars(const Expr& expr) { } Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } -TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars); -TVM_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.FreeSymbolicVars").set_body_typed(FreeSymbolicVars); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index fe7b7bbeb547..0845ec092fe2 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -76,7 +76,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { // Step 1. Clear loads and store loads_.clear(); - store_ = NullOpt; + store_ = std::nullopt; // Step 2. Visit block body. StmtVisitor::VisitStmt(op->body); @@ -537,7 +537,7 @@ bool HasReshapePattern(const PrimFunc& func) { return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } -TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index b6b6be4b83e0..f62254b6959d 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -108,18 +108,17 @@ class UDChain : relax::ExprVisitor { } }; -std::pair>, runtime::Array> FunctionUseDef( - const Expr& fn) { +std::pair>, Array> FunctionUseDef(const Expr& fn) { auto usage = UDChain::Collect(fn); return {usage.downstream_usage, usage.outputs}; } -runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { +Map> DataflowBlockUseDef(const DataflowBlock& dfb) { auto usage = UDChain::Collect(SeqExpr({dfb}, Tuple(Array()))); return usage.downstream_usage; } -TVM_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.udchain").set_body_typed(DataflowBlockUseDef); VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); } diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index be50e9bdcef2..a367d33ca4ff 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -25,7 +25,7 @@ namespace tvm { namespace relax { class Var2ValAnalysis : public relax::ExprVisitor { public: - tvm::runtime::Map var2value_; + Map var2value_; void VisitBinding_(const VarBindingNode* binding) override { var2value_.Set(binding->var, binding->value); // Recursively visit the value to handle local functions. @@ -33,19 +33,19 @@ class Var2ValAnalysis : public relax::ExprVisitor { } }; -tvm::runtime::Map AnalyzeVar2Value(const Expr& expr) { +Map AnalyzeVar2Value(const Expr& expr) { Var2ValAnalysis var2val_analysis; var2val_analysis.VisitExpr(expr); return std::move(var2val_analysis.var2value_); } -tvm::runtime::Map AnalyzeVar2Value(const DataflowBlock& dfb) { +Map AnalyzeVar2Value(const DataflowBlock& dfb) { Var2ValAnalysis var2val_analysis; var2val_analysis.VisitBindingBlock_(dfb.get()); return std::move(var2val_analysis.var2value_); } -tvm::runtime::Map AnalyzeVar2Value(const IRModule& m) { +Map AnalyzeVar2Value(const IRModule& m) { Var2ValAnalysis var2val_analysis; for (const auto& it : m->functions) { @@ -58,13 +58,13 @@ tvm::runtime::Map AnalyzeVar2Value(const IRModule& m) { return std::move(var2val_analysis.var2value_); } -TVM_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.get_var2val")).set_body_typed([](const Function& f) { return AnalyzeVar2Value(f); }); class Name2BindingAnalysis : public relax::ExprVisitor { public: - // runtime::Map is not suitable for doing in-place update. + // Map is not suitable for doing in-place update. // so we use standard container for internal usage. std::map> name2bindings_; void VisitBinding_(const VarBindingNode* binding) override { @@ -85,7 +85,7 @@ Map> NameToBinding(const Function& fn) { std::make_move_iterator(analysis.name2bindings_.end())); } -TVM_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.name_to_binding")).set_body_typed(NameToBinding); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 1ea15b38343d..243033e9454b 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -332,7 +332,7 @@ class WellFormedChecker : public relax::ExprVisitor, if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) { auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); Call before_normalize = GetRef(call); - Optional after_normalize = NullOpt; + Optional after_normalize = std::nullopt; try { after_normalize = func_normalize(dummy_builder, before_normalize); } catch (std::exception& err) { @@ -369,7 +369,7 @@ class WellFormedChecker : public relax::ExprVisitor, // an expression that does not yet have `StructInfo`. auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); Call copied(call->op, call->args, call->attrs, call->sinfo_args); - Optional normalized = NullOpt; + Optional normalized = std::nullopt; try { normalized = dummy_builder->Normalize(copied); } catch (std::exception& err) { @@ -646,7 +646,7 @@ bool WellFormed(Variant obj, bool check_struct_info) { return WellFormedChecker::Check(obj, check_struct_info); } -TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed(WellFormed); +TVM_FFI_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed(WellFormed); } // namespace relax } // namespace tvm diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 3c87079f99d5..ff64504f6111 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -322,7 +322,7 @@ Array OpenCLMLCompiler(Array functions, Map -#include +#include +#include #include #include @@ -130,14 +130,14 @@ class CodegenCBase { * return 0; * } * - * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); * * int foo_init_wrapper_(Array arr) { * foo_consts = arr; * return 0; * } * - * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); * * \endcode */ @@ -218,19 +218,19 @@ class CodegenCBase { if (!const_arr_name.empty()) { // If there are constants, insert the __init_ and the wrapper // This segment would be generated in C++ because of the usage - // of tvm::runtime::Array. This is not ideal, but this to demonstrate + // of tvm::Array. This is not ideal, but this to demonstrate // constant copying process used packed imports in other external // codegen. Moreover, in microTVM we dont expect this part to be generated. code_stream_ << "#ifdef __cplusplus\n"; code_stream_ << "int " << func_name - << "_init_wrapper_(tvm::runtime::Array arr) {\n"; + << "_init_wrapper_(tvm::Array arr) {\n"; EnterScope(); PrintIndents(); code_stream_ << func_name << "_consts = arr;\n"; code_stream_ << "return 0;\n"; ExitScope(); code_stream_ << "}\n\n"; - code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name + code_stream_ << "TVM_FFI_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name << "_init_wrapper_);\n\n"; code_stream_ << "#endif\n"; } @@ -393,7 +393,7 @@ class CodegenCBase { * \return The created declaration */ std::string CreateNDArrayPool(const std::string& symbol) const { - return "tvm::runtime::Array " + symbol + "_consts;"; + return "tvm::Array " + symbol + "_consts;"; } /*! diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index ce8ad2f8a7aa..f7df28bf716f 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -111,14 +111,14 @@ class OpAttrExtractor : public AttrVisitor { } void Visit(const char* key, runtime::ObjectRef* value) final { - if (const auto* an = (*value).as()) { + if (const auto* an = (*value).as()) { std::vector attr; for (size_t i = 0; i < an->size(); ++i) { if (const auto* im = (*an)[i].as()) { attr.push_back(std::to_string(im->value)); } else if (const auto* fm = (*an)[i].as()) { attr.push_back(Fp2String(fm->value)); - } else if (const auto* str = (*an)[i].as()) { + } else if (const auto* str = (*an)[i].as()) { String s = GetRef(str); attr.push_back(s); } else { @@ -132,7 +132,7 @@ class OpAttrExtractor : public AttrVisitor { SetNodeAttr(key, std::vector{std::to_string(im->value)}); } else if (const auto* fm = (*value).as()) { SetNodeAttr(key, std::vector{Fp2String(fm->value)}); - } else if (const auto* str = (*value).as()) { + } else if (const auto* str = (*value).as()) { String s = GetRef(str); SetNodeAttr(key, std::vector{s}); } else { diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 431265d1a760..085535b87f83 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -125,7 +125,7 @@ Array CublasCompiler(Array functions, Map cuDNNCompiler(Array functions, Map #include #include #include #include +#include #include #include @@ -59,7 +59,7 @@ runtime::Module Finalize(const std::string& code, const Array& func_name << "Should only create CUTLASS CSourceModule if there is at least one CUTLASS partition"; std::ostringstream default_headers; - default_headers << "#include \n"; + default_headers << "#include \n"; default_headers << "#include \n"; default_headers << "#include \n"; default_headers << "#include \n"; @@ -101,7 +101,7 @@ class CodegenResult : public ObjectRef { TVM_REGISTER_NODE_TYPE(CodegenResultNode); -TVM_REGISTER_GLOBAL("contrib.cutlass.CodegenResult") +TVM_FFI_REGISTER_GLOBAL("contrib.cutlass.CodegenResult") .set_body_typed([](String code, Array headers) { return CodegenResult(code, headers); }); @@ -204,7 +204,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, const auto* fn_var = call->op.as(); ICHECK(fn_var); const auto func = Downcast(bindings_[GetRef(fn_var)]); - const auto pattern_name_opt = func->GetAttr(attr::kComposite); + const auto pattern_name_opt = func->GetAttr(attr::kComposite); ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; auto ret = GenerateBody(call, pattern_name_opt.value(), func->attrs->dict); ext_func_body_.push_back(ret.decl); @@ -385,7 +385,7 @@ Array CUTLASSCompiler(Array functions, Map DNNLCompiler(Array functions, Map HipblasCompiler(Array functions, Map()); } return compiled_functions; } -TVM_REGISTER_GLOBAL("relax.ext.hipblas").set_body_typed(HipblasCompiler); +TVM_FFI_REGISTER_GLOBAL("relax.ext.hipblas").set_body_typed(HipblasCompiler); } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index f0517fa359e2..d9f26c417257 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -17,11 +17,11 @@ * under the License. */ +#include #include #include #include #include -#include #include #include @@ -264,7 +264,7 @@ Array NNAPICompiler(Array functions, Map TensorRTCompiler(Array functions, Map GetTensorRTVersion() { #endif // TVM_GRAPH_EXECUTOR_TENSORRT } -TVM_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled").set_body_typed(IsTensorRTRuntimeEnabled); -TVM_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); +TVM_FFI_REGISTER_GLOBAL("relax.is_tensorrt_runtime_enabled") + .set_body_typed(IsTensorRTRuntimeEnabled); +TVM_FFI_REGISTER_GLOBAL("relax.get_tensorrt_version").set_body_typed(GetTensorRTVersion); } // namespace contrib } // namespace relax diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index 8e214809dd51..6574ccc37a15 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -75,7 +75,7 @@ bool EndsWithPattern(const std::string& str, const std::string& pattern) { return str.compare(str.length() - pattern.length(), pattern.length(), pattern) == 0; } -TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); +TVM_FFI_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); } // namespace backend } // namespace relax diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc index 34ebb4d6ddbf..840b44c12838 100644 --- a/src/relax/backend/pattern_registry.cc +++ b/src/relax/backend/pattern_registry.cc @@ -64,13 +64,14 @@ Optional GetPattern(const String& pattern_name) { return *it; } } - return NullOpt; + return std::nullopt; } -TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); -TVM_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); -TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix); -TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); +TVM_FFI_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); +TVM_FFI_REGISTER_GLOBAL("relax.backend.RemovePatterns").set_body_typed(RemovePatterns); +TVM_FFI_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix") + .set_body_typed(GetPatternsWithPrefix); +TVM_FFI_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); } // namespace backend } // namespace relax diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h index 72eea1238d38..2c1f385a2dda 100644 --- a/src/relax/backend/pattern_registry.h +++ b/src/relax/backend/pattern_registry.h @@ -26,10 +26,10 @@ #ifndef TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ #define TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ +#include #include #include #include -#include #include namespace tvm { @@ -62,7 +62,7 @@ Array GetPatternsWithPrefix(const String& prefix); /*! * \brief Find the pattern with a particular name. * \param name The pattern name. - * \return The matched pattern. NullOpt if not found. + * \return The matched pattern. std::nullopt if not found. */ Optional GetPattern(const String& name); diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index af6e83cf5f9a..686d24de62b2 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -139,7 +139,7 @@ class TaskExtractor : public ExprVisitor { std::optional normalize_mod_func_; }; -TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") +TVM_FFI_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") .set_body_typed([](IRModule mod, Target target, String mod_eq_name) { return TaskExtractor::ExtractTask(std::move(mod), std::move(target), std::move(mod_eq_name)); }); diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index b09602894577..f61579e25e96 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -225,7 +225,7 @@ class CodeGenVM : public ExprFunctor { LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; } } - return builder_->ConvertConstant(ShapeTuple(shape)); + return builder_->ConvertConstant(ffi::Shape(shape)); } Instruction::Arg VisitExpr_(const PrimValueNode* op) final { @@ -425,7 +425,7 @@ IRModule VMCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVM::Run(exec_builder, mod); } -TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); +TVM_FFI_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen); /*! * \brief Link the modules together, possibly create a constant module. @@ -490,7 +490,7 @@ Module VMLink(ExecBuilder builder, Target target, Optional lib, Array(const Expr&)> { if (dst_reg >= 0) { return RegListGet(dst_reg); } else { - return NullOpt; + return std::nullopt; } } @@ -291,7 +291,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; } } - return ConstListGet(builder_->ConvertConstant(ShapeTuple(shape)).value()); + return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value()); } Optional VisitExpr_(const PrimValueNode* op) final { return op->value; } @@ -356,13 +356,13 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { *kind = VMFuncInfo::FuncKind::kPackedFunc; return gvar->name_hint; } else { - return NullOpt; + return std::nullopt; } } // Lookup PrimFunc in the same module // We can do direct PrimFunc call in such cases Optional LookupPrimFunc(const String& name) { - if (!ctx_mod_->ContainGlobalVar(name)) return NullOpt; + if (!ctx_mod_->ContainGlobalVar(name)) return std::nullopt; GlobalVar gvar = ctx_mod_->GetGlobalVar(name); auto it = ctx_mod_->functions.find(gvar); @@ -372,7 +372,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { return GetRef(prim_func); } } - return NullOpt; + return std::nullopt; } Optional VisitExpr_(const GlobalVarNode* op) final { @@ -530,7 +530,7 @@ IRModule VMTIRCodeGen(ExecBuilder exec_builder, IRModule mod) { return CodeGenVMTIR::Run(exec_builder, mod); } -TVM_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen); +TVM_FFI_REGISTER_GLOBAL("relax.VMTIRCodeGen").set_body_typed(VMTIRCodeGen); } // namespace relax_vm } // namespace relax diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index 7bc5cafc0a9d..56b035212e3f 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -327,9 +327,9 @@ void ExecBuilderNode::Formalize() { } } -TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); -TVM_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ExecBuilder builder = args[0].cast(); ffi::Any rt; @@ -337,20 +337,21 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderConvertConstant") *ret = builder->ConvertConstant(rt).data(); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitFunction") .set_body_typed([](ExecBuilder builder, String func, int64_t num_inputs, Optional> param_names) { builder->EmitFunction(func, num_inputs, param_names); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEndFunction").set_body_method(&ExecBuilderNode::EndFunction); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEndFunction") + .set_body_method(&ExecBuilderNode::EndFunction); -TVM_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderDeclareFunction") .set_body_typed([](ExecBuilder builder, String name, int32_t kind) { builder->DeclareFunction(name, static_cast(kind)); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { std::vector args_; for (size_t i = 0; i < args.size(); ++i) { @@ -360,35 +361,38 @@ TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") builder->EmitCall(name, args_, dst_.value()); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") .set_body_typed([](ExecBuilder builder, int64_t data) { builder->EmitRet(Instruction::Arg::FromData(data)); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto").set_body_method(&ExecBuilderNode::EmitGoto); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitGoto").set_body_method(&ExecBuilderNode::EmitGoto); -TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderEmitIf") .set_body_typed([](ExecBuilder builder, int64_t data, vm::Index false_offset) { builder->EmitIf(Instruction::Arg::FromData(data), false_offset); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::Register(value).data(); -}); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderR") + .set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Register(value).data(); + }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::Immediate(value).data(); -}); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderImm") + .set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::Immediate(value).data(); + }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { - return Instruction::Arg::ConstIdx(value).data(); -}); +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderC") + .set_body_typed([](ExecBuilder builder, int64_t value) { + return Instruction::Arg::ConstIdx(value).data(); + }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder, String value) { return builder->GetFunction(value).data(); }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { +TVM_FFI_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { ObjectPtr p_exec = builder->Get(); return runtime::Module(p_exec); }); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index f746e4a5afd2..7757195bcb1d 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -231,7 +231,7 @@ Pass LowerRuntimeBuiltin() { return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); } // namespace transform } // namespace relax diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index e930ea46cf09..0b60553034fe 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -288,7 +288,7 @@ class VMShapeLowerMutator auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body)); - current_gvar_ = NullOpt; + current_gvar_ = std::nullopt; // create a new function return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs); @@ -778,7 +778,7 @@ class VMShapeLowerMutator std::vector> slot_vec_; /*! \brief Expr => slot. */ PrimExprSlotMap slot_map_; - Optional current_gvar_ = NullOpt; + Optional current_gvar_ = std::nullopt; /*! * \brief List of vars that are being defined but * have not go through outstanding shape compute check. @@ -813,7 +813,7 @@ Pass VMShapeLower(bool emit_err_ctx) { return CreateModulePass(pass_func, 0, "VMShapeLower", {}); } -TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { +TVM_FFI_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool emit_err_ctx) { return VMShapeLower(emit_err_ctx); }); diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index d9f0cfdfd71a..e1cc32fc82e3 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -23,7 +23,7 @@ namespace tvm { namespace relax { namespace distributed { -DeviceMesh::DeviceMesh(ShapeTuple shape, Array device_ids) { +DeviceMesh::DeviceMesh(ffi::Shape shape, Array device_ids) { int prod = 1; for (int i = 0; i < static_cast(shape.size()); i++) { prod *= shape[i]; @@ -36,7 +36,7 @@ DeviceMesh::DeviceMesh(ShapeTuple shape, Array device_ids) { data_ = std::move(n); } -DeviceMesh::DeviceMesh(ShapeTuple shape, Range device_range) { +DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { ObjectPtr n = make_object(); Array device_ids; int range_start = device_range->min.as()->value; @@ -57,8 +57,8 @@ DeviceMesh::DeviceMesh(ShapeTuple shape, Range device_range) { } TVM_REGISTER_NODE_TYPE(DeviceMeshNode); -TVM_REGISTER_GLOBAL("relax.distributed.DeviceMesh") - .set_body_typed([](ShapeTuple shape, Array device_ids, Optional device_range) { +TVM_FFI_REGISTER_GLOBAL("relax.distributed.DeviceMesh") + .set_body_typed([](ffi::Shape shape, Array device_ids, Optional device_range) { if (device_range.defined()) return DeviceMesh(shape, device_range.value()); else diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 3569b1538551..0ff9d4d6fa09 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -43,11 +43,11 @@ PlacementSpec PlacementSpec::Replica() { TVM_REGISTER_NODE_TYPE(PlacementSpecNode); -TVM_REGISTER_GLOBAL("relax.distributed.Sharding").set_body_typed([](int axis) { +TVM_FFI_REGISTER_GLOBAL("relax.distributed.Sharding").set_body_typed([](int axis) { return PlacementSpec::Sharding(axis); }); -TVM_REGISTER_GLOBAL("relax.distributed.Replica").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("relax.distributed.Replica").set_body_typed([]() { return PlacementSpec::Replica(); }); @@ -106,8 +106,8 @@ Placement Placement::FromText(String text_repr) { } TVM_REGISTER_NODE_TYPE(PlacementNode); -TVM_REGISTER_GLOBAL("relax.distributed.PlacementFromText").set_body_typed(Placement::FromText); -TVM_REGISTER_GLOBAL("relax.distributed.Placement") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.PlacementFromText").set_body_typed(Placement::FromText); +TVM_FFI_REGISTER_GLOBAL("relax.distributed.Placement") .set_body_typed([](Array dim_specs) { return Placement(dim_specs); }); // DTensor @@ -130,7 +130,7 @@ DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh d TVM_REGISTER_NODE_TYPE(DTensorStructInfoNode); -TVM_REGISTER_GLOBAL("relax.distributed.DTensorStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.DTensorStructInfo") .set_body_typed([](TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span) { return DTensorStructInfo(tensor_sinfo, device_mesh, placement, span); diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 5ed947858775..1df1d2110ba9 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -115,7 +115,7 @@ Pass LegalizeRedistribute() { }; return CreateModulePass(pass_func, 1, "LegalizeRedistribute", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LegalizeRedistribute") .set_body_typed(LegalizeRedistribute); } // namespace transform diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 7c729e443837..e4f811b83d42 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -61,7 +61,7 @@ class DistIRSharder : public ExprMutator { } ShapeExpr ShardShape(ShapeExpr orig_shape, DeviceMesh device_mesh, Placement placement) { - ShapeTuple device_mesh_shape = device_mesh->shape; + ffi::Shape device_mesh_shape = device_mesh->shape; Array new_tensor_shape_value = orig_shape->values; for (int i = 0; i < static_cast(device_mesh_shape.size()); i++) { if (placement->dim_specs[i]->kind == PlacementSpecKind::kSharding) { @@ -178,7 +178,7 @@ class DistIRSharder : public ExprMutator { func_ = func; new_params_ = new_params; auto new_body = VisitWithNewScope(func->body, new_params); - Function new_func(new_params, new_body, NullOpt, func->is_pure, func->attrs); + Function new_func(new_params, new_body, std::nullopt, func->is_pure, func->attrs); return new_func; } @@ -262,7 +262,7 @@ Pass LowerDistIR() { auto pass_func = [=](IRModule m, PassContext pc) { return DistIRSharder::LowerDistIR(m); }; return CreateModulePass(pass_func, 1, "LowerDistIR", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.LowerDistIR").set_body_typed(LowerDistIR); +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LowerDistIR").set_body_typed(LowerDistIR); } // namespace transform } // namespace distributed diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 514a98ef44f3..c8abe2b1d1b5 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -432,7 +432,7 @@ Pass LowerGlobalViewToLocalView() { auto pass_func = [=](IRModule m, PassContext pc) { return LowerTIRToLocalView(m).Lower(); }; return CreateModulePass(pass_func, 1, "LowerGlobalViewToLocalView", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.LowerGlobalViewToLocalView") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.LowerGlobalViewToLocalView") .set_body_typed(LowerGlobalViewToLocalView); } // namespace transform diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index dfcdde5c9eb4..f5f276c2b873 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -429,7 +429,7 @@ class DistributedIRBuilder : public ExprMutator { } } auto new_body = VisitWithNewScope(func->body, new_params); - Function new_func(new_params, new_body, NullOpt, func->is_pure, func->attrs); + Function new_func(new_params, new_body, std::nullopt, func->is_pure, func->attrs); return new_func; } @@ -615,7 +615,7 @@ Pass PropagateSharding() { }; return CreateModulePass(pass_func, 1, "PropagateSharding", {}); } -TVM_REGISTER_GLOBAL("relax.distributed.transform.PropagateSharding") +TVM_FFI_REGISTER_GLOBAL("relax.distributed.transform.PropagateSharding") .set_body_typed(PropagateSharding); } // namespace transform diff --git a/src/relax/distributed/transform/utils.h b/src/relax/distributed/transform/utils.h index 26ce6530116f..2680c892695c 100644 --- a/src/relax/distributed/transform/utils.h +++ b/src/relax/distributed/transform/utils.h @@ -40,7 +40,7 @@ inline Optional MatchPrimFunc(const IRModule& mod_, const Expr& o if (auto* pfunc = base_func.as()) { return GetRef(pfunc); } - return NullOpt; + return std::nullopt; } /*! * \brief Check whether the given struct infos can appear in DistIR diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 1aee7bef780a..f35b443b5b39 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -51,7 +51,7 @@ DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowBlockRewrite") +TVM_FFI_REGISTER_GLOBAL("relax.DataflowBlockRewrite") .set_body_typed([](DataflowBlock dfb, Function root_fn) { return DataflowBlockRewrite(dfb, root_fn); }); @@ -110,7 +110,7 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { } } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_replace_all_uses") .set_body_typed([](DataflowBlockRewrite rwt, Var old_var, Var new_var) { rwt->ReplaceAllUses(old_var, new_var); }); @@ -178,10 +178,10 @@ void DataflowBlockRewriteNode::Add(Binding binding) { } } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_add_binding") .set_body_typed([](DataflowBlockRewrite rwt, Binding vb) { rwt->Add(vb); }); -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_add") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_add") .set_body_typed([](DataflowBlockRewrite rwt, Expr expr, Optional name, bool is_dfvar) { if (name.has_value()) { rwt->Add(name.value(), expr, is_dfvar); @@ -235,7 +235,7 @@ std::set GetUnusedVars(Map> users_map, Array fn_output class RemoveUnusedVars : public ExprMutator { public: std::set unused_vars; - Optional caught_rewrite = NullOpt; + Optional caught_rewrite = std::nullopt; RemoveUnusedVars(std::set unused_vars) : unused_vars(std::move(unused_vars)) {} @@ -292,7 +292,7 @@ void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { to_users_.erase(unused); // update use-def chain. } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_remove_unused") .set_body_typed([](DataflowBlockRewrite rwt, Var unused, bool allow_undef) { rwt->RemoveUnused(unused, allow_undef); }); @@ -314,7 +314,7 @@ void DataflowBlockRewriteNode::RemoveAllUnused() { for (const auto& unused : remover.unused_vars) to_users_.erase(unused); } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); Expr RemoveAllUnused(Expr expr) { @@ -333,7 +333,7 @@ Expr RemoveAllUnused(Expr expr) { return remover.VisitExpr(std::move(expr)); } -TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); +TVM_FFI_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { BlockBuilder builder = BlockBuilder::Create(irmod); @@ -348,7 +348,7 @@ IRModule DataflowBlockRewriteNode::MutateIRModule(IRModule irmod) { return builder->GetContextIRModule(); } -TVM_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") +TVM_FFI_REGISTER_GLOBAL("relax.dfb_rewrite_mutate_irmodule") .set_body_typed([](DataflowBlockRewrite rwt, IRModule irmod) { return rwt->MutateIRModule(irmod); }); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5f56e7cf453d..63288201e741 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -21,6 +21,7 @@ * \file src/relax/block_builder.cc */ #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include #include #include @@ -162,7 +162,7 @@ class BlockBuilderImpl : public BlockBuilderNode { //------------------------------- Optional LookupBinding(const Var& var) final { auto it = binding_table_.find(var->vid); - if (it == binding_table_.end()) return NullOpt; + if (it == binding_table_.end()) return std::nullopt; return it->second; } @@ -418,8 +418,8 @@ class BlockBuilderImpl : public BlockBuilderNode { name_hint = is_dataflow ? "lv" : "gv"; } Id vid = Id(GetUniqueName(name_hint)); - return is_dataflow ? DataflowVar(vid, /*struct_info_annotation=*/NullOpt) - : Var(vid, /*struct_info_annotation=*/NullOpt); + return is_dataflow ? DataflowVar(vid, /*struct_info_annotation=*/std::nullopt) + : Var(vid, /*struct_info_annotation=*/std::nullopt); } private: @@ -866,12 +866,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Optional { auto it = curr_scope->shape_var_map.find(var); if (it != curr_scope->shape_var_map.end()) return (*it).second; - return NullOpt; + return std::nullopt; }; return EraseToWellDefined(info, f_shape_var_map); } - Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + Expr VisitWithNewScope(const Expr& expr, Optional> params = std::nullopt) { if (params.defined()) { this->BeginScope(params.value()); } else { @@ -1054,65 +1054,67 @@ BlockBuilder BlockBuilder::Create(Optional mod, //--------------------------------------- TVM_REGISTER_OBJECT_TYPE(BlockBuilderNode); -TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed([](Optional mod) { return BlockBuilder::Create(mod); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginDataflowBlock") .set_body_method(&BlockBuilderNode::BeginDataflowBlock); -TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginBindingBlock") .set_body_method(&BlockBuilderNode::BeginBindingBlock); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock").set_body_method(&BlockBuilderNode::EndBlock); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEndBlock").set_body_method(&BlockBuilderNode::EndBlock); -TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize").set_body_method(&BlockBuilderNode::Normalize); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderNormalize") + .set_body_method(&BlockBuilderNode::Normalize); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmit") .set_body_typed([](BlockBuilder builder, Expr expr, String name_hint) { return builder->Emit(expr, name_hint); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) { return builder->EmitMatchCast(value, struct_info, name_hint); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") .set_body_typed([](BlockBuilder builder, const Expr& output, String name_hint) { return builder->EmitOutput(output, name_hint); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEmitNormalized") .set_body_typed([](BlockBuilder builder, Binding binding) { return builder->EmitNormalized(binding); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName") .set_body_typed([](BlockBuilder builder, String name_hint) { return builder->name_supply()->FreshName(name_hint, /*add_prefix*/ false, /*add_underscore*/ false); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderAddFunction") .set_body_method(&BlockBuilderNode::AddFunction); -TVM_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderUpdateFunction") .set_body_method(&BlockBuilderNode::UpdateFunction); -TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule") .set_body_method(&BlockBuilderNode::GetContextIRModule); -TVM_REGISTER_GLOBAL("relax.BlockBuilderFinalize").set_body_method(&BlockBuilderNode::Finalize); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderFinalize").set_body_method(&BlockBuilderNode::Finalize); -TVM_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderCurrentBlockIsDataFlow") .set_body_method(&BlockBuilderNode::CurrentBlockIsDataFlow); -TVM_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderLookupBinding") .set_body_method(&BlockBuilderNode::LookupBinding); -TVM_REGISTER_GLOBAL("relax.BlockBuilderBeginScope").set_body_method(&BlockBuilderNode::BeginScope); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderBeginScope") + .set_body_method(&BlockBuilderNode::BeginScope); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEndScope").set_body_method(&BlockBuilderNode::EndScope); +TVM_FFI_REGISTER_GLOBAL("relax.BlockBuilderEndScope").set_body_method(&BlockBuilderNode::EndScope); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 661c43842db4..172f4d7bcb27 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -201,7 +201,7 @@ static std::optional TryValidate( if (auto ptr = current_match.matched(p_node)) { return GetRef(ptr); } else { - return NullOpt; + return std::nullopt; } }; @@ -340,14 +340,14 @@ Optional> MatchGraph(const PatternContext& ctx, } if (roots.empty()) { - return NullOpt; + return std::nullopt; } arith::Analyzer analyzer; auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, ctx->validation_constraints, ud_analysis, &analyzer); if (!match) { - return NullOpt; + return std::nullopt; } Map ret; @@ -362,7 +362,7 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } -TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.match_dfb") .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { return MatchGraph(ctx, dfb); }); @@ -397,7 +397,7 @@ class PatternContextRewriterNode : public PatternMatchingRewriterNode { } } - return NullOpt; + return std::nullopt; } }; @@ -447,7 +447,7 @@ Function RewriteBindings( return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index da1614f50b47..c398305d938c 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -193,17 +193,16 @@ void RewriteSpec::Append(RewriteSpec other) { TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); -TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") .set_body_typed([](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { return PatternMatchingRewriter::FromPattern(pattern, func); }); -TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule").set_body_typed([](IRModule mod) { - return PatternMatchingRewriter::FromModule(mod); -}); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule") + .set_body_typed([](IRModule mod) { return PatternMatchingRewriter::FromModule(mod); }); -TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") .set_body_typed([](PatternMatchingRewriter rewriter, Variant obj) -> Variant { if (auto expr = obj.as()) { @@ -256,10 +255,10 @@ Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, return rewritten_expr.value(); } } - return NullOpt; + return std::nullopt; } -TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternRewriter") .set_body_typed([](DFPattern pattern, ffi::TypedFunction(Expr, Map)> func) { return ExprPatternRewriter(pattern, func); @@ -308,7 +307,7 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons return lhs_match; } -TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.OrRewriter") .set_body_typed([](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { return OrRewriter(lhs, rhs); }); @@ -603,7 +602,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( return rewrites; } -TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TupleRewriter") .set_body_typed([](Array patterns, ffi::TypedFunction(Expr, Map)> func) { return TupleRewriter(patterns, func); @@ -780,7 +779,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return SeqExpr(new_blocks, func_replacement->body->body); }; - return PatternMatchingRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); + return PatternMatchingRewriter::FromPattern(top_pattern, rewriter_func, std::nullopt, + new_subroutines); } Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, @@ -789,19 +789,19 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, DFPatternMatcher matcher(bindings); if (!matcher.Match(pattern, expr)) { - return NullOpt; + return std::nullopt; } return matcher.GetMemo(); } -TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); } -TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); /*! * \brief Apply pattern matching to each expression, replacing @@ -857,7 +857,7 @@ class PatternMatchingMutator : public ExprMutator { // If the SeqExpr's output is not a variable, treat it as if it // were the last variable binding of the last block. This // simplifies the special handling of the SeqExpr's body. - Optional dummy_output_var = NullOpt; + Optional dummy_output_var = std::nullopt; if (!seq->body->IsInstance()) { dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); VarBinding dummy_binding(dummy_output_var.value(), seq->body); @@ -991,7 +991,7 @@ class PatternMatchingMutator : public ExprMutator { auto new_blocks = old_blocks.Map(visit_block); if (old_blocks.same_as(new_blocks)) { - return NullOpt; + return std::nullopt; } // Restore the body of the SeqExpr, if needed. @@ -1073,7 +1073,7 @@ Function RewriteCall(const DFPattern& pat, return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 1176f1eaee7e..8a4ca3f7ba0a 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -80,7 +80,7 @@ Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) } } - return NullOpt; + return std::nullopt; }; while (auto unwrapped = unwrap(expr)) { @@ -489,7 +489,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( } else if (auto shape_expr = sinfo.as()) { return shape_expr->values; } else { - return NullOpt; + return std::nullopt; } }(); @@ -524,8 +524,8 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( } else { // Missing an argument, so the constraint will either return - // NullOpt or false at this point. However, delay the return of - // NullOpt until the end of the function, because we'd rather + // std::nullopt or false at this point. However, delay the return of + // std::nullopt until the end of the function, because we'd rather // return "false" if it possible to do so. all_shapes_defined = false; } diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index c5d58db5b9d0..76f48383c47c 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -38,7 +38,7 @@ namespace relax { class DFPatternMatcher : public DFPatternFunctor { public: - using var2val_t = runtime::Map; + using var2val_t = Map; explicit DFPatternMatcher() {} explicit DFPatternMatcher(var2val_t var2val) : var2val_(std::move(var2val)) {} diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index f7414faaabb6..db242b773be6 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -44,7 +44,7 @@ ExternFuncPattern::ExternFuncPattern(String global_symbol) { n->global_symbol_ = std::move(global_symbol); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ExternFuncPattern").set_body_typed([](String global_symbol) { return ExternFuncPattern(global_symbol); }); RELAX_PATTERN_PRINTER_DEF(ExternFuncPatternNode, [](auto p, auto node) { @@ -57,7 +57,7 @@ VarPattern::VarPattern(String name_hint) { n->name = std::move(name_hint); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.VarPattern").set_body_typed([](String name_hint) { return VarPattern(name_hint); }); RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { @@ -65,7 +65,7 @@ RELAX_PATTERN_PRINTER_DEF(VarPatternNode, [](auto p, auto node) { }); TVM_REGISTER_NODE_TYPE(DataflowVarPatternNode); -TVM_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.DataflowVarPattern").set_body_typed([](String name_hint) { return DataflowVarPattern(name_hint); }); DataflowVarPattern::DataflowVarPattern(String name_hint) { @@ -83,7 +83,7 @@ GlobalVarPattern::GlobalVarPattern(String name_hint) { n->name = std::move(name_hint); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.GlobalVarPattern").set_body_typed([](String name_hint) { return GlobalVarPattern(name_hint); }); RELAX_PATTERN_PRINTER_DEF(GlobalVarPatternNode, [](auto p, auto node) { @@ -96,11 +96,13 @@ ExprPattern::ExprPattern(Expr expr) { n->expr = std::move(expr); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { return ExprPattern(e); }); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ExprPattern").set_body_typed([](Expr e) { + return ExprPattern(e); +}); RELAX_PATTERN_PRINTER_DEF(ExprPatternNode, [](auto p, auto node) { p->Print(node->expr); }); TVM_REGISTER_NODE_TYPE(ConstantPatternNode); -TVM_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ConstantPattern").set_body_typed([]() { auto c = ConstantPattern(make_object()); return c; }); @@ -115,7 +117,7 @@ CallPattern::CallPattern(DFPattern op, Array args, bool varg_default_ n->varg_default_wildcard = varg_default_wildcard; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.CallPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.CallPattern") .set_body_typed([](DFPattern op, Array args, bool varg_default_wildcard) { return CallPattern(op, args, varg_default_wildcard); }); @@ -138,7 +140,7 @@ PrimArrPattern::PrimArrPattern(Array arr) { n->fields = std::move(arr); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PrimArrPattern").set_body_typed([](Array arr) { return PrimArrPattern(std::move(arr)); }); RELAX_PATTERN_PRINTER_DEF(PrimArrPatternNode, [](auto p, auto node) { @@ -152,7 +154,7 @@ FunctionPattern::FunctionPattern(Array params, DFPattern body) { n->body = std::move(body); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.FunctionPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.FunctionPattern") .set_body_typed([](Array params, DFPattern body) { return FunctionPattern(params, body); }); @@ -166,7 +168,7 @@ TuplePattern::TuplePattern(tvm::Array fields) { n->fields = std::move(fields); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TuplePattern").set_body_typed([](tvm::Array fields) { return TuplePattern(fields); }); RELAX_PATTERN_PRINTER_DEF(TuplePatternNode, [](auto p, auto node) { @@ -179,7 +181,7 @@ UnorderedTuplePattern::UnorderedTuplePattern(tvm::Array fields) { n->fields = std::move(fields); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.UnorderedTuplePattern") .set_body_typed([](tvm::Array fields) { return UnorderedTuplePattern(fields); }); RELAX_PATTERN_PRINTER_DEF(UnorderedTuplePatternNode, [](auto p, auto node) { p->stream << "UnorderedTuplePattern(" << node->fields << ")"; @@ -192,9 +194,8 @@ TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { n->index = index; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern").set_body_typed([](DFPattern tuple, int index) { - return TupleGetItemPattern(tuple, index); -}); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TupleGetItemPattern") + .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); RELAX_PATTERN_PRINTER_DEF(TupleGetItemPatternNode, [](auto p, auto node) { p->stream << "TupleGetItemPattern(" << node->tuple << ", " << node->index << ")"; }); @@ -206,7 +207,7 @@ AndPattern::AndPattern(DFPattern left, DFPattern right) { n->right = std::move(right); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.AndPattern").set_body_typed([](DFPattern left, DFPattern right) { return AndPattern(left, right); }); RELAX_PATTERN_PRINTER_DEF(AndPatternNode, [](auto p, auto node) { @@ -220,7 +221,7 @@ OrPattern::OrPattern(DFPattern left, DFPattern right) { n->right = std::move(right); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.OrPattern").set_body_typed([](DFPattern left, DFPattern right) { return OrPattern(left, right); }); RELAX_PATTERN_PRINTER_DEF(OrPatternNode, [](auto p, auto node) { @@ -233,7 +234,7 @@ NotPattern::NotPattern(DFPattern reject) { n->reject = std::move(reject); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.NotPattern").set_body_typed([](DFPattern reject) { return NotPattern(reject); }); RELAX_PATTERN_PRINTER_DEF(NotPatternNode, @@ -241,7 +242,9 @@ RELAX_PATTERN_PRINTER_DEF(NotPatternNode, TVM_REGISTER_NODE_TYPE(WildcardPatternNode); WildcardPattern::WildcardPattern() { data_ = make_object(); } -TVM_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { return WildcardPattern(); }); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.WildcardPattern").set_body_typed([]() { + return WildcardPattern(); +}); RELAX_PATTERN_PRINTER_DEF(WildcardPatternNode, [](auto p, auto node) { p->stream << "*"; }); TVM_REGISTER_NODE_TYPE(TypePatternNode); @@ -251,7 +254,7 @@ TypePattern::TypePattern(DFPattern pattern, Type type) { n->type = std::move(type); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.TypePattern").set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); RELAX_PATTERN_PRINTER_DEF(TypePatternNode, [](auto p, auto node) { @@ -265,7 +268,7 @@ StructInfoPattern::StructInfoPattern(DFPattern pattern, StructInfo struct_info) n->struct_info = std::move(struct_info); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.StructInfoPattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.StructInfoPattern") .set_body_typed([](DFPattern pattern, StructInfo struct_info) { return StructInfoPattern(pattern, struct_info); }); @@ -281,7 +284,7 @@ ShapePattern::ShapePattern(DFPattern pattern, Array shape) { n->shape = std::move(shape); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.ShapePattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.ShapePattern") .set_body_typed([](DFPattern pattern, Array shape) { return ShapePattern(pattern, shape); }); @@ -299,7 +302,7 @@ SameShapeConstraint::SameShapeConstraint(Array args) { ctx.value().add_constraint(*this); } } -TVM_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array args) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.SameShapeConstraint").set_body_typed([](Array args) { return SameShapeConstraint(args); }); RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { @@ -320,7 +323,7 @@ DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { n->dtype = std::move(dtype); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.DataTypePattern") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.DataTypePattern") .set_body_typed([](DFPattern pattern, DataType dtype) { return DataTypePattern(pattern, dtype); }); @@ -335,9 +338,8 @@ AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { n->attrs = std::move(attrs); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.dpl.AttrPattern").set_body_typed([](DFPattern pattern, DictAttrs attrs) { - return AttrPattern(pattern, attrs); -}); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.AttrPattern") + .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); }); RELAX_PATTERN_PRINTER_DEF(AttrPatternNode, [](auto p, auto node) { p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; }); @@ -424,7 +426,7 @@ DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { return DataTypePattern(*this, dtype); } DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { - return HasDtype(DataType(runtime::StringToDLDataType(dtype))); + return HasDtype(DataType(ffi::StringToDLDataType(dtype))); } ShapePattern DFPattern::HasShape(const Array& shape) const { return ShapePattern(*this, shape); @@ -438,7 +440,7 @@ std::stack& pattern_ctx_stack() { } Optional PatternContext::Current() { - if (pattern_ctx_stack().empty()) return NullOpt; + if (pattern_ctx_stack().empty()) return std::nullopt; return pattern_ctx_stack().top(); } @@ -511,7 +513,7 @@ PatternSeq PatternSeq::dup() const { return ret; } -TVM_REGISTER_GLOBAL("relax.dpl.PatternSeq") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternSeq") .set_body_typed([](Array patterns, bool only_used_by) { return PatternSeq(std::move(patterns), only_used_by); }); @@ -525,12 +527,12 @@ RELAX_PATTERN_PRINTER_DEF(PatternSeqNode, [](auto p, auto node) { p->stream << "]"; }); -TVM_REGISTER_GLOBAL("relax.dpl.used_by") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.used_by") .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.UsedBy(rhs, index); }); -TVM_REGISTER_GLOBAL("relax.dpl.only_used_by") +TVM_FFI_REGISTER_GLOBAL("relax.dpl.only_used_by") .set_body_typed([](PatternSeq lhs, PatternSeq rhs, int index) { return lhs.OnlyUsedBy(rhs, index); }); @@ -643,25 +645,27 @@ DFPattern DFPattern::dup() const { return pattern; } -TVM_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.dup_pattern").set_body_typed([](DFPattern pattern) { return pattern.dup(); }); -TVM_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { return seq.dup(); }); +TVM_FFI_REGISTER_GLOBAL("relax.dpl.dup_seq").set_body_typed([](PatternSeq seq) { + return seq.dup(); +}); -TVM_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.PatternContext").set_body_typed([](bool incre) { return PatternContext(incre); }); -TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { return PatternContext::Current(); }); -TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const PatternContext& ctx) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const PatternContext& ctx) { ctx.EnterWithScope(); }); -TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const PatternContext& ctx) { +TVM_FFI_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const PatternContext& ctx) { ctx.ExitWithScope(); }); diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index 4eec98373d0c..d2016adbf8e7 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -67,7 +67,7 @@ class PatternMatchingRewriter : public tvm::transform::Pass { public: static PatternMatchingRewriter FromPattern( DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = NullOpt, + Optional> additional_bindings = std::nullopt, Map new_subroutines = {}); static PatternMatchingRewriter FromModule(IRModule mod); @@ -103,7 +103,7 @@ class ExprPatternRewriter : public PatternMatchingRewriter { public: ExprPatternRewriter(DFPattern pattern, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = NullOpt, + Optional> additional_bindings = std::nullopt, Map new_subroutines = {}); TVM_DEFINE_OBJECT_REF_METHODS(ExprPatternRewriter, PatternMatchingRewriter, @@ -170,7 +170,7 @@ class TupleRewriter : public PatternMatchingRewriter { public: TupleRewriter(Array patterns, ffi::TypedFunction(Expr, Map)> func, - Optional> additional_bindings = NullOpt, + Optional> additional_bindings = std::nullopt, Map new_subroutines = {}); TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, PatternMatchingRewriter, TupleRewriterNode); diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index bfb5896c9988..e75dc3c2d7ca 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -49,7 +49,7 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string n->dtype = DataType(constant->data->dtype); int ndim = constant->data->ndim; - ShapeTuple shape_tuple = constant->data.Shape(); + ffi::Shape shape_tuple = constant->data.Shape(); Array shape; shape.reserve(ndim); for (int i = 0; i < ndim; ++i) { @@ -72,7 +72,7 @@ te::Tensor TETensor(Expr value, Map tir_var_map, std::string return te::PlaceholderOp(n).output(0); } -TVM_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); +TVM_FFI_REGISTER_GLOBAL("relax.TETensor").set_body_typed(TETensor); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 4f8a1b6c650d..238cece41f61 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -27,7 +27,6 @@ namespace tvm { namespace relax { using tvm::ReprPrinter; -using tvm::runtime::Optional; TVM_REGISTER_NODE_TYPE(IdNode); @@ -98,7 +97,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_GLOBAL("relax.Call") +TVM_FFI_REGISTER_GLOBAL("relax.Call") .set_body_typed([](Expr op, Array args, Attrs attrs, Array sinfo_args, Span span) { return Call(op, args, attrs, sinfo_args, span); }); @@ -133,7 +132,7 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_GLOBAL("relax.If") +TVM_FFI_REGISTER_GLOBAL("relax.If") .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span span) { return If(cond, true_branch, false_branch, span); }); @@ -145,7 +144,7 @@ Tuple::Tuple(tvm::Array fields, Span span) { if (field->struct_info_.defined()) { field_sinfo.push_back(GetStructInfo(field)); } else { - return NullOpt; + return std::nullopt; } } return TupleStructInfo(field_sinfo); @@ -163,7 +162,7 @@ Tuple::Tuple(tvm::Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); @@ -227,7 +226,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { return TupleGetItem(tuple, index, span); }); @@ -250,7 +249,7 @@ ShapeExpr::ShapeExpr(Array values, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array values, Span span) { return ShapeExpr(values, span); }); @@ -286,12 +285,12 @@ VarNode* Var::CopyOnWrite() { return static_cast(data_.get()); } -TVM_REGISTER_GLOBAL("relax.Var") +TVM_FFI_REGISTER_GLOBAL("relax.Var") .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); }); -TVM_REGISTER_GLOBAL("relax.VarFromId") +TVM_FFI_REGISTER_GLOBAL("relax.VarFromId") .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { return Var(vid, struct_info_annotation, span); }); @@ -310,12 +309,12 @@ DataflowVar::DataflowVar(Id vid, Optional struct_info_annotation, Sp data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowVar") +TVM_FFI_REGISTER_GLOBAL("relax.DataflowVar") .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { return DataflowVar(name_hint, struct_info_annotation, span); }); -TVM_REGISTER_GLOBAL("relax.DataflowVarFromId") +TVM_FFI_REGISTER_GLOBAL("relax.DataflowVarFromId") .set_body_typed([](Id vid, Optional struct_info_annotation, Span span) { return DataflowVar(vid, struct_info_annotation, span); }); @@ -345,8 +344,9 @@ Constant::Constant(runtime::NDArray data, Optional struct_info_annot TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relax.Constant") - .set_body_typed([](runtime::NDArray data, Optional struct_info_annotation = NullOpt, +TVM_FFI_REGISTER_GLOBAL("relax.Constant") + .set_body_typed([](runtime::NDArray data, + Optional struct_info_annotation = std::nullopt, Span span = Span()) { return Constant(data, struct_info_annotation, span); }); @@ -366,7 +366,7 @@ PrimValue PrimValue::Int64(int64_t value, Span span) { TVM_REGISTER_NODE_TYPE(PrimValueNode); -TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span span) { return PrimValue(value, span); }); @@ -383,7 +383,7 @@ StringImm::StringImm(String value, Span span) { TVM_REGISTER_NODE_TYPE(StringImmNode); -TVM_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.StringImm").set_body_typed([](String value, Span span) { return StringImm(value, span); }); @@ -400,7 +400,7 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { TVM_REGISTER_NODE_TYPE(DataTypeImmNode); -TVM_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.DataTypeImm").set_body_typed([](DataType value, Span span) { return DataTypeImm(value, span); }); @@ -416,7 +416,7 @@ MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.MatchCast") +TVM_FFI_REGISTER_GLOBAL("relax.MatchCast") .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { return MatchCast(var, value, struct_info, span); }); @@ -458,7 +458,7 @@ VarBinding::VarBinding(Var var, Expr value, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, Span span) { return VarBinding(var, value, span); }); @@ -513,9 +513,10 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { return static_cast(data_.get()); } -TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { - return BindingBlock(bindings, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.BindingBlock") + .set_body_typed([](Array bindings, Span span) { + return BindingBlock(bindings, span); + }); TVM_REGISTER_NODE_TYPE(DataflowBlockNode); @@ -526,9 +527,10 @@ DataflowBlock::DataflowBlock(Array bindings, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.DataflowBlock").set_body_typed([](Array bindings, Span span) { - return DataflowBlock(bindings, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.DataflowBlock") + .set_body_typed([](Array bindings, Span span) { + return DataflowBlock(bindings, span); + }); TVM_REGISTER_NODE_TYPE(SeqExprNode); @@ -548,7 +550,7 @@ SeqExpr::SeqExpr(Array blocks, Expr body, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.SeqExpr") +TVM_FFI_REGISTER_GLOBAL("relax.SeqExpr") .set_body_typed([](Array blocks, Expr body, Span span) { return SeqExpr(blocks, body, span); }); @@ -602,7 +604,7 @@ Function::Function(Array params, Expr body, Optional ret_struct if (lookup.count(var)) { return var; } else { - return NullOpt; + return std::nullopt; } }; }(); @@ -624,7 +626,7 @@ Function::Function(Array params, Expr body, Optional ret_struct data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.Function") +TVM_FFI_REGISTER_GLOBAL("relax.Function") .set_body_typed([](Array params, Expr body, Optional ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { return Function(params, body, ret_struct_info, is_pure, attrs, span); @@ -662,7 +664,7 @@ Function Function::CreateEmpty(Array params, StructInfo ret_struct_info, bo return Function(std::move(n)); } -TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") +TVM_FFI_REGISTER_GLOBAL("relax.FunctionCreateEmpty") .set_body_typed([](Array params, StructInfo ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { return Function::CreateEmpty(params, ret_struct_info, is_pure, attrs, span); @@ -670,7 +672,7 @@ TVM_REGISTER_GLOBAL("relax.FunctionCreateEmpty") // Special opaque derivation function for ExternFunc // Take look at sinfo_args to figure out the return StructInfo. -TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") +TVM_FFI_REGISTER_GLOBAL("tvm.relax.struct_info.infer_by_sinfo_args") .set_body_typed([](const Call& call, const BlockBuilder& ctx) -> StructInfo { ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; if (call->sinfo_args.empty()) { @@ -708,7 +710,7 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ExternFunc") +TVM_FFI_REGISTER_GLOBAL("relax.ExternFunc") .set_body_typed([](String global_symbol, Optional struct_info, Span span) { if (struct_info.defined()) { return ExternFunc(global_symbol, struct_info.value(), span); @@ -732,32 +734,32 @@ Expr GetShapeOf(const Expr& expr) { return call_shape_of; } -TVM_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { +TVM_FFI_REGISTER_GLOBAL("relax.GetShapeOf").set_body_typed([](const Expr& expr) { return GetShapeOf(expr); }); -TVM_REGISTER_GLOBAL("relax.FuncWithAttr") +TVM_FFI_REGISTER_GLOBAL("relax.FuncWithAttr") .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } - return NullOpt; + return std::nullopt; }); -TVM_REGISTER_GLOBAL("relax.FuncWithAttrs") +TVM_FFI_REGISTER_GLOBAL("relax.FuncWithAttrs") .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); } - return NullOpt; + return std::nullopt; }); -TVM_REGISTER_GLOBAL("relax.FuncWithoutAttr") +TVM_FFI_REGISTER_GLOBAL("relax.FuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> Optional { if (func->IsInstance()) { return WithoutAttr(Downcast(std::move(func)), key); } - return NullOpt; + return std::nullopt; }); } // namespace relax diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index a450919decff..5e04453a1227 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -326,7 +326,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit") +TVM_FFI_REGISTER_GLOBAL("relax.analysis.post_order_visit") .set_body_typed([](Expr expr, ffi::Function f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); @@ -601,7 +601,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { // example, if the previous return value was // `TensorStructInfo(shape=[16,16])`, but the new return value is // `TensorStructInfo(shape=[8,8])`. - return Function(params, body, NullOpt, op->is_pure, op->attrs); + return Function(params, body, std::nullopt, op->is_pure, op->attrs); } } diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index 4a36bf214884..dc355cef905f 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -540,30 +540,30 @@ class PyExprMutator : public ObjectRef { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); }; -TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); +TVM_FFI_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { visitor->VisitBinding(binding); }); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { visitor->VisitBindingBlock(block); }); -TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->ExprVisitor::VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { if (const auto* ptr = binding.as()) { visitor->ExprVisitor::VisitBinding_(ptr); @@ -574,7 +574,7 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") } }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { if (const auto* ptr = block.as()) { visitor->ExprVisitor::VisitBindingBlock_(ptr); @@ -585,7 +585,7 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") } }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") .set_body_typed([](PyExprVisitor visitor, const Var& var) { if (const auto* node = var.as()) { visitor->ExprVisitor::VisitVarDef_(node); @@ -596,39 +596,39 @@ TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") } }); -TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") +TVM_FFI_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") .set_body_typed([](PyExprVisitor visitor, const Span& span) { visitor->ExprVisitor::VisitSpan(span); }); -TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); +TVM_FFI_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") .set_body_typed([](PyExprMutator mutator, const Binding& binding) { mutator->VisitBinding(binding); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { return mutator->VisitBindingBlock(block); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") .set_body_typed([](PyExprMutator mutator, const Var& var) { return mutator->VisitVarDef(var); }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->ExprMutator::VisitExpr(expr); }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") .set_body_typed([](PyExprMutator mutator, const Binding& binding) { if (const auto* ptr = binding.as()) { return mutator->ExprMutator::VisitBinding_(ptr); @@ -639,7 +639,7 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") } }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") .set_body_typed([](PyExprMutator mutator, const BindingBlock& block) { if (const auto* node = block.as()) { return mutator->ExprMutator::VisitBindingBlock_(node); @@ -650,7 +650,7 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") } }); -TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") +TVM_FFI_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") .set_body_typed([](PyExprMutator mutator, const Var& var) { if (const auto* node = var.as()) { return mutator->ExprMutator::VisitVarDef_(node); @@ -661,32 +661,32 @@ TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") } }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->VisitExprPostOrder(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") .set_body_typed([](PyExprMutator mutator, const Expr& expr) { return mutator->VisitWithNewScope(expr); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") .set_body_typed([](PyExprMutator mutator, const Var& var) { return mutator->LookupBinding(var); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorWithStructInfo") .set_body_typed([](PyExprMutator mutator, Var var, StructInfo sinfo) { return mutator->WithStructInfo(var, sinfo); }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") .set_body_typed([](PyExprMutator mutator, Id id, Var var) { return mutator->var_remap_[id] = var; }); -TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") +TVM_FFI_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") .set_body_typed([](PyExprMutator mutator, Id id) { return mutator->var_remap_[id]; }); } // namespace relax diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 28049e941c39..feb1f910a42c 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -21,10 +21,10 @@ * \file src/relax/ir/struct_info.cc * \brief Relax struct info. */ +#include #include #include #include -#include namespace tvm { namespace relax { @@ -37,7 +37,7 @@ ObjectStructInfo::ObjectStructInfo(Span span) { TVM_REGISTER_NODE_TYPE(ObjectStructInfoNode); -TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) { return ObjectStructInfo(span); }); @@ -53,20 +53,18 @@ PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) { PrimStructInfo::PrimStructInfo(DataType dtype, Span span) { ObjectPtr n = make_object(); n->dtype = dtype; - n->value = NullOpt; + n->value = std::nullopt; n->span = span; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(PrimStructInfoNode); -TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype").set_body_typed([](DataType dtype, Span span) { - return PrimStructInfo(dtype, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype") + .set_body_typed([](DataType dtype, Span span) { return PrimStructInfo(dtype, span); }); -TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromValue").set_body_typed([](PrimExpr value, Span span) { - return PrimStructInfo(value, span); -}); +TVM_FFI_REGISTER_GLOBAL("relax.PrimStructInfoFromValue") + .set_body_typed([](PrimExpr value, Span span) { return PrimStructInfo(value, span); }); // Shape ShapeStructInfo::ShapeStructInfo(Array values, Span span) { @@ -94,7 +92,7 @@ ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { TVM_REGISTER_NODE_TYPE(ShapeStructInfoNode); -TVM_REGISTER_GLOBAL("relax.ShapeStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.ShapeStructInfo") .set_body_typed([](Optional> values, int ndim, Span span) { if (values.defined()) { CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; @@ -135,7 +133,7 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Optional v TVM_REGISTER_NODE_TYPE(TensorStructInfoNode); -TVM_REGISTER_GLOBAL("relax.TensorStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.TensorStructInfo") .set_body_typed([](Optional shape, Optional dtype, int ndim, VDevice vdevice, Span span) { if (shape.defined()) { @@ -156,7 +154,7 @@ TupleStructInfo::TupleStructInfo(Array fields, Span span) { TVM_REGISTER_NODE_TYPE(TupleStructInfoNode); -TVM_REGISTER_GLOBAL("relax.TupleStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.TupleStructInfo") .set_body_typed([](Array fields, Span span) { return TupleStructInfo(fields, span); }); @@ -191,12 +189,12 @@ FuncStructInfo FuncStructInfo::OpaqueFunc(StructInfo ret, bool purity, Span span TVM_REGISTER_NODE_TYPE(FuncStructInfoNode); -TVM_REGISTER_GLOBAL("relax.FuncStructInfo") +TVM_FFI_REGISTER_GLOBAL("relax.FuncStructInfo") .set_body_typed([](Array params, StructInfo ret, bool purity, Span span) { return FuncStructInfo(params, ret, purity, span); }); -TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") +TVM_FFI_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc") .set_body_typed([](Optional ret, Optional derive_func, bool purity, Span span) { if (derive_func.defined()) { @@ -220,11 +218,10 @@ void UpdateStructInfo(Expr expr, StructInfo struct_info) { expr->checked_type_ = GetStaticType(struct_info); } -TVM_REGISTER_GLOBAL("relax.UpdateStructInfo").set_body_typed([](Expr expr, StructInfo struct_info) { - UpdateStructInfo(expr, struct_info); -}); +TVM_FFI_REGISTER_GLOBAL("relax.UpdateStructInfo") + .set_body_typed([](Expr expr, StructInfo struct_info) { UpdateStructInfo(expr, struct_info); }); -TVM_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { +TVM_FFI_REGISTER_GLOBAL("ir.ExprStructInfo").set_body_typed([](Expr expr) { return GetStructInfo(expr); }); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index d79d8b3fd50d..a44deba0fe94 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -22,13 +22,13 @@ * \brief Relax specific transformation passes. */ #include +#include #include #include #include #include #include #include -#include namespace tvm { namespace relax { @@ -163,7 +163,7 @@ Pass CreateFunctionPass(std::function TVM_REGISTER_NODE_TYPE(FunctionPassNode); -TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MakeFunctionPass") .set_body_typed( [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { @@ -383,7 +383,7 @@ Pass CreateDataflowBlockPass( TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode); -TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass") .set_body_typed( [](ffi::TypedFunction, IRModule, PassContext)> pass_func, diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 82b95b556bc2..8b70bcf2c7a5 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -21,8 +21,8 @@ * \file src/relax/ir/type.cc * \brief Relax type system. */ +#include #include -#include namespace tvm { namespace relax { @@ -36,7 +36,7 @@ ShapeType::ShapeType(int ndim, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.ShapeType").set_body_typed([](int ndim, Span span) { return ShapeType(ndim, span); }); @@ -48,7 +48,9 @@ ObjectType::ObjectType(Span span) { TVM_REGISTER_NODE_TYPE(ObjectTypeNode); -TVM_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { return ObjectType(span); }); +TVM_FFI_REGISTER_GLOBAL("relax.ObjectType").set_body_typed([](Span span) { + return ObjectType(span); +}); TensorType::TensorType(int ndim, DataType dtype, Span span) { ObjectPtr n = make_object(); @@ -68,7 +70,7 @@ TensorType TensorType::CreateUnknownNDim(DataType dtype, Span span) { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("relax.TensorType").set_body_typed([](int ndim, DataType dtype, Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.TensorType").set_body_typed([](int ndim, DataType dtype, Span span) { return TensorType(ndim, dtype, span); }); @@ -80,7 +82,7 @@ PackedFuncType::PackedFuncType(Span span) { TVM_REGISTER_NODE_TYPE(PackedFuncTypeNode); -TVM_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { +TVM_FFI_REGISTER_GLOBAL("relax.PackedFuncType").set_body_typed([](Span span) { return PackedFuncType(span); }); diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index c32cdc3aacb3..2f6314221ecc 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -36,7 +36,7 @@ Expr allreduce(Expr x, String op_type, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.allreduce").set_body_typed(allreduce); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.allreduce").set_body_typed(allreduce); StructInfo InferStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -63,7 +63,7 @@ Expr allgather(Expr x, int num_workers, bool in_group) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.allgather").set_body_typed(allgather); StructInfo InferStructInfoAllGather(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -94,7 +94,8 @@ Expr broadcast_from_worker0(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0").set_body_typed(broadcast_from_worker0); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.broadcast_from_worker0") + .set_body_typed(broadcast_from_worker0); StructInfo InferStructInfoBroadcastFromZero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -120,7 +121,7 @@ Expr scatter_from_worker0(Expr data, int num_workers, int axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0); +TVM_FFI_REGISTER_GLOBAL("relax.op.ccl.scatter_from_worker0").set_body_typed(scatter_from_worker0); StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index cdeb537c3d9f..84750c0c9c4c 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -48,7 +48,7 @@ Expr annotate_sharding(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.dist.annotate_sharding").set_body_typed(annotate_sharding); +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.annotate_sharding").set_body_typed(annotate_sharding); StructInfo InferStructInfoAnnotateSharding(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -73,7 +73,7 @@ Expr redistribute(Expr input, distributed::DeviceMesh device_mesh, return Call(op, {std::move(input)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.dist.redistribute").set_body_typed(redistribute); +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.redistribute").set_body_typed(redistribute); StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -139,7 +139,7 @@ Expr MakeCallTIRLocalView(Expr func, Tuple args, return call; } -TVM_REGISTER_GLOBAL("relax.op.dist.call_tir_local_view").set_body_typed(MakeCallTIRLocalView); +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.call_tir_local_view").set_body_typed(MakeCallTIRLocalView); StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -208,7 +208,7 @@ Expr redistribute_replica_to_shard(Expr input, int num_workers, int axis) { return Call(op, {std::move(input)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard") +TVM_FFI_REGISTER_GLOBAL("relax.op.dist.redistribute_replica_to_shard") .set_body_typed(redistribute_replica_to_shard); TVM_REGISTER_OP("relax.dist.redistribute_replica_to_shard") diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 5b6550c72903..c45e24df2b13 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -50,7 +50,7 @@ Expr resize2d(Expr data, Expr size, Array roi, String layout, String m return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d); StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1 && call->args.size() != 2) { diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 21a72f6200b0..a7465db868fe 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -40,7 +40,7 @@ Expr view(Expr x, Optional shape, Optional dtype, Optional rel }); } -TVM_REGISTER_GLOBAL("relax.op.memory.view").set_body_typed(view); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.view").set_body_typed(view); StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 4) { @@ -136,7 +136,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return prim_sinfo->value.value(); } else { // An offset of unknown value is applied. - return NullOpt; + return std::nullopt; } } else { LOG(FATAL) << "TypeError: " @@ -149,7 +149,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { Optional> input_shape = data_sinfo->GetShape(); - Optional> output_shape = NullOpt; + Optional> output_shape = std::nullopt; int output_ndim = kUnknownNDim; if (view_shape_sinfo && view_shape_sinfo->values.defined()) { output_shape = view_shape_sinfo->values.value(); @@ -168,7 +168,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // number of bytes per scalar element. auto get_size_bytes = [](const DataType& dtype) -> Optional { if (dtype.is_void()) { - return NullOpt; + return std::nullopt; } else { auto size_bits = dtype.bits() * dtype.lanes(); return IntImm(DataType::Int(64), (size_bits + 7) / 8); @@ -179,7 +179,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // given the shape of that array. auto get_num_elements = [&ctx](const Optional>& shape) -> Optional { if (!shape.defined()) { - return NullOpt; + return std::nullopt; } PrimExpr num_elements = Integer(1); @@ -289,7 +289,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { } } -TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); +TVM_FFI_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo") + .set_body_typed(InferStructInfoView); Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; @@ -360,7 +361,7 @@ Expr ensure_zero_offset(const Expr& x) { return Call(op, {x}); } -TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index ca3746ddad4e..a084747a5cf3 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -57,8 +57,8 @@ Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr s {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); -TVM_REGISTER_GLOBAL("relax.op.nn.attention_var_len").set_body_typed(attention_var_len); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.attention_var_len").set_body_typed(attention_var_len); StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -137,8 +137,8 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { } Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { - return Downcast( - attention(call->args[0], call->args[1], call->args[2], NullOpt, NullOpt, NullOpt, NullOpt)); + return Downcast(attention(call->args[0], call->args[1], call->args[2], std::nullopt, + std::nullopt, std::nullopt, std::nullopt)); } TVM_REGISTER_OP("relax.nn.attention") diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index cca50689cb02..f335bb9e7c7b 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -51,7 +51,7 @@ Expr conv1d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv1d"); } -TVM_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv1d").set_body_typed(conv1d); StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -214,7 +214,7 @@ Expr conv2d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv2d"); } -TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d); StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -413,7 +413,7 @@ Expr conv3d(Expr data, Expr weight, Array strides, Array padding out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv3d"); } -TVM_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.conv3d").set_body_typed(conv3d); StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -593,7 +593,7 @@ Expr conv1d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -730,7 +730,7 @@ Expr conv2d_transpose(Expr data, Expr weight, Array strides, Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 19ea095e07a2..6da83697ee15 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -50,7 +50,7 @@ Expr leakyrelu(Expr data, double alpha) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.leakyrelu").set_body_typed(leakyrelu); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.leakyrelu").set_body_typed(leakyrelu); TVM_REGISTER_OP("relax.nn.leakyrelu") .set_num_inputs(1) @@ -71,7 +71,7 @@ Expr softplus(Expr data, double beta, double threshold) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); TVM_REGISTER_OP("relax.nn.softplus") .set_num_inputs(1) @@ -91,7 +91,7 @@ Expr prelu(Expr data, Expr alpha, int axis = 1) { return Call(op, {data, alpha}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); TVM_REGISTER_OP("relax.nn.prelu") .set_num_inputs(2) @@ -112,7 +112,7 @@ Expr softmax(Expr data, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax); StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -170,7 +170,7 @@ Expr log_softmax(Expr data, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.log_softmax").set_body_typed(log_softmax); TVM_REGISTER_OP("relax.nn.log_softmax") .set_num_inputs(1) @@ -191,7 +191,7 @@ Expr pad(Expr data, Array pad_width, String pad_mode, double pad_value) return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad); StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -234,7 +234,7 @@ Expr pixel_shuffle(Expr data, int upscale_factor) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -383,7 +383,7 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ std::move(moving_var)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -460,7 +460,7 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double ep return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm); StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -528,7 +528,7 @@ Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_ax return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); @@ -639,7 +639,7 @@ Expr instance_norm(Expr data, Expr gamma, Expr beta, int channel_axis, Array(call->op); @@ -735,7 +735,7 @@ Expr rms_norm(Expr data, Expr weight, Array axes, double epsilon) { return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm); StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -794,7 +794,7 @@ Expr dropout(Expr data, double rate) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout); StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -862,7 +862,7 @@ Expr cross_entropy_with_logits(Expr predictions, Expr labels) { return Call(op, {std::move(predictions), std::move(labels)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.cross_entropy_with_logits") .set_body_typed(cross_entropy_with_logits); TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits") @@ -896,7 +896,7 @@ Expr nll_loss(Expr predictions, Expr targets, Optional weights, String red } } -TVM_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss); StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (call->args.size() < 2 || call->args.size() > 3) { diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 565e6a00c60d..0161a4d4195d 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -62,7 +62,7 @@ Expr max_pool1d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d); StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -95,7 +95,7 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_w += attrs->strides[1] - 1; + numerator_w += attrs->strides[0] - 1; } out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, attrs->strides[0]) + 1); @@ -175,7 +175,7 @@ Expr max_pool2d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d); StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -314,7 +314,7 @@ Expr max_pool3d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d); StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -401,7 +401,7 @@ Expr avg_pool1d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d); TVM_REGISTER_OP("relax.nn.avg_pool1d") .set_num_inputs(1) @@ -420,7 +420,7 @@ Expr avg_pool2d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d); TVM_REGISTER_OP("relax.nn.avg_pool2d") .set_num_inputs(1) @@ -439,7 +439,7 @@ Expr avg_pool3d(Expr data, Array pool_size, Array strides, Array count_include_pad, layout, out_layout); } -TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d); TVM_REGISTER_OP("relax.nn.avg_pool3d") .set_num_inputs(1) @@ -470,7 +470,7 @@ Expr adaptive_avg_pool1d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool1d").set_body_typed(adaptive_avg_pool1d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool1d").set_body_typed(adaptive_avg_pool1d); StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -553,7 +553,7 @@ Expr adaptive_avg_pool2d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d); StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -652,7 +652,7 @@ Expr adaptive_avg_pool3d(Expr data, Optional> output_size, String return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool3d").set_body_typed(adaptive_avg_pool3d); +TVM_FFI_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool3d").set_body_typed(adaptive_avg_pool3d); StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index bb3d645adea4..c581b20835c6 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -119,7 +119,7 @@ Expr MakeCallPurePacked(const Expr& callee, Array args, const Attrs& attrs return Call(op, call_args, attrs, sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_pure_packed").set_body_typed(MakeCallPurePacked); // call_inplace_packed @@ -238,7 +238,7 @@ Expr MakeCallInplacePacked(Expr func, Array args, Array inplace_i return Call(op, call_args, Attrs(attrs), sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInplacePacked); // call_tir @@ -258,11 +258,11 @@ TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla * If the arguments provided are not compatible with the PrimFunc's * signature, an error will be raised. If the arguments are * compatible with the PrimFunc's signature, but are not sufficient to - * determine the output's StructInfo, then `NullOpt` will be returned. + * determine the output's StructInfo, then `std::nullopt` will be returned. * * \param func_sinfo The StructInfo of the TIR callee. * \param arg_sinfo The StructInfo of the argument tuple. - * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument, + * \param packed_ints_sinfo The StructInfo of the ffi::Shape argument, * if present. * \param opt_inplace_indices For `R.call_tir_inplace`, an array of * indices indicating which outputs are constructed from in-place @@ -270,7 +270,7 @@ TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla * `CallTIRInplaceAttrs::inplace_indices` for more details. * * \return The `arg_sinfo`, if it can be inferred from the arguments. - * Otherwise, NullOpt. + * Otherwise, std::nullopt. */ static Optional InferCallTIROutputStructInfoFromArguments( StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo, @@ -311,7 +311,7 @@ static Optional InferCallTIROutputStructInfoFromArguments( CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim()) << "TypeError: " << "The third argument to `R.call_tir`, if present, " - << "must be a ShapeTuple with known dimensionality. " + << "must be a ffi::Shape with known dimensionality. " << "However, the argument received was of type " << packed_sinfo; num_trailing_int_arguments = packed_tuple_sinfo->ndim; } else { @@ -338,7 +338,7 @@ static Optional InferCallTIROutputStructInfoFromArguments( } }; if (contains_dtensor(arg_sinfo)) { - return NullOpt; + return std::nullopt; } // At this point, the return types are known. However, the shapes @@ -413,7 +413,7 @@ static Optional InferCallTIROutputStructInfoFromArguments( auto derived_ret_sinfo = DeriveCallRetStructInfo( dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), dummy_args), - BlockBuilder::Create(NullOpt)); + BlockBuilder::Create(std::nullopt)); return derived_ret_sinfo; } @@ -466,7 +466,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { Expr packed_ints = call->args[2]; CHECK(packed_ints->struct_info_.as()) << "Operation " << call->op << " expects the optional third argument, " - << "if present, to be a ShapeTuple. " + << "if present, to be a ffi::Shape. " << "However, the third argument " << packed_ints << " has struct info " << packed_ints->struct_info_; } @@ -481,7 +481,7 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return bound_value.value(); } } - return NullOpt; + return std::nullopt; }; Tuple new_arg_tuple = [&]() { @@ -535,7 +535,7 @@ void ValidateCallTIR(Call call) { auto packed_int_sinfo = [&]() -> Optional { if (call->args.size() <= 2) { - return NullOpt; + return std::nullopt; } else { return GetStructInfo(call->args[2]); } @@ -545,7 +545,7 @@ void ValidateCallTIR(Call call) { if (const auto* attrs = call->attrs.as()) { return attrs->inplace_indices; } else { - return NullOpt; + return std::nullopt; } }(); @@ -600,7 +600,7 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, return call; } -TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); // call_tir_with_grad @@ -652,7 +652,7 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinf return call; } -TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWithGrad); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWithGrad); // call_tir_inplace @@ -793,7 +793,7 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, return call; } -TVM_REGISTER_GLOBAL("relax.op.call_tir_inplace").set_body_typed(MakeCallTIRInplace); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_tir_inplace").set_body_typed(MakeCallTIRInplace); // call_dps_packed @@ -834,7 +834,7 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_ return Call(op, {func, args}, {}, {out_sinfo}); } -TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { @@ -860,7 +860,7 @@ Expr MakeCallBuiltinWithCtx(Expr func, Tuple args, Array sinfo_args) return Call(op, {func, args}, Attrs(), sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); +TVM_FFI_REGISTER_GLOBAL("relax.op.call_builtin_with_ctx").set_body_typed(MakeCallBuiltinWithCtx); TVM_REGISTER_OP("relax.null_value") .set_num_inputs(0) @@ -872,7 +872,7 @@ Expr MakeCallNullValue() { return Call(op, {}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); +TVM_FFI_REGISTER_GLOBAL("relax.op.null_value").set_body_typed(MakeCallNullValue); // print @@ -895,7 +895,7 @@ Expr MakePrint(Array vals, StringImm format) { return Call(op, params); } -TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); +TVM_FFI_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); // assert_op @@ -938,7 +938,7 @@ Expr MakeAssertOp(Expr condition, Array vals, StringImm format) { return Call(op, args); } -TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); +TVM_FFI_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); // make_closure @@ -954,7 +954,7 @@ Expr MakeClosure(Expr func, Tuple args) { return Call(op, {func, args}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); +TVM_FFI_REGISTER_GLOBAL("relax.op.make_closure").set_body_typed(MakeClosure); // invoke_closure @@ -981,7 +981,7 @@ Expr InvokeClosure(Expr closure, Tuple args, Array sinfo_args) { return Call(op, {closure, args}, {}, sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); +TVM_FFI_REGISTER_GLOBAL("relax.op.invoke_closure").set_body_typed(InvokeClosure); // invoke_pure_closure @@ -997,7 +997,7 @@ Expr InvokePureClosure(Expr closure, Tuple args, Array sinfo_args) { return Call(op, {closure, args}, {}, sinfo_args); } -TVM_REGISTER_GLOBAL("relax.op.invoke_pure_closure").set_body_typed(InvokePureClosure); +TVM_FFI_REGISTER_GLOBAL("relax.op.invoke_pure_closure").set_body_typed(InvokePureClosure); // shape_of @@ -1012,7 +1012,7 @@ Expr MakeShapeOf(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); +TVM_FFI_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); // tensor_to_shape @@ -1046,7 +1046,7 @@ Expr MakeTensorToShape(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); +TVM_FFI_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); // shape_to_tensor StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { @@ -1070,7 +1070,7 @@ Expr MakeShapeToTensor(Expr expr) { return Call(op, {expr}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor); // alloc_tensor @@ -1107,7 +1107,7 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind return Call(op, {shape, dtype, runtime_device_index, storage_scope}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); // memory planning alloc_storage @@ -1132,7 +1132,7 @@ Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm stora return Call(op, {size, virtual_device_index, storage_scope, dtype}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); // memory planning alloc_tensor @@ -1163,7 +1163,7 @@ Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); // memory planning kill_storage @@ -1179,7 +1179,7 @@ Expr MakeMemKillStorage(Expr storage) { return Call(op, {storage}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); // memory planning kill_tensor @@ -1195,7 +1195,7 @@ Expr MakeMemKillTensor(Expr tensor) { return Call(op, {tensor}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); // vm alloc_storage @@ -1219,7 +1219,7 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d return Call(op, {size, runtime_device_index, dtype, storage_scope}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.alloc_storage").set_body_typed(MakeVMAllocStorage); // vm alloc_tensor @@ -1257,7 +1257,7 @@ Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm d return Call(op, {storage, offset, shape, dtype}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor); // vm kill_object @@ -1273,7 +1273,7 @@ Expr MakeVMKillObject(Expr obj) { return Call(op, {std::move(obj)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.kill_object").set_body_typed(MakeVMKillObject); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.kill_object").set_body_typed(MakeVMKillObject); // vm call_tir_dyn @@ -1291,7 +1291,7 @@ Expr MakeCallTIRDyn(Expr func, Tuple args) { return Call(op, {func, args}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); +TVM_FFI_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); // builtin stop_lift_params StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { @@ -1309,7 +1309,7 @@ Expr MakeStopLiftParams(Expr x) { return Call(op, {x}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); +TVM_FFI_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); // to_vdevice TVM_REGISTER_NODE_TYPE(ToVDeviceAttrs); @@ -1340,7 +1340,7 @@ Expr MakeToVDevice(Expr data, VDevice dst_vdev) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice); +TVM_FFI_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice); // hint_on_device TVM_REGISTER_NODE_TYPE(HintOnDeviceAttrs); @@ -1367,7 +1367,7 @@ Expr MakeHintOnDevice(Expr data, Device device) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); +TVM_FFI_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index f9c1ece38c18..f439a345eb19 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -136,7 +136,7 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc << " is " << dim1 << ", which are not broadcastable."); } else { // Use simple fallback when shape mismatch. - return NullOpt; + return std::nullopt; } } auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape; diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 85a3d885dce8..d7d50f8fa714 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -181,7 +181,7 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c static const Op& op = Op::Get("relax." OpRegName); \ return Call(op, {std::move(x)}, Attrs(), {}); \ } \ - TVM_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) + TVM_FFI_REGISTER_GLOBAL("relax.op." OpRegName).set_body_typed(OpName) /************ Utilities ************/ @@ -344,7 +344,7 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl if (const auto* tensor = sinfo.as()) { return tensor->vdevice; } else { - return NullOpt; + return std::nullopt; } }; @@ -374,8 +374,8 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl * \param ctx The error reporting context. * \param x1_shape The shape of the first operand. * \param x2_shape The shape of the second operand. - * \return The inferred output shape after broadcasting. Or `NullOpt` if the output shape cannot be - * determined due to symbolic broadcast. + * \return The inferred output shape after broadcasting. Or `std::nullopt` if the output shape + * cannot be determined due to symbolic broadcast. */ Optional> InferBinaryBroadcastShape(const Call& call, const BlockBuilder& ctx, const Array& x1_shape, @@ -536,7 +536,7 @@ inline std::pair CheckTensorLayout(const Call * \param ctx The error reporting context. * \param sinfo The input tensor struct info to be checked. * \param layout The layout that the given tensor is expected to have. - * \return The shape of the input tensor in ShapeExpr, or `NullOpt` if the shape is unknown. + * \return The shape of the input tensor in ShapeExpr, or `std::nullopt` if the shape is unknown. */ inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, const TensorStructInfo& sinfo, @@ -550,7 +550,7 @@ inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const if (const auto* shape_expr = sinfo->shape.as()) { return GetRef(shape_expr); } - return NullOpt; + return std::nullopt; } Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm dtype, diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index e7fab8f166e1..74ae8e9cbc5c 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -92,7 +92,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, } else if (const auto* tensor = sinfo.as()) { return tensor->GetShape(); } else { - return NullOpt; + return std::nullopt; } }; @@ -113,7 +113,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, if (const auto* tensor = sinfo.as()) { return tensor->shape; } else { - return NullOpt; + return std::nullopt; } }; diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 6b106f760d5f..ae36d45b3683 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -42,7 +42,7 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {x1, x2}, Attrs(), {}); \ } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(2) \ .add_argument("x1", "Tensor", "The first input tensor.") \ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 361f90c7b043..b2355b1af7f0 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -40,7 +40,7 @@ Expr full(Variant> shape, Expr fill_value, Optional()) { shape_in_expr = GetRef(expr); - } else if (const auto* _array = shape.as()) { + } else if (const auto* _array = shape.as()) { shape_in_expr = ShapeExpr(GetRef>(_array)); } else { LOG(FATAL) @@ -54,7 +54,7 @@ Expr full(Variant> shape, Expr fill_value, Optionalargs.size() != 2) { @@ -96,7 +96,7 @@ Expr full_like(Expr x, Expr fill_value, Optional dtype) { return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -174,8 +174,8 @@ Expr ones_like(Expr x, Optional dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); -TVM_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); +TVM_FFI_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); TVM_REGISTER_OP("relax.ones") .set_attrs_type() @@ -209,8 +209,8 @@ Expr zeros_like(Expr x, Optional dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); -TVM_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); +TVM_FFI_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); TVM_REGISTER_OP("relax.zeros") .set_attrs_type() @@ -242,8 +242,8 @@ Expr eye_like(Expr x, PrimValue k, Optional dtype) { return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); -TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); +TVM_FFI_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -319,7 +319,7 @@ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { return Call(op, {std::move(start), std::move(stop), std::move(step)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); +TVM_FFI_REGISTER_GLOBAL("relax.op.arange").set_body_typed(arange); StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 3) { @@ -380,8 +380,8 @@ Expr triu(Expr x, Expr k) { Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); } -TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast(tril)); -TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast(triu)); +TVM_FFI_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast(tril)); +TVM_FFI_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast(triu)); StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { auto [data_sinfo, offset] = GetArgStructInfo(call, ctx); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 5f770b2a6ed9..0bf15bbd57e7 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAX_OP_TENSOR_CREATE_H_ #define TVM_RELAX_OP_TENSOR_CREATE_H_ +#include #include -#include #include "../op_common.h" diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index bc24285cf9c7..d1d5bbccbcc7 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -40,7 +40,7 @@ Expr astype(Expr x, DataType dtype) { return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); +TVM_FFI_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype); StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -70,7 +70,7 @@ Expr MakeWrapParam(Expr data, DataType dtype) { return Call(op, {std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam); +TVM_FFI_REGISTER_GLOBAL("relax.op.wrap_param").set_body_typed(MakeWrapParam); StructInfo InferStructInfoWrapParam(const Call& call, const BlockBuilder& ctx) { TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc index d8aecb3461d4..c25d587052f1 100644 --- a/src/relax/op/tensor/grad.cc +++ b/src/relax/op/tensor/grad.cc @@ -35,7 +35,7 @@ Expr no_grad(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.no_grad").set_body_typed(no_grad); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.no_grad").set_body_typed(no_grad); StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[0]); @@ -53,7 +53,7 @@ Expr start_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.start_checkpoint").set_body_typed(start_checkpoint); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.start_checkpoint").set_body_typed(start_checkpoint); StructInfo InferStructInfoStartCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -75,7 +75,7 @@ Expr end_checkpoint(Expr input) { return Call(op, {std::move(input)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.end_checkpoint").set_body_typed(end_checkpoint); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.end_checkpoint").set_body_typed(end_checkpoint); StructInfo InferStructInfoEndCheckpoint(const Call& call, const BlockBuilder& ctx) { if (!call->args[0].as()) { @@ -111,7 +111,7 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets, Optiona } } -TVM_REGISTER_GLOBAL("relax.op.grad.nll_loss_backward").set_body_typed(nll_loss_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.nll_loss_backward").set_body_typed(nll_loss_backward); StructInfo InferStructInfoNLLLossBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -145,7 +145,7 @@ Expr max_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.max_pool2d_backward").set_body_typed(max_pool2d_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.max_pool2d_backward").set_body_typed(max_pool2d_backward); StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -177,7 +177,7 @@ Expr avg_pool2d_backward(Expr output_grad, Expr data, Array pool_size, return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.avg_pool2d_backward").set_body_typed(avg_pool2d_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.avg_pool2d_backward").set_body_typed(avg_pool2d_backward); StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); @@ -202,7 +202,7 @@ Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional axi return Call(op, {std::move(output_grad), std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.grad.take_backward").set_body_typed(take_backward); +TVM_FFI_REGISTER_GLOBAL("relax.op.grad.take_backward").set_body_typed(take_backward); StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder& ctx) { return GetStructInfo(call->args[1]); diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 4e62a9ecd7d7..26978f2fad74 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -47,7 +47,7 @@ Expr take(Expr x, Expr indices, Optional axis) { return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); +TVM_FFI_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); @@ -169,7 +169,7 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid return call; } -TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); +TVM_FFI_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); /* \brief Helper function to unpack a relax::Tuple * @@ -190,17 +190,17 @@ TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); * \param sinfo The StructInfo to inspect * * \returns An array of the `PrimType`, if it can be extracted. - * Otherwise, `NullOpt`. + * Otherwise, `std::nullopt`. */ template >> Optional> UnpackTupleOfPrimValue(Optional sinfo) { - if (!sinfo) return NullOpt; + if (!sinfo) return std::nullopt; // An ObjectStructInfo may contain a tuple of the desired type, but // it isn't yet known whether it does. Return early, as we cannot // provide a known `Array` to the caller. - if (sinfo.as()) return NullOpt; + if (sinfo.as()) return std::nullopt; auto tuple = sinfo.as(); CHECK(tuple) << "TypeError: " @@ -211,7 +211,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { for (size_t i = 0; i < tuple->fields.size(); i++) { auto field = tuple->fields[i]; - if (field.as()) return NullOpt; + if (field.as()) return std::nullopt; auto prim_sinfo = field.as(); CHECK(prim_sinfo) << "TypeError: " @@ -220,7 +220,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { << PrimType::ContainerType::_type_key << ", because element " << i << " has struct info " << field; - if (!prim_sinfo->value.defined()) return NullOpt; + if (!prim_sinfo->value.defined()) return std::nullopt; Optional element = prim_sinfo->value.as(); if (!element) return std::nullopt; @@ -249,7 +249,7 @@ Optional> UnpackTupleOfPrimValue(Optional sinfo) { * \param expr The `relax::Expr` to inspect * * \returns An array of the `PrimType`, if it can be extracted. - * Otherwise, `NullOpt`. + * Otherwise, `std::nullopt`. */ template >> @@ -257,7 +257,7 @@ Optional> UnpackTupleOfPrimValue(Optional expr) { if (expr) { return UnpackTupleOfPrimValue(GetStructInfo(expr.value())); } else { - return NullOpt; + return std::nullopt; } } @@ -276,7 +276,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx if (n_args > 4) { return call->args[4]; } else { - return NullOpt; + return std::nullopt; } }(); @@ -287,7 +287,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx if (n_args > 4) { return GetStructInfo(call->args[4]); } else { - return NullOpt; + return std::nullopt; } }(); @@ -329,7 +329,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx const auto* data_sinfo = data->struct_info_.as(); DataType dtype = DataType::Void(); - Optional vdevice = NullOpt; + Optional vdevice = std::nullopt; int ndim = kUnknownNDim; if (data_sinfo) { dtype = data_sinfo->dtype; @@ -338,15 +338,15 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx } Optional shape = [&]() -> Optional { - if (!data_sinfo) return NullOpt; - if (!data_sinfo->shape) return NullOpt; + if (!data_sinfo) return std::nullopt; + if (!data_sinfo->shape) return std::nullopt; auto opt_axes_tuple = UnpackTupleOfPrimValue(axes); - if (!opt_axes_tuple) return NullOpt; + if (!opt_axes_tuple) return std::nullopt; auto axes_tuple = opt_axes_tuple.value(); auto opt_begin_tuple = UnpackTupleOfPrimValue(begin); - if (!opt_begin_tuple) return NullOpt; + if (!opt_begin_tuple) return std::nullopt; auto begin_tuple = opt_begin_tuple.value(); CHECK_EQ(axes_tuple.size(), begin_tuple.size()) @@ -356,7 +356,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx << ") and " << begin_tuple.size() << " 'begin' indices specified (" << begin_tuple << ")"; auto opt_end_tuple = UnpackTupleOfPrimValue(end); - if (!opt_end_tuple) return NullOpt; + if (!opt_end_tuple) return std::nullopt; auto end_tuple = opt_end_tuple.value(); CHECK_EQ(axes_tuple.size(), end_tuple.size()) @@ -368,7 +368,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx Array strides_tuple; if (strides.defined()) { auto opt_strides_tuple = UnpackTupleOfPrimValue(strides); - if (!opt_strides_tuple) return NullOpt; + if (!opt_strides_tuple) return std::nullopt; strides_tuple = opt_strides_tuple.value(); } else { @@ -387,7 +387,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx if (axes_tuple.empty() && !opt_data_shape.defined()) { return data_sinfo->shape.value(); } else if (!opt_data_shape.defined()) { - return NullOpt; + return std::nullopt; } std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, axes_tuple); @@ -477,7 +477,7 @@ Expr dynamic_strided_slice(Expr x, // return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); } -TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice); +TVM_FFI_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice); StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h index 46b5fd501b95..63a12e28f622 100644 --- a/src/relax/op/tensor/index.h +++ b/src/relax/op/tensor/index.h @@ -37,7 +37,7 @@ namespace relax { * \param indices The indices of the values to extract. * It is required to be a one-dimensional tensor which has integer dtype. * \param axis The axis over which to select values. - * If it is `NullOpt`, the input tensor is required to be one-dimensional. + * If it is `std::nullopt`, the input tensor is required to be one-dimensional. * \return The taken result. */ Expr take(Expr x, Expr indices, Optional axis); @@ -50,11 +50,11 @@ Expr take(Expr x, Expr indices, Optional axis); * \param end The indices indicating end of the slice, exclusive. * \param strides Specifies the stride values, it can be negative in that case, * the input tensor will be reversed in that particular axis. - * If it is `NullOpt`, it by default is an list of ones of the same length as `axes`. + * If it is `std::nullopt`, it by default is an list of ones of the same length as `axes`. * \param assume_inbound Whether to assume the indices are in bound. * \return The sliced result */ -Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strides = NullOpt, +Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strides = std::nullopt, bool assume_inbound = false); } // namespace relax diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 0fdbee1c6aac..28af375f00a6 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -44,7 +44,7 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype) { return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); +TVM_FFI_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -176,7 +176,7 @@ Expr einsum(Expr operands, String subscripts) { return Call(op, {std::move(operands)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.einsum").set_body_typed(einsum); +TVM_FFI_REGISTER_GLOBAL("relax.op.einsum").set_body_typed(einsum); StructInfo InferStructInfoEinsum(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -251,5 +251,43 @@ TVM_REGISTER_OP("relax.einsum") .set_attr("FInferStructInfo", InferStructInfoEinsum) .set_attr("FPurity", Bool(true)); +/* relax.outer */ + +Expr outer(Expr x1, Expr x2) { + static const Op& op = Op::Get("relax.outer"); + return Call(op, {std::move(x1), std::move(x2)}, {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer); + +StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { + auto input_sinfo = GetInputTensorStructInfo(call, ctx); + auto x1_sinfo = input_sinfo[0]; + auto x2_sinfo = input_sinfo[1]; + + // Ensure both inputs are 1D tensors + if (x1_sinfo->ndim != 1 || x2_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "torch.outer requires both inputs to be 1D tensors."); + } + + // Determine output shape + auto x1_shape = x1_sinfo->shape.as(); + auto x2_shape = x2_sinfo->shape.as(); + if (!x1_shape || !x2_shape) { + return TensorStructInfo(x1_sinfo->dtype, 2); + } + Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; + return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.outer") + .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOuter) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h index 638e5af8f87e..eb003fed1c76 100644 --- a/src/relax/op/tensor/linear_algebra.h +++ b/src/relax/op/tensor/linear_algebra.h @@ -51,6 +51,14 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype); */ Expr einsum(Expr operands, String subscripts); +/*! + * \brief Compute the outer product of two input expressions. + * \param x1 The first input expression. + * \param x2 The second input expression. + * \return The resulting expression representing the outer product. + */ +Expr outer(Expr x1, Expr x2); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index d3da5a5e1199..f834bed2538e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -41,7 +41,7 @@ Expr broadcast_to(Expr x, Expr shape) { return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); +TVM_FFI_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -124,7 +124,7 @@ Expr concat(Expr tensors, Optional axis) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); +TVM_FFI_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -171,7 +171,7 @@ Optional> CheckConcatOutputShape(const Call& call, const BlockBu } if (shape_unknown) { - return NullOpt; + return std::nullopt; } Array output_shape = shape_values[0]; output_shape.Set(axis, concat_sum); @@ -192,7 +192,7 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1; DataType output_dtype = DataType::Void(); - Optional vdev = NullOpt; + Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; @@ -260,7 +260,7 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { output_dtype = DataType::Void(); } if (vdevice_unknown) { - vdev = NullOpt; + vdev = std::nullopt; } if (output_ndim == kUnknownNDim) { @@ -340,7 +340,7 @@ Expr expand_dims(Expr x, Array axis) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); +TVM_FFI_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -446,7 +446,7 @@ Expr flatten(Expr x) { return Call(op, {std::move(x)}, {}, {}); } -TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); +TVM_FFI_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -481,7 +481,7 @@ Expr index_tensor(Expr first, Expr tensors) { return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); +TVM_FFI_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -635,7 +635,7 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); +TVM_FFI_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -702,7 +702,7 @@ Expr permute_dims(Expr x, Optional> axes) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); +TVM_FFI_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); bool IsIdentityPermutation(const std::vector& permutation) { for (int i = 0; i < static_cast(permutation.size()); ++i) { @@ -812,16 +812,16 @@ TVM_REGISTER_OP("relax.permute_dims") /* relax.reshape */ Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { - const ArrayObj* array; + const ffi::ArrayObj* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { - array = e->values.as(); + array = e->values.as(); // Other non-shape expressions are used directly. } else if (const auto* e = shape.as()) { return GetRef(e); // Process special values in constants and produce an expression. } else { - array = shape.as(); + array = shape.as(); } CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " @@ -910,7 +910,7 @@ Expr reshape(Expr x, Variant> shape) { return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); +TVM_FFI_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -973,7 +973,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) ObjectPtr attrs = make_object(); ObjectRef indices_or_sections_obj; - if (const auto* indices = indices_or_sections.as()) { + if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { const auto* idx = indices->at(i).as(); CHECK(idx != nullptr) << "Split op only accepts an array of integers as the indices. " @@ -997,7 +997,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) return Call(op, {std::move(x)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); +TVM_FFI_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1150,7 +1150,7 @@ Expr squeeze(Expr x, Optional> axis) { return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); +TVM_FFI_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -1350,7 +1350,7 @@ Expr stack(Expr tensors, Optional axis) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack); +TVM_FFI_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack); Optional> CheckStackOutputShape(const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, @@ -1373,7 +1373,7 @@ Optional> CheckStackOutputShape(const Call& call, const BlockBui } if (shape_unknown) { - return NullOpt; + return std::nullopt; } // Insert new dimension at axis position @@ -1406,7 +1406,7 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { // Default axis is 0 if not specified int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension DataType output_dtype = DataType::Void(); - Optional vdev = NullOpt; + Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; bool vdevice_unknown = false; @@ -1459,7 +1459,7 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { } if (is_void_dtype) output_dtype = DataType::Void(); - if (vdevice_unknown) vdev = NullOpt; + if (vdevice_unknown) vdev = std::nullopt; // Normalize axis (default to 0 if not specified) int axis = @@ -1554,7 +1554,7 @@ Expr collapse_sum_like(Expr data, Expr collapse_target) { return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); +TVM_FFI_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -1600,7 +1600,7 @@ Expr collapse_sum_to(Expr data, Expr shape) { return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); +TVM_FFI_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { @@ -1655,7 +1655,7 @@ Expr repeat(Expr data, int repeats, Optional axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat); +TVM_FFI_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat); StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1720,7 +1720,7 @@ Expr tile(Expr data, Array repeats) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile); +TVM_FFI_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile); StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -1783,7 +1783,7 @@ Expr flip(Expr data, Integer axis) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.flip").set_body_typed(flip); +TVM_FFI_REGISTER_GLOBAL("relax.op.flip").set_body_typed(flip); StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -1820,7 +1820,7 @@ Expr gather_elements(Expr data, Expr indices, int axis) { return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements); +TVM_FFI_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements); StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -1889,7 +1889,7 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims) { return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd); +TVM_FFI_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd); StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -1983,7 +1983,7 @@ Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) { return Call(op, {data, indices, values}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); +TVM_FFI_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put); StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) { const auto* data_sinfo = GetStructInfoAs(call->args[0]); @@ -2106,7 +2106,7 @@ Expr meshgrid(Expr tensors, Optional indexing) { return Call(op, {std::move(tensors)}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid); +TVM_FFI_REGISTER_GLOBAL("relax.op.meshgrid").set_body_typed(meshgrid); StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { @@ -2124,7 +2124,7 @@ StructInfo InferStructInfoMeshgrid(const Call& call, const BlockBuilder& ctx) { std::vector lengths; DataType common_dtype = DataType::Void(); bool shape_unknown = false; - Optional vdev = NullOpt; + Optional vdev = std::nullopt; bool vdevice_unknown = false; for (int i = 0; i < n_inputs; ++i) { @@ -2210,7 +2210,7 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.scatter_elements").set_body_typed(scatter_elements); +TVM_FFI_REGISTER_GLOBAL("relax.op.scatter_elements").set_body_typed(scatter_elements); StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); @@ -2324,7 +2324,7 @@ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { return Call(op, {data, indices, updates}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); +TVM_FFI_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] @@ -2448,6 +2448,161 @@ TVM_REGISTER_OP("relax.scatter_nd") .set_attr("FInferStructInfo", InferStructInfoScatterND) .set_attr("FPurity", Bool(true)); +/* relax.scatter_nd */ +TVM_REGISTER_NODE_TYPE(SliceScatterAttrs); + +Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("relax.slice_scatter"); + return Call(op, {input, src, start, end, step}, Attrs(attrs), {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.slice_scatter").set_body_typed(slice_scatter); + +StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* src_sinfo = GetStructInfoAs(call->args[1]); + auto* attrs = call->attrs.as(); + + auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& arg_expr, + String name) { + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the input " << name + << " to be a Tensor. However, the given one is " + << arg_expr->struct_info_->GetTypeKey()); + } + }; + + diag_tensor_check(data_sinfo, call->args[0], "data"); + diag_tensor_check(src_sinfo, call->args[1], "src"); + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + int ndim = data_sinfo->ndim; + int raw_axis = attrs->axis; + if (raw_axis < -ndim || raw_axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter requires the input axis to be in the range " + << "[" << -ndim << ", " << ndim - 1 << "]. However, the input axis is " + << raw_axis << ", while ndim is " << ndim); + } + + if (!data_sinfo->IsUnknownNdim() && !src_sinfo->IsUnknownNdim()) { + if (data_sinfo->ndim != src_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the data tensor to have the same rank as the " + "src tensor. However, the given dimensions are " + << "src: " << src_sinfo->ndim << ", data: " << data_sinfo->ndim); + } + } + + if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) { + auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name) { + if (sinfo->IsUnknownDtype()) { + LOG(WARNING) << "SliceScatter: Data type of " << name + << " has not been specified for call node " << call + << ". Assuming it is compatible."; + } + }; + diag_dtype_warn(data_sinfo, "data"); + diag_dtype_warn(src_sinfo, "src"); + } else { + if (data_sinfo->dtype != src_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the input data to have the same type as " + "src. However, the given types are " + << "data: " << data_sinfo->dtype << ", src: " << src_sinfo->dtype); + } + } + + auto get_prim_expr_from_arg = [&ctx, &call](const Expr& arg_expr, std::string key) -> PrimExpr { + const auto* prim_value_node = arg_expr.as(); + if (prim_value_node == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter expects the `" << key << "` argument (" << arg_expr + << ") to be a PrimValue, but got " << arg_expr->GetTypeKey()); + } + const PrimExpr& prim_expr = prim_value_node->value; + if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter expects `" << key << "` (" << prim_expr + << ") to be an integer PrimValue, but got dtype " << prim_expr.dtype()); + } + return prim_expr; + }; + + PrimExpr start_val = get_prim_expr_from_arg(call->args[2], "start"); + PrimExpr stop_val = get_prim_expr_from_arg(call->args[3], "end"); + PrimExpr step_val = get_prim_expr_from_arg(call->args[4], "step"); + + if (analyzer->CanProve(step_val < 1)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the step (" << step_val << ") to be >= 1."); + } + + if (analyzer->CanProve(stop_val < start_val)) { + ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter op requires start (" << start_val + << ") <= end (" << stop_val << ")."); + } + + int axis = NormalizeAxis(call, ctx, ndim, attrs->axis); + + const auto* data_shape_node = data_sinfo->shape.as(); + const auto* src_shape_node = src_sinfo->shape.as(); + + if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) { + ICHECK_EQ(data_shape_node->values.size(), static_cast(ndim)) + << "Internal error: data_shape_node rank mismatch with data_sinfo->ndim for call " << call; + ICHECK_EQ(src_shape_node->values.size(), static_cast(src_sinfo->ndim)) + << "Internal error: src_shape_node rank mismatch with src_sinfo->ndim for call " << call; + + PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val - PrimExpr(1)), step_val); + + for (int i = 0; i < ndim; i++) { + if (i != axis) { + if (analyzer->CanProve(data_shape_node->values[i] != src_shape_node->values[i])) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "SliceScatter op requires the data tensor to have the same shape as the " + "src tensor except at the scatter axis (" + << axis << "). Mismatch at dimension " << i << ". " + << "data shape: " << data_sinfo->GetShape().value() + << ", src shape: " << src_sinfo->GetShape().value()); + } + } + } + + if (analyzer->CanProve(src_shape_node->values[axis] != num_elem)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the src tensor's dimension at scatter axis (" + << axis << ") to match the number of elements in the slice. " + << "Actual src dimension at axis " << axis << ": " + << src_shape_node->values[axis] + << ", Expected elements in slice (num_elem): " << num_elem); + } + } + + if (data_sinfo->shape.defined()) { + return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.slice_scatter") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("input", "Tensor", "The input tensor.") + .add_argument("src", "Tensor", "The source tensor to scatter.") + .add_argument("start", "PrimValue", "The starting index of the slice (inclusive).") + .add_argument("end", "PrimValue", "The ending index of the slice (exclusive).") + .add_argument("step", "PrimValue", "The step of the slice.") + .set_attr("FInferStructInfo", InferStructInfoSliceScatter) + .set_attr("FPurity", Bool(true)); + /* relax.one_hot */ TVM_REGISTER_NODE_TYPE(OneHotAttrs); Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { @@ -2467,7 +2622,7 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } // namespace relax -TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); +TVM_FFI_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 9f0c86abd2fb..cc15d5d4ab76 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -24,8 +24,8 @@ #ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ #define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ +#include #include -#include #include "../op_common.h" #include "tvm/relax/expr.h" @@ -41,7 +41,7 @@ Expr broadcast_to(Expr x, Expr shape); * \param tensors An Expr in Tuple type, containing the tensors to be concatenated, * or a list of tensors * \param axis The axis along which the tensors are concatenated. - * If it is `NullOpt`, the input tensor is required to be flattened before concatenation. + * If it is `std::nullopt`, the input tensor is required to be flattened before concatenation. * \return The concatenated tensor. */ Expr concat(Expr tensors, Optional axis); @@ -74,7 +74,7 @@ Expr flatten(Expr x); */ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value, Optional> axis_separators, - Optional> input_axis_separators = NullOpt); + Optional> input_axis_separators = std::nullopt); /*! * \brief Permutes the dimensions of an array. @@ -113,7 +113,7 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) * \brief Squeeze axes in the array. * \param x The input data to the operator. * \param axis The set of axes to remove. - * If it is `NullOpt`, remove all axis of dimensions 1. + * If it is `std::nullopt`, remove all axis of dimensions 1. * If any specified axis has dimension that does not equal 1, it is an error. * \return The squeezed result. */ @@ -154,7 +154,7 @@ Expr collapse_sum_to(Expr data, Expr shape); * from the backward. By default, use the flattened input array, and return a flat output array. * \return The computed result. */ -Expr repeat(Expr data, int repeats, Optional axis = NullOpt); +Expr repeat(Expr data, int repeats, Optional axis = std::nullopt); /*! * \brief Construct an array by repeating data the number of times given by reps. @@ -273,6 +273,18 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re */ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); +/*! + * \brief Embeds the values of the src tensor into input at the given dimension. + * \param input The input tensor to be updated. + * \param src The tensor to embed into input. + * \param dim The dimension to insert the slice into. + * \param start The start index of where to insert the slice. + * \param end The end index of where to insert the slice. + * \param step The how many elements to skip in + * \return The computed result tensor with the same shape as `data`. + */ +Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step); + /*! * \brief Returns a one-hot tensor. * \param indices The indices to set to `on_value`. diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 0189ef96780d..78ba6fec34ac 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -44,7 +44,7 @@ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dty return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize); +TVM_FFI_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize); StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); @@ -128,7 +128,7 @@ Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_d return Call(op, {std::move(data), std::move(scale), std::move(zero_point)}, Attrs(attrs)); } -TVM_REGISTER_GLOBAL("relax.op.dequantize").set_body_typed(dequantize); +TVM_FFI_REGISTER_GLOBAL("relax.op.dequantize").set_body_typed(dequantize); StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 35ee4c486b1d..80bbf48fd4f9 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -43,7 +43,8 @@ Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indice Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.multinomial_from_uniform").set_body_typed(multinomial_from_uniform); +TVM_FFI_REGISTER_GLOBAL("relax.op.multinomial_from_uniform") + .set_body_typed(multinomial_from_uniform); StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) { CheckNumArguments(call, ctx); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 4df166215414..83e0e246b1bf 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -36,7 +36,7 @@ Expr where(Expr condition, Expr x1, Expr x2) { return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); +TVM_FFI_REGISTER_GLOBAL("relax.op.where").set_body_typed(where); StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); @@ -195,7 +195,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs(attrs)); \ } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index e2aef8005e78..e321b326d24e 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -46,7 +46,7 @@ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_i return call; } -TVM_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); +TVM_FFI_REGISTER_GLOBAL("relax.op.unique").set_body_typed(unique); StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = Downcast(call->args[0]->struct_info_); @@ -130,10 +130,10 @@ TVM_REGISTER_OP("relax.unique") "original input ended up in the returned unique list.") .add_argument("return_counts", "Tensor", "Whether to return an additional tensor with counts of each unique elements") - .add_argument( - "axis", "Tensor", - "The dimension to apply unique. If it is NullOpt, the unique values of the flattened input " - "are returned.") + .add_argument("axis", "Tensor", + "The dimension to apply unique. If it is std::nullopt, the unique values of the " + "flattened input " + "are returned.") .set_attr("FInferStructInfo", InferStructInfoUnique) .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); @@ -144,7 +144,7 @@ Expr nonzero(Expr x) { return Call(op, {std::move(x)}); } -TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); +TVM_FFI_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 9f8545e9b3a2..1cd061084e1a 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -41,7 +41,7 @@ Expr sort(Expr data, int axis, bool descending) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.sort").set_body_typed(sort); +TVM_FFI_REGISTER_GLOBAL("relax.op.sort").set_body_typed(sort); StructInfo InferStructInfoSort(const Call& call, const BlockBuilder& ctx) { return GetUnaryInputTensorStructInfo(call, ctx); @@ -67,7 +67,7 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype) { return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.argsort").set_body_typed(argsort); +TVM_FFI_REGISTER_GLOBAL("relax.op.argsort").set_body_typed(argsort); StructInfo InferStructInfoArgsort(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); @@ -101,7 +101,7 @@ Expr topk(Expr data, int k, int axis, String ret_type, bool largest, DataType dt return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.topk").set_body_typed(topk); +TVM_FFI_REGISTER_GLOBAL("relax.op.topk").set_body_typed(topk); StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index e4765c8ddb3c..73e74578fc06 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -186,7 +186,7 @@ Expr cumprod(Expr data, Optional axis, Optional dtype, Bool e return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.cumprod").set_body_typed(cumprod); +TVM_FFI_REGISTER_GLOBAL("relax.op.cumprod").set_body_typed(cumprod); TVM_REGISTER_OP("relax.cumprod") .set_attrs_type() @@ -206,7 +206,7 @@ Expr cumsum(Expr data, Optional axis, Optional dtype, Bool ex return Call(op, {std::move(data)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum); +TVM_FFI_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(cumsum); TVM_REGISTER_OP("relax.cumsum") .set_attrs_type() diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index c527ba363894..331562454efe 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -50,7 +50,7 @@ namespace relax { static const Op& op = Op::Get("relax." #OpName); \ return Call(op, {std::move(x)}, Attrs{attrs}, {}); \ } \ - TVM_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ + TVM_FFI_REGISTER_GLOBAL("relax.op." #OpName).set_body_typed(OpName); \ TVM_REGISTER_OP("relax." #OpName) \ .set_num_inputs(1) \ .add_argument("x", "Tensor", "The input data tensor") \ @@ -61,12 +61,10 @@ namespace relax { /*! * \brief Computes the maximum value of tensor elements over given axes. * \param x The input data tensor - * \param axis Axis or axes along which a max is performed. Being `NullOpt` means to max all the - * elements of the input tensor - * \param keepdims If this is set to True, the axes which are reduced are left in the result as - * dimensions with size one. With this option, the result will broadcast correctly against the - * input tensor. - * \return The result after reduction. + * \param axis Axis or axes along which a max is performed. Being `std::nullopt` means to max all + * the elements of the input tensor \param keepdims If this is set to True, the axes which are + * reduced are left in the result as dimensions with size one. With this option, the result will + * broadcast correctly against the input tensor. \return The result after reduction. */ Expr max(Expr x, Optional> axis, bool keepdims); @@ -98,8 +96,8 @@ Expr sum(Expr x, Optional> axis, bool keepdims); * \return The computed * result. */ -Expr cumprod(Expr data, Optional axis = NullOpt, Optional dtype = NullOpt, - Bool exclusive = Bool(false)); +Expr cumprod(Expr data, Optional axis = std::nullopt, + Optional dtype = std::nullopt, Bool exclusive = Bool(false)); /*! * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along @@ -113,8 +111,8 @@ Expr cumprod(Expr data, Optional axis = NullOpt, Optional dty * which the first element is not included. * \return The computed result. */ -Expr cumsum(Expr data, Optional axis = NullOpt, Optional dtype = NullOpt, - Bool exclusive = Bool(false)); +Expr cumsum(Expr data, Optional axis = std::nullopt, + Optional dtype = std::nullopt, Bool exclusive = Bool(false)); /*! \brief Computes the variance of tensor elements over given axes. */ Expr variance(Expr x, Optional> axis, bool keepdims); diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index cc265ad9a160..91a6e8d0ae04 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -143,7 +143,7 @@ Expr ewise_fma(Expr x1, Expr x2, Expr x3) { return Call(op, {x1, x2, x3}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); +TVM_FFI_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index f95eb721fc70..828a91dde21d 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -85,7 +85,7 @@ Expr clip(Expr x, Expr min, Expr max) { return Call(op, {std::move(x), std::move(min), std::move(max)}); } -TVM_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); +TVM_FFI_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); /***************** Check operators *****************/ diff --git a/src/relax/testing/transform.cc b/src/relax/testing/transform.cc index eed2329e3d3a..c4e41d5afc1f 100644 --- a/src/relax/testing/transform.cc +++ b/src/relax/testing/transform.cc @@ -35,7 +35,7 @@ tvm::transform::Pass ApplyEmptyCppMutator() { "relax.testing.ApplyEmptyCppMutator", {}); } -TVM_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator") .set_body_typed(ApplyEmptyCppMutator); } // namespace testing diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 101217c15463..cb44339f1969 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -77,7 +77,7 @@ class AppendLossMutator : private ExprMutator { loss_function_->params.end()); Expr new_body = this->VisitExpr(func->body); - return Function(new_params, new_body, NullOpt, func->is_pure, func->attrs); + return Function(new_params, new_body, std::nullopt, func->is_pure, func->attrs); } Expr VisitExpr_(const SeqExprNode* seq_expr) final { @@ -215,7 +215,7 @@ Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outpu /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); +TVM_FFI_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss); } // namespace transform diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h index f280308f9d51..1bfb20da3521 100644 --- a/src/relax/training/utils.h +++ b/src/relax/training/utils.h @@ -51,7 +51,7 @@ namespace transform { * \return The Pass. */ TVM_DLL Pass AppendLoss(String func_name, Function loss_function, int num_backbone_outputs = 1, - Optional new_func_name = NullOpt); + Optional new_func_name = std::nullopt); } // namespace transform } // namespace relax diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 79749cb41693..46dc803018ea 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -106,7 +106,7 @@ std::tuple)>> Crea if (sinfo) { return sinfo->GetShape(); } else { - return NullOpt; + return std::nullopt; } }; @@ -122,15 +122,15 @@ std::tuple)>> Crea auto shape_c = opt_shape_c.value(); if (matches.count(pat_permuted_matmul_on_lhs)) { - expr_a = permute_dims(expr_a, NullOpt); - expr_b = permute_dims(expr_b, NullOpt); + expr_a = permute_dims(expr_a, std::nullopt); + expr_b = permute_dims(expr_b, std::nullopt); CHECK_EQ(shape_a.size(), 2); CHECK_EQ(shape_b.size(), 2); shape_a = {shape_a[1], shape_a[0]}; shape_b = {shape_b[1], shape_b[0]}; } else if (matches.count(pat_permuted_matmul_on_rhs)) { - expr_b = permute_dims(expr_b, NullOpt); - expr_c = permute_dims(expr_c, NullOpt); + expr_b = permute_dims(expr_b, std::nullopt); + expr_c = permute_dims(expr_c, std::nullopt); CHECK_EQ(shape_b.size(), 2); CHECK_EQ(shape_c.size(), 2); shape_b = {shape_b[1], shape_b[0]}; @@ -213,7 +213,7 @@ Pass AdjustMatmulOrder() { return CreateFunctionPass(pass_func, 1, "AdjustMatmulOrder", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AdjustMatmulOrder").set_body_typed(AdjustMatmulOrder); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AdjustMatmulOrder").set_body_typed(AdjustMatmulOrder); } // namespace transform } // namespace relax diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index c8a2cef400b6..763a009a24b2 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -201,7 +201,7 @@ Pass AllocateWorkspace() { return CreateModulePass(pass_func, 0, "AllocateWorkspace", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 61d7725e6c13..63521e4e8fe1 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -438,7 +438,7 @@ Pass AlterOpImpl(const Map& op_impl_map, /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl); } // namespace transform } // namespace relax diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc index 5127b3df0c26..e2b0fc2c2877 100644 --- a/src/relax/transform/annotate_tir_op_pattern.cc +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -47,7 +47,8 @@ Pass AnnotateTIROpPattern() { return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern") + .set_body_typed(AnnotateTIROpPattern); } // namespace transform diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index 3593b22c10ab..cef74890806d 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -105,7 +105,7 @@ Pass AttachAttrLayoutFreeBuffers() { return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); } -TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") +TVM_FFI_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") .set_body_typed(AttachAttrLayoutFreeBuffers); } // namespace transform } // namespace relax diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 97226491a809..905d2bcd838d 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -79,7 +79,7 @@ Pass AttachGlobalSymbol() { return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); } -TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); +TVM_FFI_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 6a871edeaa9d..2a5c6f525d50 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -196,7 +196,7 @@ IRModule BindParam(IRModule m, String func_name, Map bind_ return GetRef(new_module); } -TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams); +TVM_FFI_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams); namespace transform { @@ -207,7 +207,7 @@ Pass BindParams(String func_name, Map params) { return CreateModulePass(pass_func, 0, "BindParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams); } // namespace transform diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 2df9ed1f01a3..49af21c10755 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -148,7 +148,7 @@ IRModule ModuleBindSymbolicVars(IRModule mod, Map binding_m } } // namespace -TVM_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars); namespace transform { @@ -170,7 +170,7 @@ Pass BindSymbolicVars(Map binding_map, Optional fun return tvm::transform::CreateModulePass(pass_func, 1, "relax.BindSymbolicVars", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars); } // namespace transform } // namespace relax diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index a011841c1316..982e1ac0c323 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -115,7 +115,7 @@ Pass BundleModelParams(Optional param_tuple_name) { return CreateModulePass(pass_func, 1, "BundleModelParams", {}); } -TVM_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.BundleModelParams").set_body_typed(BundleModelParams); } // namespace transform } // namespace relax diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index e3bb2bcbae46..25b4abadc7ff 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -183,7 +183,7 @@ Pass CallTIRRewrite() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); +TVM_FFI_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); } // namespace transform diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 807914075e8d..ecbb9e77518e 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -122,7 +122,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { if (auto it = known_values_.find(var); it != known_values_.end()) { return it->second.expr; } else { - return NullOpt; + return std::nullopt; } }); if (output.same_as(expr)) { @@ -301,22 +301,22 @@ class CanonicalizePlanner : public ExprVisitor { auto earlier_tuple = [&]() -> Optional { auto expr_tuple = expr.as(); if (!expr_tuple) { - return NullOpt; + return std::nullopt; } if (expr_tuple->fields.empty()) { - return NullOpt; + return std::nullopt; } auto first_element = recursively_unwrap_var(expr_tuple->fields[0]).as(); if (!first_element) { - return NullOpt; + return std::nullopt; } auto earlier_tuple_size = Downcast(GetStructInfo(first_element->tuple))->fields.size(); if (earlier_tuple_size != expr_tuple->fields.size()) { - return NullOpt; + return std::nullopt; } Expr earlier_tuple = recursively_unwrap_var(first_element->tuple); @@ -324,16 +324,16 @@ class CanonicalizePlanner : public ExprVisitor { for (size_t i = 0; i < expr_tuple->fields.size(); i++) { auto element = recursively_unwrap_var(expr_tuple->fields[i]).as(); if (!element) { - return NullOpt; + return std::nullopt; } if (static_cast(element->index) != i) { - return NullOpt; + return std::nullopt; } auto source_of_element = recursively_unwrap_var(element->tuple); if (!earlier_tuple.same_as(source_of_element)) { - return NullOpt; + return std::nullopt; } } @@ -343,7 +343,7 @@ class CanonicalizePlanner : public ExprVisitor { return earlier_tuple.value(); } - return NullOpt; + return std::nullopt; } void VisitBinding(const Binding& binding) override { @@ -591,7 +591,8 @@ Pass CanonicalizeBindings() { "CanonicalizeBindings"); } -TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); +TVM_FFI_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings") + .set_body_typed(CanonicalizeBindings); } // namespace transform diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index cc55eaff0721..620186320342 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -38,8 +38,6 @@ namespace tvm { namespace relax { -using runtime::Map; - using FCheck = ffi::TypedFunction, Array, Map)>; /*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose batch sizes @@ -160,7 +158,7 @@ ffi::TypedFunction(Map, Map)> GetRewri std::vector splits; for (auto index : indices) { Var rhs = matchings[patterns.rhs[index]]; - Optional bias = NullOpt; + Optional bias = std::nullopt; if (branch_info.bias_dim.has_value()) { bias = matchings[patterns.bias[index]]; } @@ -389,7 +387,8 @@ Pass CombineParallelMatmul(FCheck check) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul").set_body_typed(CombineParallelMatmul); +TVM_FFI_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul") + .set_body_typed(CombineParallelMatmul); } // namespace transform diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 25a1d1b0ede6..e6db2eb73f3a 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -86,7 +86,7 @@ Pass ComputePrimValue() { return CreateModulePass(pass_func, 0, "ComputePrimValue", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue); } // namespace transform diff --git a/src/relax/transform/convert_dataflow.cc b/src/relax/transform/convert_dataflow.cc index 5fb3683c40a2..c359afdebc28 100644 --- a/src/relax/transform/convert_dataflow.cc +++ b/src/relax/transform/convert_dataflow.cc @@ -91,7 +91,7 @@ class DataflowBlockExtractor : public ExprMutator { } dataflow_bindings = {}; - input_dataflow_block = NullOpt; + input_dataflow_block = std::nullopt; }; for (auto block : seq->blocks) { @@ -159,7 +159,7 @@ Pass ConvertToDataflow(int min_size) { return tvm::transform::Sequential({pass, CanonicalizeBindings()}); } -TVM_REGISTER_GLOBAL("relax.transform.ConvertToDataflow").set_body_typed(ConvertToDataflow); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ConvertToDataflow").set_body_typed(ConvertToDataflow); } // namespace transform diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index b41ad9ea29c4..0c06cac75d19 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -200,7 +200,7 @@ class LayoutConvertMutator : public ExprMutator { const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { const OpNode* op_node = call_node->op.as(); - if (op_node == nullptr) return NullOpt; + if (op_node == nullptr) return std::nullopt; Op op = Downcast(GetRef(op_node)); const auto attr_map = Op::GetAttrMap("FRelaxInferLayout"); if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { @@ -209,7 +209,7 @@ class LayoutConvertMutator : public ExprMutator { return f(GetRef(call_node), desired_layouts, var_layout_map); } else { // Otherwise, we use the default policy. - return NullOpt; + return std::nullopt; } } @@ -217,7 +217,7 @@ class LayoutConvertMutator : public ExprMutator { Optional res = GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); ObjectPtr new_call = make_object(*call_node); - new_call->struct_info_ = NullOpt; + new_call->struct_info_ = std::nullopt; if (!res.defined() || (!IsNestedTensor(binding->var) && !binding->var->IsInstance())) { // Default policy: use the initial layout. @@ -350,7 +350,7 @@ Pass ConvertLayout(Map> desired_layouts) { return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ConvertLayout").set_body_typed(ConvertLayout); } // namespace transform } // namespace relax diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index aee2c015fc81..51ab6bb23068 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -1016,13 +1016,13 @@ Array> DataflowInplaceAnalysis(const DataflowBlock& bl } // these are exposed only for testing -TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowLivenessAnalysis") .set_body_typed(DataflowLivenessAnalysis); -TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowAliasAnalysis") .set_body_typed(DataflowAliasAnalysis); -TVM_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.DataflowInplaceAnalysis") .set_body_typed(DataflowInplaceAnalysis); -TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") +TVM_FFI_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") .set_body_typed([](const IRModule& mod, const Call& call, const Array& inplace_indices) -> Array { ModuleInplaceTransformer transformer(mod); @@ -1031,7 +1031,7 @@ TVM_REGISTER_GLOBAL("relax.testing.transform.SingleInplaceCall") }); // actually exposed -TVM_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") +TVM_FFI_REGISTER_GLOBAL("relax.transform.DataflowUseInplaceCalls") .set_body_typed(DataflowUseInplaceCalls); } // namespace transform diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index f1dc3908fc8e..7de1da329f88 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -90,7 +90,7 @@ IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set return mod; } -IRModule DeadCodeElimination(const IRModule& arg_mod, Array entry_function_names) { +IRModule DeadCodeElimination(const IRModule& arg_mod, Array entry_function_names) { IRModule mod = arg_mod; // S0: Make a list of all user-specified entry functions and @@ -133,14 +133,14 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array ent namespace transform { -Pass DeadCodeElimination(Array entry_functions) { +Pass DeadCodeElimination(Array entry_functions) { auto pass_func = [=](IRModule m, PassContext pc) { return relax::DeadCodeElimination(m, entry_functions); }; return CreateModulePass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); +TVM_FFI_REGISTER_GLOBAL("relax.transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); } // namespace transform } // namespace relax diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 1a4cd216256b..eec27f3b7888 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -250,10 +250,10 @@ Pass DecomposeOpsForTraining(Optional func_name) { } } -TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference") +TVM_FFI_REGISTER_GLOBAL("relax.transform.DecomposeOpsForInference") .set_body_typed(DecomposeOpsForInference); -TVM_REGISTER_GLOBAL("relax.transform.DecomposeOpsForTraining") +TVM_FFI_REGISTER_GLOBAL("relax.transform.DecomposeOpsForTraining") .set_body_typed(DecomposeOpsForTraining); } // namespace transform diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 59960b47d73c..8a5ce1db04de 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -47,7 +47,7 @@ namespace { */ struct ReplacementKey { tvm::relax::Expr bound_value; - tvm::Optional match_cast = tvm::NullOpt; + tvm::Optional match_cast = std::nullopt; explicit ReplacementKey(const tvm::relax::Binding& binding) : bound_value(GetBoundValue(binding)) { @@ -221,7 +221,7 @@ Pass EliminateCommonSubexpr(bool call_only) { return CreateFunctionPass(pass_func, 1, "EliminateCommonSubexpr", {}); } -TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") +TVM_FFI_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") .set_body_typed(EliminateCommonSubexpr); } // namespace transform diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 134eca557264..d7bf2dd95ffb 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -104,7 +104,7 @@ Pass ExpandMatmulOfSum() { return CreateFunctionPass(pass_func, 1, "ExpandMatmulOfSum", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ExpandMatmulOfSum").set_body_typed(ExpandMatmulOfSum); } // namespace transform } // namespace relax diff --git a/src/relax/transform/expand_tuple_arguments.cc b/src/relax/transform/expand_tuple_arguments.cc index 550409f82800..1a9afadf7e48 100644 --- a/src/relax/transform/expand_tuple_arguments.cc +++ b/src/relax/transform/expand_tuple_arguments.cc @@ -33,13 +33,13 @@ using PMap = std::unordered_map; Optional ExpandParams(Function func) { bool is_exposed = func->attrs.GetAttr(tvm::attr::kGlobalSymbol).defined(); - if (is_exposed) return NullOpt; + if (is_exposed) return std::nullopt; bool has_tuple_param = std::any_of( func->params.begin(), func->params.end(), [](const Var& param) -> bool { return param->struct_info_.as(); }); - if (!has_tuple_param) return NullOpt; + if (!has_tuple_param) return std::nullopt; Array params; Array bindings; @@ -178,7 +178,8 @@ Pass ExpandTupleArguments() { "ExpandTupleArguments"); } -TVM_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments").set_body_typed(ExpandTupleArguments); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ExpandTupleArguments") + .set_body_typed(ExpandTupleArguments); } // namespace transform diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 3df818cf3cea..4ccf6c25abc8 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -55,11 +55,11 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& /*target=*/target, /*space_generator=*/ meta_schedule::SpaceGenerator::PostOrderApply(/*f_block_filter=*/nullptr, - /*sch_rules=*/NullOpt, - /*postprocs=*/NullOpt, - /*mutator_probs=*/NullOpt), + /*sch_rules=*/std::nullopt, + /*postprocs=*/std::nullopt, + /*mutator_probs=*/std::nullopt), /*search_strategy=*/meta_schedule::SearchStrategy::ReplayTrace(/*max_fail_count=*/100), - /*task_name=*/NullOpt, + /*task_name=*/std::nullopt, /*num_threads=*/num_threads, // use all available local threads /*rand_state=*/-1, // -1 means use random seed /*logger=*/nullptr); @@ -67,8 +67,8 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& task->search_strategy.value()->PreTuning( /*max_trials=*/valid_count, /*num_trials_per_iter=*/valid_count, /*design_spaces=*/task->space_generator.value()->GenerateDesignSpace(mod), - /*database=*/NullOpt, - /*cost_model=*/NullOpt); + /*database=*/std::nullopt, + /*cost_model=*/std::nullopt); int fail_count = 0, max_fail_count = 100; while (valid_count > 0 && fail_count < max_fail_count) { Optional> candidates = @@ -172,7 +172,7 @@ Pass FewShotTuning(int valid_count, bool benchmark) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.FewShotTuning").set_body_typed(FewShotTuning); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FewShotTuning").set_body_typed(FewShotTuning); } // namespace transform } // namespace relax diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 2cce9c8d7c26..7fec51086514 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -45,14 +46,14 @@ class ConstantFolder : public ExprMutator { * constant shape and get runtime shape tuple from it. * \param struct_info The given struct info whose shape inside is to be casted. * \return The runtime shape tuple, or nullopt if it is not a constant shape. - * \note Only TensorStructInfo is supported at this moment. Return NullOpt + * \note Only TensorStructInfo is supported at this moment. Return std::nullopt * if the input struct info is not TensorStructInfo. */ - static Optional MatchConstShape(const StructInfo& struct_info) { + static Optional MatchConstShape(const StructInfo& struct_info) { // Only support single output for call_tir at this moment. const auto* tensor_sinfo = struct_info.as(); if (tensor_sinfo == nullptr) { - return NullOpt; + return std::nullopt; } const auto* shape = tensor_sinfo->shape.as(); @@ -61,10 +62,10 @@ class ConstantFolder : public ExprMutator { std::vector shape_values; for (const auto v : shape->values) { auto* ptr = v.as(); - if (!ptr) return NullOpt; + if (!ptr) return std::nullopt; shape_values.push_back(ptr->value); } - return runtime::ShapeTuple(shape_values.begin(), shape_values.end()); + return ffi::Shape(shape_values.begin(), shape_values.end()); } /*! @@ -75,7 +76,7 @@ class ConstantFolder : public ExprMutator { Array res; for (auto arg : args) { auto* ptr = arg.as(); - if (!ptr) return NullOpt; + if (!ptr) return std::nullopt; res.push_back(ptr->data); } return res; @@ -92,7 +93,7 @@ class ConstantFolder : public ExprMutator { if (auto* pfunc = base_func.as()) { return GetRef(pfunc); } - return NullOpt; + return std::nullopt; } /*! @@ -108,7 +109,7 @@ class ConstantFolder : public ExprMutator { if (it != func_build_cache_.end()) { return it->second; } - Optional build_func = NullOpt; + Optional build_func = std::nullopt; try { // Not all the primfunc can be directly built via llvm, for example, if a function is @@ -141,12 +142,12 @@ class ConstantFolder : public ExprMutator { } // Try constant evaluate the function call - // if failed return NullOpt + // if failed return std::nullopt Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, Array arr_args, - runtime::ShapeTuple shape, DataType ret_type) { + ffi::Shape shape, DataType ret_type) { // obtain function from the cache. Optional func = GetCachedBuild(tir_func); - if (!func) return NullOpt; + if (!func) return std::nullopt; // here the vector size has an additional + 1 because we need to put ret_tensor at the end std::vector packed_args(arr_args.size() + 1); @@ -180,7 +181,7 @@ class ConstantFolder : public ExprMutator { Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; - Optional shape = MatchConstShape(call->sinfo_args[0]); + Optional shape = MatchConstShape(call->sinfo_args[0]); bool output_not_tuple = call->sinfo_args.size() == 1; // Pattern 0: call constant function, const argument with const shape. if (func && arr_args && shape && output_not_tuple) { @@ -326,7 +327,7 @@ Pass FoldConstant() { return CreateFunctionPass(pass_func, 0, "FoldConstant", {}); } -TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 4819cefb9ac3..f9ffcd930283 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -537,26 +537,26 @@ class FunctionCreator : public ExprMutator { // If the result is not used outside LOG(WARNING) << "There are dead codes in the current IRModule, please run the " "DeadCodeElimination Pass before FuseOps"; - function_ = NullOpt; + function_ = std::nullopt; } else { Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); body = builder_->Normalize(body); body = builder_->Normalize(SeqExpr({new_block}, body)); group_attrs.Set(tvm::relax::attr::kPrimitive, true); - Function function = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*is_pure=*/true, // + Function function = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_struct_info=*/std::nullopt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); Array free_vars = FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); arguments_.push_back(ShapeExpr(free_vars)); - function = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*is_pure=*/true, // + function = Function(/*params=*/params_, // + /*body=*/body, // + /*ret_struct_info=*/std::nullopt, // + /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); } function_ = SymbolicVarRenewMutator::Renew(function); @@ -572,7 +572,7 @@ class FunctionCreator : public ExprMutator { /*! \brief The name for the fused function */ String name_hint_ = "fused"; /*! \brief The constructed Relax function */ - Optional function_ = NullOpt; + Optional function_ = std::nullopt; private: std::optional GetOutputIndex(Var v) { @@ -1395,7 +1395,7 @@ FusionPattern::FusionPattern(String name, DFPattern pattern, } TVM_REGISTER_NODE_TYPE(FusionPatternNode); -TVM_REGISTER_GLOBAL("relax.transform.FusionPattern") +TVM_FFI_REGISTER_GLOBAL("relax.transform.FusionPattern") .set_body_typed([](String name, DFPattern pattern, Map annotation_patterns, Optional check, Optional attrs_getter) { return FusionPattern(name, pattern, annotation_patterns, check, attrs_getter); @@ -1429,7 +1429,7 @@ Pass FuseOps(int fuse_opt_level) { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps); Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants, bool annotate_codegen, const Array& entry_function_names) { @@ -1444,7 +1444,7 @@ Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_const /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern); } // namespace transform diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 4f4a3ec2c73a..05b7bf4218dd 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -900,7 +900,7 @@ class FusedTIRConstructor : public ExprVisitor { */ static void CollectPrimFuncParams(const Var& relax_param, std::vector>* out, - const tvm::runtime::Optional& tir_buffer_param) { + const Optional& tir_buffer_param) { auto struct_info = GetStructInfo(relax_param); CHECK(!struct_info.as()) @@ -963,7 +963,7 @@ class FusedTIRConstructor : public ExprVisitor { tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); body = subst.Substitute(body); - body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); + body = tir::Block({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); body = tir::BlockRealize({}, Bool(true), Downcast(body)); tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, DictAttrs(attr_map)); @@ -1262,7 +1262,7 @@ Pass FuseTIR() { "FuseTIR"); } -TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); +TVM_FFI_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR); } // namespace transform diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index fa96f31cf5a8..9998b6da93f3 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -668,7 +668,7 @@ class GradientMutator : private ExprMutator { orig_params_ = func->params; Expr new_body = this->VisitExpr(func->body); - return Function(func->params, new_body, NullOpt, func->is_pure, func->attrs); + return Function(func->params, new_body, std::nullopt, func->is_pure, func->attrs); } Expr VisitExpr_(const SeqExprNode* seq_expr) final { @@ -787,7 +787,7 @@ Pass Gradient(String func_name, Optional> require_grads, int target_i /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient); +TVM_FFI_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient); } // namespace transform diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index 3440afb7211d..a3a86dd2e0c3 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -38,9 +38,8 @@ namespace tvm { namespace relax { -using runtime::DLDataTypeToString; -using runtime::String; -using runtime::StringToDLDataType; +using ffi::DLDataTypeToString; +using ffi::StringToDLDataType; enum MixedPrecisionPolicyKind : int { kAlways = 0, kFollow = 1, kNever = 2 }; diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 457d6d9ead46..26b106373ff0 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -85,7 +85,7 @@ class FunctionInliner : public ExprMutator { } else if (auto opt = replacements_.Get(gvar->name_hint)) { return opt; } else { - return NullOpt; + return std::nullopt; } } @@ -164,7 +164,7 @@ Function FunctionInlineFunctions(Function func, return Downcast(mutator(std::move(func))); } -TVM_REGISTER_GLOBAL("relax.FunctionInlineFunctions").set_body_typed(FunctionInlineFunctions); +TVM_FFI_REGISTER_GLOBAL("relax.FunctionInlineFunctions").set_body_typed(FunctionInlineFunctions); namespace transform { @@ -219,7 +219,7 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "InlinePrivateFunctions", {}); } -TVM_REGISTER_GLOBAL("relax.transform.InlinePrivateFunctions") +TVM_FFI_REGISTER_GLOBAL("relax.transform.InlinePrivateFunctions") .set_body_typed(InlinePrivateFunctions); } // namespace transform diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 20ec5eb4348f..730f65f701ba 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -265,7 +265,7 @@ Pass KillAfterLastUse() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "KillAfterLastUse", {}); } -TVM_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse); +TVM_FFI_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index ee19872e5af5..e5e28cb55375 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -166,7 +166,7 @@ class LambdaNameCollector : ExprVisitor { if (auto it = lifted_with_global_symbol_.find(func); it != lifted_with_global_symbol_.end()) { return it->second; } else { - return NullOpt; + return std::nullopt; } }); @@ -181,7 +181,7 @@ class LambdaNameCollector : ExprVisitor { // 3. Try concatenating the entire path together. Don't include // paths of length 2, as they would already be attempted earlier. attempt_name_generation([&](const FunctionNode*, const auto& location) -> Optional { - if (location.size() == 2) return NullOpt; + if (location.size() == 2) return std::nullopt; std::stringstream stream; bool is_first = true; @@ -485,7 +485,7 @@ class LambdaLifter : public ExprMutator { std::unordered_map nested_closure_map_; std::unordered_map rebind_map_; std::unordered_set, ObjectPtrHash, ObjectPtrEqual> closures_; - Optional current_lambda_var_ = NullOpt; + Optional current_lambda_var_ = std::nullopt; IRModule mod_; std::unordered_map lifted_names_; @@ -503,7 +503,7 @@ Pass LambdaLift() { return tvm::transform::CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 66aa35887f08..32f63e1e141b 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -80,7 +80,7 @@ class LazyInputMutator : public ExprMutator { if (externally_visible_vars.count(var)) { return var; } else { - return NullOpt; + return std::nullopt; } }); @@ -259,7 +259,7 @@ Pass LazyGetInput() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput); Pass LazySetOutput() { auto pass_func = [](Function func, IRModule, PassContext) -> Function { @@ -274,7 +274,7 @@ Pass LazySetOutput() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput); } // namespace transform } // namespace relax diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index b66132154aca..a0ac6fffb62c 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -149,7 +149,7 @@ class LegalizeMutator : public ExprMutator { return GetTarget(tup_sinfo->fields); } } - return NullOpt; + return std::nullopt; } Expr BindTarget(Expr expr) { @@ -404,7 +404,7 @@ Pass LegalizeOps(Optional> cmap, bool enable_warning) /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LegalizeOps").set_body_typed(LegalizeOps); } // namespace transform diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 180c0be8910d..9013737df5e4 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -703,7 +703,8 @@ std::vector> GetTargetFunctions( const IRModule& mod, const Variant>& shared_transform) { std::vector> target_functions; if (shared_transform.as>().value_or(Array{}).size()) { - for (const auto& name : shared_transform.as>().value()) { + auto names = shared_transform.as>().value(); + for (const auto& name : names) { auto gvar = mod->global_var_map_.Get(name); CHECK(gvar) << "When LiftTransformParams is called with a list of function names, " << "all function names must occur within the IRModule. " @@ -866,7 +867,7 @@ Pass LiftTransformParams(Variant> shared_transform) { "LiftTransformParams"); } -TVM_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LiftTransformParams").set_body_typed(LiftTransformParams); } // namespace transform } // namespace relax diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 13705c0908cc..3bdbfd0b94a9 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -99,7 +99,7 @@ Pass LowerAllocTensor() { return CreateFunctionPass(pass_func, /*opt_level=*/0, "LowerAllocTensor", {}); } -TVM_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor); +TVM_FFI_REGISTER_GLOBAL("relax.transform.LowerAllocTensor").set_body_typed(LowerAllocTensor); } // namespace transform } // namespace relax diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index d814629026d5..ffeddd08c401 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -168,13 +168,13 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { Optional GetCodegenName(const Expr& callee) { auto const* gvar = callee.as(); if (!gvar) { - return NullOpt; + return std::nullopt; } auto composite_name_opt = mod_->Lookup(GetRef(gvar))->GetAttr(attr::kComposite); if (!composite_name_opt) { - return NullOpt; + return std::nullopt; } return relax::GetCodegenName(composite_name_opt.value()); @@ -421,7 +421,7 @@ Pass MergeCompositeFunctions() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MergeCompositeFunctions") .set_body_typed(MergeCompositeFunctions); } // namespace transform diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 875d26ea47a1..cf7b9fc03a50 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -173,8 +173,8 @@ Pass MetaScheduleApplyDatabase(Optional work_dir, bool enable_warning = Pass MetaScheduleTuneIRMod(Map params, String work_dir, Integer max_trials_global, - Optional max_trials_per_task = NullOpt, - Optional> op_names = NullOpt) { + Optional max_trials_per_task = std::nullopt, + Optional> op_names = std::nullopt) { Target target = Target::Current(false); auto pass_func = [=](IRModule m, PassContext ctx) { auto max_trials_task = max_trials_per_task.value_or(max_trials_global); @@ -191,7 +191,8 @@ Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { Target target = Target::Current(false); ffi::TypedFunction pass_func = [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { - return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_global, NullOpt) + return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_global, + std::nullopt) .TuneTIR(f, ctx); }; return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, @@ -200,10 +201,11 @@ Pass MetaScheduleTuneTIR(String work_dir, Integer max_trials_global) { /*traceable*/ true); } -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") +TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") .set_body_typed(MetaScheduleApplyDatabase); -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); +TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod") + .set_body_typed(MetaScheduleTuneIRMod); +TVM_FFI_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); } // namespace transform } // namespace relax } // namespace tvm diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index e81cd16207e9..07ca6a1133e7 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -35,7 +35,7 @@ namespace relax { // TODO(@altanh): LCA binding lifting class NormalizeMutator : public ExprMutatorBase { public: - NormalizeMutator() { builder_ = BlockBuilder::Create(NullOpt); } + NormalizeMutator() { builder_ = BlockBuilder::Create(std::nullopt); } Expr VisitExpr(const Expr& expr) override { return builder_->Normalize(ExprMutatorBase::VisitExpr(expr)); @@ -63,7 +63,7 @@ class NormalizeMutator : public ExprMutatorBase { } } - Expr VisitWithNewScope(const Expr& expr, Optional> params = NullOpt) { + Expr VisitWithNewScope(const Expr& expr, Optional> params = std::nullopt) { builder_->BeginBindingBlock(); if (params.defined()) { builder_->BeginScope(params); @@ -279,7 +279,7 @@ Pass Normalize() { return CreateFunctionPass(pass_func, 1, "Normalize", {}); } -TVM_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); +TVM_FFI_REGISTER_GLOBAL("relax.transform.Normalize").set_body_typed(Normalize); Pass NormalizeGlobalVar() { auto pass_func = [=](IRModule mod, PassContext pc) { @@ -290,7 +290,7 @@ Pass NormalizeGlobalVar() { /*pass_name=*/"NormalizeGlobalVar", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.NormalizeGlobalVar").set_body_typed(NormalizeGlobalVar); +TVM_FFI_REGISTER_GLOBAL("relax.transform.NormalizeGlobalVar").set_body_typed(NormalizeGlobalVar); } // namespace transform diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 9e6ebbbc2632..ee4773fb3a24 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -76,7 +76,7 @@ class VDeviceLookup { } private: - Optional> opt_vdevices_ = NullOpt; + Optional> opt_vdevices_ = std::nullopt; }; class DeviceHintCollector : ExprVisitor { @@ -183,7 +183,7 @@ class DeviceHintCollector : ExprVisitor { return bound.value(); } } - return NullOpt; + return std::nullopt; } // A lookup to identify the VDevice from the IRModule attributes, @@ -254,7 +254,7 @@ class VDeviceSetCollector : ExprVisitor { } } - Optional current_binding_ = NullOpt; + Optional current_binding_ = std::nullopt; // Lookup from relax variable to the set of relax variables which // must be located on the same device. For example, a trivial @@ -415,7 +415,7 @@ Pass RealizeVDevice() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.RealizeVDevice").set_body_typed(RealizeVDevice); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RealizeVDevice").set_body_typed(RealizeVDevice); } // namespace transform } // namespace relax diff --git a/src/relax/transform/remove_purity_checking.cc b/src/relax/transform/remove_purity_checking.cc index a88f5e0f5629..31e771d2adec 100644 --- a/src/relax/transform/remove_purity_checking.cc +++ b/src/relax/transform/remove_purity_checking.cc @@ -88,7 +88,8 @@ Pass RemovePurityChecking() { return CreateFunctionPass(pass_func, 0, "RemovePurityChecking", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RemovePurityChecking").set_body_typed(RemovePurityChecking); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RemovePurityChecking") + .set_body_typed(RemovePurityChecking); } // namespace transform diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index aafd1941ec90..e170588f60c6 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -143,7 +143,7 @@ class PartialTupleUsageCollector : ExprVisitor { return known_binding.value(); } } - return NullOpt; + return std::nullopt; }; while (auto unwrapped = get_bound_value(expr)) { @@ -336,7 +336,7 @@ Pass RemoveUnusedOutputs() { "RemoveUnusedOutputs"); } -TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedOutputs").set_body_typed(RemoveUnusedOutputs); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RemoveUnusedOutputs").set_body_typed(RemoveUnusedOutputs); } // namespace transform diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index bc7fa325ccc7..911e427935be 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -250,7 +250,7 @@ Pass RemoveUnusedParameters() { return CreateModulePass(pass_func, 0, "RemoveUnusedParameters", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") +TVM_FFI_REGISTER_GLOBAL("relax.transform.RemoveUnusedParameters") .set_body_typed(RemoveUnusedParameters); } // namespace transform diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index a2023a068aa2..2016c6766c08 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -173,7 +173,7 @@ Pass ReorderPermuteDimsAfterConcat() { return CreateFunctionPass(pass_func, 1, "ReorderPermuteDimsAfterConcat", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat") +TVM_FFI_REGISTER_GLOBAL("relax.transform.ReorderPermuteDimsAfterConcat") .set_body_typed(ReorderPermuteDimsAfterConcat); } // namespace transform diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index 28480a2296f3..4c87cbe8b7e3 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -156,7 +156,7 @@ Pass ReorderTakeAfterMatmul() { return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul") +TVM_FFI_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul") .set_body_typed(ReorderTakeAfterMatmul); } // namespace transform diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index aebf722bcfd2..14e98ecad152 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -88,7 +88,7 @@ struct LiftedFunctionRewritePlan { // The corresponding binding vars in the original function of the inputs of the lifted function std::vector inputs; // The tir vars in the original function that are propagated to the lifted function - Optional propogated_tir_vars = NullOpt; + Optional propogated_tir_vars = std::nullopt; }; /*! \brief Builder of the lifted function for cuda graph capturing or allocations */ @@ -123,7 +123,7 @@ class FuncBuilder : public ExprMutator { /*! \brief Build the new function */ Function Build() { Array params; - Optional shape_expr = NullOpt; + Optional shape_expr = std::nullopt; if (shape_expr_inputs_.size()) { Array tir_vars; for (const auto* var : shape_expr_inputs_) { @@ -871,8 +871,8 @@ class CUDAGraphRewriter : public ExprMutator { int index_alloc_ = 0; int index_capture_ = 0; support::Arena arena_; - Optional gv_global_alloc_ = NullOpt; - Optional current_func_ = NullOpt; + Optional gv_global_alloc_ = std::nullopt; + Optional current_func_ = std::nullopt; }; IRModule RewriteCUDAGraph(IRModule mod) { @@ -897,7 +897,7 @@ Pass RewriteCUDAGraph() { return CreateModulePass(pass_func, 0, "RewriteCUDAGraph", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RewriteCUDAGraph").set_body_typed(RewriteCUDAGraph); } // namespace transform diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 690f1b723279..a13c23387821 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -165,7 +165,7 @@ Pass RewriteDataflowReshape() { return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") +TVM_FFI_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") .set_body_typed(RewriteDataflowReshape); } // namespace transform diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 3940ecd70bd5..d29bdaacb9b0 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "../../support/ordered_set.h" #include "utils.h" @@ -219,7 +220,7 @@ Pass RunCodegen(Optional>> target_options, return CreateModulePass(pass_func, 0, "RunCodegen", {}); } -TVM_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); +TVM_FFI_REGISTER_GLOBAL("relax.transform.RunCodegen").set_body_typed(RunCodegen); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 14d1d7d2fb92..276ba448cc4b 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -88,7 +88,7 @@ class ForMatcher : public TensorizeComparator { return it->second; } } - return NullOpt; + return std::nullopt; } bool VisitExpr(const PrimExpr& lhs, const PrimExpr& rhs) final { @@ -568,7 +568,7 @@ std::pair> SplitFunctions(PrimFunc func, Array match_results = TIRPatternMatcher::Match(patterns, func->body.as()->block->body); if (match_results.empty()) { - return {func, NullOpt}; + return {func, std::nullopt}; } Array codegen_result = f_codegen(match_results); ICHECK(codegen_result.size() == 3); @@ -576,12 +576,12 @@ std::pair> SplitFunctions(PrimFunc func, int num_matched_ops = Downcast(codegen_result[1])->value; Array func1_args = Downcast>(codegen_result[2]); if (num_matched_ops == 0) { - return {func, NullOpt}; + return {func, std::nullopt}; } FunctionPartitioner partitioner(num_matched_ops); partitioner(body); if (partitioner.fail) { - return {func, NullOpt}; + return {func, std::nullopt}; } bool has_second_func = false; for (const auto& pr : partitioner.block_partition) { @@ -592,7 +592,7 @@ std::pair> SplitFunctions(PrimFunc func, } if (!has_second_func) { // No need to split the function. - return {WithAttr(func, kLibraryKernel, library_code), NullOpt}; + return {WithAttr(func, kLibraryKernel, library_code), std::nullopt}; } // Step 2. Split the function into two functions. Stmt body1 = BlockRemover::RemoveBlockByPartition(func->body, partitioner.block_partition, @@ -660,7 +660,7 @@ void StringReplace(std::string* subject, const std::string& search, const std::s tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, String global_symbol) { using namespace tvm::tir; - Optional library_code = pf->attrs.GetAttr(kLibraryKernel); + Optional library_code = pf->attrs.GetAttr(kLibraryKernel); if (!library_code.defined()) { return GetRef(pf); } @@ -774,7 +774,8 @@ Pass SplitCallTIRByPattern(Array patterns, FCodegen fcodegen) { /*pass_name=*/"SplitCallTIRByPattern", // /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern").set_body_typed(SplitCallTIRByPattern); +TVM_FFI_REGISTER_GLOBAL("relax.transform.SplitCallTIRByPattern") + .set_body_typed(SplitCallTIRByPattern); } // namespace transform diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 41213859f8ce..7990beb04b2e 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -41,7 +41,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { if (layout_rewrite_preproc_stmts_.size() > 0) { return std::make_tuple(create_layout_rewrite_preproc_func(), create_compute_func()); } else { - return std::make_tuple(NullOpt, func); + return std::make_tuple(std::nullopt, func); } } @@ -124,7 +124,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { /*block=*/ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body, - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/alloc_buffers)); Map dict; @@ -340,7 +340,7 @@ Pass SplitLayoutRewritePreproc() { return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, "SplitLayoutRewritePreproc"); } -TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") +TVM_FFI_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") .set_body_typed(SplitLayoutRewritePreproc); } // namespace transform } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index dcbd3077b1ea..0a51e9cd4acb 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -166,7 +166,7 @@ class TokenAllocator1D { * \brief Request a storage token from the available token pool for a * given prototype, or report no appropriate available token in the pool. * \param prototype The requesting prototype storage token. - * \return The request result token. Return NullOpt if there is no + * \return The request result token. Return std::nullopt if there is no * appropriate available token in the pool. */ Optional RequestReuse(StorageToken prototype) { @@ -175,7 +175,7 @@ class TokenAllocator1D { // If the prototype has no reference at all, feel free to allocate new storage. // The unused binding can be removed by cleaning passes. if (prototype->ref_counter == 0) { - return NullOpt; + return std::nullopt; } // Step 1. Get the available pool of the token dtype. @@ -197,7 +197,7 @@ class TokenAllocator1D { return available_token; } } - return NullOpt; + return std::nullopt; } // Step 2. Get the range of memory blocks in [size / match_range_, size * match_range_) auto begin = pool.lower_bound(size / match_range_); @@ -228,8 +228,9 @@ class TokenAllocator1D { pool.erase(mid); return available_token; } - // Return `NullOpt` indicating that no satisfiable storage token is found in the available pool. - return NullOpt; + // Return `std::nullopt` indicating that no satisfiable storage token is found in the available + // pool. + return std::nullopt; } /*! @@ -982,7 +983,8 @@ Pass StaticPlanBlockMemory() { return CreateModulePass(pass_func, /*opt_level=*/0, "StaticPlanBlockMemory", {}); } -TVM_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory").set_body_typed(StaticPlanBlockMemory); +TVM_FFI_REGISTER_GLOBAL("relax.transform.StaticPlanBlockMemory") + .set_body_typed(StaticPlanBlockMemory); } // namespace transform } // namespace relax diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index 9482f500d942..531ecefd5d66 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -38,8 +38,6 @@ namespace tvm { namespace relax { -using runtime::String; - int GetMixedPrecisionInfo(const CallNode* call_node) { const OpNode* op_node = call_node->op.as(); if (op_node == nullptr) { @@ -511,7 +509,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (opt_new_dtype) { auto new_dtype = opt_new_dtype.value(); new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype); - new_call.CopyOnWrite()->struct_info_ = NullOpt; + new_call.CopyOnWrite()->struct_info_ = std::nullopt; new_value = builder_->Normalize(Call(new_call)); @@ -535,7 +533,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } ObjectPtr new_tuple = make_object(*tuple_node); new_tuple->fields = std::move(RemapArgs(tuple_node->fields)); - new_tuple->struct_info_ = NullOpt; + new_tuple->struct_info_ = std::nullopt; Expr new_value = builder_->Normalize(Tuple(new_tuple)); if (!binding->var->IsInstance()) { // Global var: store the tensors to the original dtype @@ -555,7 +553,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { ObjectPtr new_tuple_get_item = make_object(*tuple_get_item_node); new_tuple_get_item->tuple = RemapArgs({tuple_get_item_node->tuple})[0]; - new_tuple_get_item->struct_info_ = NullOpt; + new_tuple_get_item->struct_info_ = std::nullopt; Expr new_value = TupleGetItem(new_tuple_get_item); if (!binding->var->IsInstance()) { // Global var: store the tensors to the original dtype @@ -620,7 +618,7 @@ Pass ToMixedPrecision(const DataType& out_dtype, Optional> fp16_in return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ToMixedPrecision").set_body_typed(ToMixedPrecision); } // namespace transform diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index b18ece65a6db..ef1616c83ed8 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -61,7 +61,7 @@ Pass ToNonDataflow() { return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); } -TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); +TVM_FFI_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); } // namespace transform diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index 24ed53948e71..1ba78cdc5e2c 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -98,7 +98,7 @@ class BindingOrderCollector : ExprVisitor { // If there is a variable without any inputs (e.g. `R.const(1)`) // or an unused variable, these must be handled somewhere, to // ensure they are visited corrected. It's easiest to perform the - // depth/breadth-first search if handled here, with `NullOpt` + // depth/breadth-first search if handled here, with `std::nullopt` // acting as a special value, so that the later traversal doesn't // need to check for this special case. std::vector zero_input_bindings; @@ -247,7 +247,7 @@ class TopologicalSorter : public ExprMutator { std::unordered_set visited; - // Given a variable that has just been defined (or NullOpt for the + // Given a variable that has just been defined (or std::nullopt for the // function's output), mark nodes as ready to visit. auto push_descendents_to_stack = [&](const DataflowNode& var) { auto it = forward_edge_lookup.find(var); @@ -342,7 +342,7 @@ Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { return relax::transform::CreateFunctionPass(pass_func, 0, "TopologicalSort", {}); } -TVM_REGISTER_GLOBAL("relax.transform.TopologicalSort") +TVM_FFI_REGISTER_GLOBAL("relax.transform.TopologicalSort") .set_body_typed([](String order_str, String direction_str) -> Pass { TraversalOrder order = [&]() { if (order_str == "depth-first") { diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index 87d9a76cfbee..fedc61019b06 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -49,7 +49,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj) { Trace trace{nullptr}; Optional> run_secs{nullptr}; try { - const ArrayObj* json_array = json_obj.as(); + const ffi::ArrayObj* json_array = json_obj.as(); CHECK(json_array && json_array->size() == 2); // Load json[0] => trace { @@ -256,7 +256,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { const ObjectRef& json_obj = json_objs[task_id].cast(); try { - const ArrayObj* arr = json_obj.as(); + const ffi::ArrayObj* arr = json_obj.as(); ICHECK_EQ(arr->size(), 3); workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); targets[task_id] = Target(Downcast>(arr->at(1))); @@ -288,7 +288,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { const ObjectRef& json_obj = json_objs[task_id].cast(); try { - const ArrayObj* arr = json_obj.as(); + const ffi::ArrayObj* arr = json_obj.as(); ICHECK_EQ(arr->size(), 3); workload_idxs[task_id] = Downcast(arr->at(0)).IntValue(); targets[task_id] = Target(Downcast>(arr->at(1))); @@ -311,32 +311,34 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TuningRecordNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecord") .set_body_typed([](Trace trace, Optional> run_secs) { return TuningRecord(trace, run_secs); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecordAsJSON") .set_body_method(&TuningRecordNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TuningRecordFromJSON") + .set_body_typed(TuningRecord::FromJSON); TVM_REGISTER_OBJECT_TYPE(DatabaseNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasWorkload") .set_body_method(&DatabaseNode::HasWorkload); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasMeasurementRecord") .set_body_method(&DatabaseNode::HasMeasurementRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseHasTuningRecord") .set_body_method(&DatabaseNode::HasTuningRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitMeasurementRecord") .set_body_method(&DatabaseNode::CommitMeasurementRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitWorkload") .set_body_method(&DatabaseNode::CommitWorkload); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseCommitTuningRecord") .set_body_method(&DatabaseNode::CommitTuningRecord); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetTopK").set_body_method(&DatabaseNode::GetTopK); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseGetMeasurementRecord") .set_body_method(&DatabaseNode::GetMeasurementRecord); TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.DatabaseJSONDatabase") + .set_body_typed(Database::JSONDatabase); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc index cdad8e58c3c3..5f53b5166725 100644 --- a/src/relax/transform/tuning_api/primitives.cc +++ b/src/relax/transform/tuning_api/primitives.cc @@ -64,12 +64,12 @@ Choice Choice::FromJSON(const ObjectRef& json) { String transform_func_key, constr_func_key; Array transform_func_args, constr_func_args; try { - const ArrayObj* arr = json.as(); + const ffi::ArrayObj* arr = json.as(); ICHECK(arr && arr->size() == 4); const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); + const auto* arr1 = arr->at(1).as(); const auto* arr2 = arr->at(2).as(); - const auto* arr3 = arr->at(3).as(); + const auto* arr3 = arr->at(3).as(); ICHECK(arr0 && arr1 && arr2 && arr3); transform_func_key = GetRef(arr0); { @@ -123,10 +123,10 @@ Knob Knob::FromJSON(const ObjectRef& json) { String name; Map choices; try { - const ArrayObj* arr = json.as(); + const ffi::ArrayObj* arr = json.as(); ICHECK(arr && arr->size() == 2); const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); + const auto* arr1 = arr->at(1).as(); ICHECK(arr0 && arr1); name = GetRef(arr0); for (auto const& x : GetRef>(arr1)) { @@ -198,12 +198,12 @@ Trace Trace::FromJSON(const ObjectRef& json) { Array knobs; Array decisions; try { - const ArrayObj* arr = json.as(); + const ffi::ArrayObj* arr = json.as(); // A trace will have 2 or 3 entries depending on `include_irmod` parameter. ICHECK(arr && (arr->size() == 2 || arr->size() == 3)); - const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); ICHECK(arr0 && arr1); for (const Any& elem : *arr0) { @@ -231,41 +231,42 @@ Trace Trace::FromJSON(const ObjectRef& json) { /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ChoiceNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.Choice") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Choice") .set_body_typed([](String transform_func_key, Array transform_func_args, String constr_func_key, Array constr_func_args) { return Choice(transform_func_key, transform_func_args, constr_func_key, constr_func_args); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceAsJSON").set_body_method(&ChoiceNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceFromJSON").set_body_typed(Choice::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetTransformFunc") .set_body_method(&ChoiceNode::GetTransformFunc); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceGetConstrFunc") .set_body_method(&ChoiceNode::GetConstrFunc); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceApplyTransformFunc") .set_body_method(&ChoiceNode::ApplyTransformFunc); -TVM_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr").set_body_method(&ChoiceNode::CheckConstr); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.ChoiceCheckConstr") + .set_body_method(&ChoiceNode::CheckConstr); TVM_REGISTER_NODE_TYPE(KnobNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.Knob") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Knob") .set_body_typed([](String name, Map choices) { return Knob(name, choices); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobAsJSON").set_body_method(&KnobNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobFromJSON").set_body_typed(Knob::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobIsValidDecision") .set_body_method(&KnobNode::IsValidDecision); -TVM_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.KnobApply").set_body_method(&KnobNode::Apply); TVM_REGISTER_NODE_TYPE(TraceNode); -TVM_REGISTER_GLOBAL("relax.tuning_api.Trace") +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.Trace") .set_body_typed([](IRModule in_mod, Array knobs, Array decisions) { return Trace(in_mod, knobs, decisions); }); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod").set_body_method(&TraceNode::SetOutMod); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceVerify").set_body_method(&TraceNode::Verify); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceAdd").set_body_method(&TraceNode::Add); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceSetPerf").set_body_method(&TraceNode::SetPerf); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceSetOutMod").set_body_method(&TraceNode::SetOutMod); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); -TVM_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("relax.tuning_api.TraceFromJSON").set_body_typed(Trace::FromJSON); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 062ac97a35f7..472f454bc11a 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -104,7 +104,8 @@ Pass UpdateParamStructInfo(ffi::TypedFunction(Var)> sinfo_f return tvm::transform::CreateModulePass(pass_func, 1, "UpdateParamStructInfo", {}); } -TVM_REGISTER_GLOBAL("relax.transform.UpdateParamStructInfo").set_body_typed(UpdateParamStructInfo); +TVM_FFI_REGISTER_GLOBAL("relax.transform.UpdateParamStructInfo") + .set_body_typed(UpdateParamStructInfo); } // namespace transform } // namespace relax diff --git a/src/relax/transform/update_vdevice.cc b/src/relax/transform/update_vdevice.cc index 5a8346578e7c..d2a1f85be853 100644 --- a/src/relax/transform/update_vdevice.cc +++ b/src/relax/transform/update_vdevice.cc @@ -106,7 +106,7 @@ Pass UpdateVDevice(VDevice new_vdevice, int64_t index) { /*pass_name=*/"UpdateVDevice", /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice); +TVM_FFI_REGISTER_GLOBAL("relax.transform.UpdateVDevice").set_body_typed(UpdateVDevice); } // namespace transform } // namespace relax diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 67d5fd4875e0..edd953e3126e 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -115,7 +115,7 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor entry_funcs); +TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, Array entry_funcs); /*! * \brief Get the external symbol of the Relax function name. @@ -434,7 +434,7 @@ Expr CanonicalizeBindings(Expr expr); * * \ret The updated function. */ -Function BundleModelParams(const Function& func, Optional param_tuple_name = NullOpt); +Function BundleModelParams(const Function& func, Optional param_tuple_name = std::nullopt); /*! \brief Compose two functions * diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 96fd5578e40a..ab270c08a65d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -245,7 +245,7 @@ Expr GetBoundValue(const Binding& b) { */ Function CopyWithNewVars(Function func) { return FunctionCopier().Copy(func); } -TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); +TVM_FFI_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars); } // namespace relax } // namespace tvm diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index ba3dddb7ae49..7f7d416f88d9 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -22,22 +22,22 @@ * \brief Functions for conversion between fp32 and fp16 */ #include -#include +#include extern "C" { // disable under msvc #ifndef _MSC_VER -TVM_DLL TVM_WEAK uint16_t __gnu_f2h_ieee(float a) { +TVM_DLL TVM_FFI_WEAK uint16_t __gnu_f2h_ieee(float a) { return __truncXfYf2__(a); } -TVM_DLL TVM_WEAK float __gnu_h2f_ieee(uint16_t a) { +TVM_DLL TVM_FFI_WEAK float __gnu_h2f_ieee(uint16_t a) { return __extendXfYf2__(a); } -TVM_DLL TVM_WEAK uint16_t __truncdfhf2(double a) { +TVM_DLL TVM_FFI_WEAK uint16_t __truncdfhf2(double a) { return __truncXfYf2__(a); } diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc deleted file mode 100644 index b76687522d69..000000000000 --- a/src/runtime/c_runtime_api.cc +++ /dev/null @@ -1,805 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file c_runtime_api.cc - * \brief Device specific implementations - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "object_internal.h" -#include "runtime_base.h" - -namespace tvm { -namespace runtime { - -std::string GetCustomTypeName(uint8_t type_code) { - const auto f = tvm::ffi::Function::GetGlobalRequired("runtime._datatype_get_type_name"); - return f(type_code).cast(); -} - -uint8_t GetCustomTypeCode(const std::string& type_name) { - const auto f = tvm::ffi::Function::GetGlobalRequired("runtime._datatype_get_type_code"); - return f(type_name).cast(); -} - -bool GetCustomTypeRegistered(uint8_t type_code) { - const auto f = tvm::ffi::Function::GetGlobalRequired("runtime._datatype_get_type_registered"); - return f(type_code).cast(); -} - -uint8_t ParseCustomDatatype(const std::string& s, const char** scan) { - ICHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string"; - - auto tmp = s.c_str(); - - ICHECK(s.c_str() == tmp); - *scan = s.c_str() + 6; - ICHECK(s.c_str() == tmp); - if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s; - ICHECK(s.c_str() == tmp); - *scan += 1; - ICHECK(s.c_str() == tmp); - size_t custom_name_len = 0; - ICHECK(s.c_str() == tmp); - while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']') - ++custom_name_len; - ICHECK(s.c_str() == tmp); - if (*(*scan + custom_name_len) != ']') - LOG(FATAL) << "expected closing brace after 'custom' type in" << s; - ICHECK(s.c_str() == tmp); - *scan += custom_name_len + 1; - ICHECK(s.c_str() == tmp); - - auto type_name = s.substr(7, custom_name_len); - ICHECK(s.c_str() == tmp); - return GetCustomTypeCode(type_name); -} - -class DeviceAPIManager { - public: - static const int kMaxDeviceAPI = TVMDeviceExtType_End; - // Get API - static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); } - static DeviceAPI* Get(int dev_type, bool allow_missing = false) { - return Global()->GetAPI(dev_type, allow_missing); - } - - private: - std::array api_; - DeviceAPI* rpc_api_{nullptr}; - std::mutex mutex_; - // constructor - DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } - // Global static variable. - static DeviceAPIManager* Global() { - static DeviceAPIManager* inst = new DeviceAPIManager(); - return inst; - } - // Get or initialize API. - DeviceAPI* GetAPI(int type, bool allow_missing) { - if (type < kRPCSessMask) { - if (api_[type] != nullptr) return api_[type]; - std::lock_guard lock(mutex_); - if (api_[type] != nullptr) return api_[type]; - api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing); - return api_[type]; - } else { - if (rpc_api_ != nullptr) return rpc_api_; - std::lock_guard lock(mutex_); - if (rpc_api_ != nullptr) return rpc_api_; - rpc_api_ = GetAPI("rpc", allow_missing); - return rpc_api_; - } - } - DeviceAPI* GetAPI(const std::string name, bool allow_missing) { - std::string factory = "device_api." + name; - const auto f = tvm::ffi::Function::GetGlobal(factory); - if (!f.has_value()) { - ICHECK(allow_missing) << "Device API " << name << " is not enabled."; - return nullptr; - } - void* ptr = (*f)().cast(); - return static_cast(ptr); - } -}; - -DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) { - return DeviceAPIManager::Get(static_cast(dev.device_type), allow_missing); -} - -void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint); -} - -static size_t GetDataAlignment(const DLDataType dtype) { - size_t align = (dtype.bits / 8) * dtype.lanes; - if (align < kAllocAlignment) return kAllocAlignment; - return align; -} - -size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { - size_t size = 1; - for (tvm_index_t i = 0; i < arr.ndim; ++i) { - size *= static_cast(arr.shape[i]); - } - size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; - return size; - } - LOG(FATAL) << "Device does not support physical mem computation with " - << "specified memory scope: " << mem_scope.value(); - return 0; -} - -void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope) { - if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") { - // by default, we can always redirect to the flat memory allocations - DLTensor temp; - temp.data = nullptr; - temp.device = dev; - temp.ndim = ndim; - temp.dtype = dtype; - temp.shape = const_cast(shape); - temp.strides = nullptr; - temp.byte_offset = 0; - size_t size = GetDataSize(temp); - size_t alignment = GetDataAlignment(temp.dtype); - return AllocDataSpace(dev, size, alignment, dtype); - } - LOG(FATAL) << "Device does not support allocate data space with " - << "specified memory scope: " << mem_scope.value(); - return nullptr; -} - -void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - // by default, we can always redirect to the flat memory copy operation. - size_t nbytes = GetDataSize(*from); - ICHECK_EQ(nbytes, GetDataSize(*to)); - - ICHECK(IsContiguous(*from) && IsContiguous(*to)) - << "CopyDataFromTo only support contiguous array for now"; - CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, - to->device, from->dtype, stream); -} - -void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, - size_t num_bytes, Device dev_from, Device dev_to, - DLDataType type_hint, TVMStreamHandle stream) { - LOG(FATAL) << "Device does not support CopyDataFromTo."; -} - -void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } - -TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } - -void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} - -TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; } - -void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { -} - -//-------------------------------------------------------- -// Error handling mechanism -// ------------------------------------------------------- -// Standard error message format, {} means optional -//-------------------------------------------------------- -// {error_type:} {message0} -// {message1} -// {message2} -// {Stack trace:} // stack traces follow by this line -// {trace 0} // two spaces in the beginning. -// {trace 1} -// {trace 2} -//-------------------------------------------------------- -/*! - * \brief Normalize error message - * - * Parse them header generated by LOG(FATAL) and ICHECK - * and reformat the message into the standard format. - * - * This function will also merge all the stack traces into - * one trace and trim them. - * - * \param err_msg The error message. - * \return normalized message. - */ -std::string NormalizeError(std::string err_msg) { - // ------------------------------------------------------------------------ - // log with header, {} indicates optional - //------------------------------------------------------------------------- - // [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0} - // {message1} - // Stack trace: - // {stack trace 0} - // {stack trace 1} - //------------------------------------------------------------------------- - // Normalzied version - //------------------------------------------------------------------------- - // error_type: check_msg message0 - // {message1} - // Stack trace: - // File file_name, line lineno - // {stack trace 0} - // {stack trace 1} - //------------------------------------------------------------------------- - // LEGACY-COMPACT: - // skip python-style error style - // TODO(tqchen) move to new FFI handling - if (err_msg.find("Traceback (most recent call last)") != std::string::npos) { - return err_msg; - } - int line_number = 0; - std::istringstream is(err_msg); - std::string line, file_name, error_type, check_msg; - - // Parse log header and set the fields, - // Return true if it the log is in correct format, - // return false if something is wrong. - auto parse_log_header = [&]() { - // skip timestamp - if (is.peek() != '[') { - getline(is, line); - return true; - } - if (!(is >> line)) return false; - // get filename - while (is.peek() == ' ') is.get(); -#ifdef _MSC_VER // handle volume separator ":" in Windows path - std::string drive; - if (!getline(is, drive, ':')) return false; - if (!getline(is, file_name, ':')) return false; - file_name = drive + ":" + file_name; -#else - if (!getline(is, file_name, ':')) return false; -#endif - // get line number - if (!(is >> line_number)) return false; - // get rest of the message. - while (is.peek() == ' ' || is.peek() == ':') is.get(); - if (!getline(is, line)) return false; - // detect check message, rewrite to remote extra : - if (line.compare(0, 13, "Check failed:") == 0) { - std::string ending = ": "; - size_t end_pos = line.find(ending, 13); - if (end_pos == std::string::npos) return false; - check_msg = line.substr(0, end_pos + ending.size()); - line = line.substr(end_pos + ending.size()); - } - return true; - }; - // if not in correct format, do not do any rewrite. - if (!parse_log_header()) return err_msg; - // Parse error type. - { - size_t start_pos = 0, end_pos; - for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { - } - for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { - char ch = line[end_pos]; - if (ch == ':') { - error_type = line.substr(start_pos, end_pos - start_pos); - break; - } - // [A-Z0-9a-z_.] - if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break; - } - if (error_type.length() != 0) { - // if we successfully detected error_type: trim the following space. - for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; - ++start_pos) { - } - line = line.substr(start_pos); - } else { - // did not detect error_type, use default value. - line = line.substr(start_pos); - error_type = "TVMError"; - } - } - // Separate out stack trace. - std::ostringstream os; - os << error_type << ": " << check_msg << line << '\n'; - - bool trace_mode = true; - std::vector stack_trace; - while (getline(is, line)) { - if (trace_mode) { - if (line.compare(0, 2, " ") == 0) { - stack_trace.push_back(line); - } else { - trace_mode = false; - // remove EOL trailing stacktrace. - if (line.length() == 0) continue; - } - } - if (!trace_mode) { - if (line.compare(0, 11, "Stack trace") == 0) { - trace_mode = true; - } else { - os << line << '\n'; - } - } - } - if (stack_trace.size() != 0 || file_name.length() != 0) { - os << "Stack trace:\n"; - if (file_name.length() != 0) { - os << " File \"" << file_name << "\", line " << line_number << "\n"; - } - // Print out stack traces, optionally trim the c++ traces - // about the frontends (as they will be provided by the frontends). - bool ffi_boundary = false; - for (const auto& line : stack_trace) { - // Heuristic to detect python ffi. - if (line.find("libffi.so") != std::string::npos || - line.find("core.cpython") != std::string::npos) { - ffi_boundary = true; - } - // If the backtrace is not c++ backtrace with the prefix " [bt]", - // then we can stop trimming. - if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) { - ffi_boundary = false; - } - if (!ffi_boundary) { - os << line << '\n'; - } - // The line after TVMFuncCall cound be in FFI. - if (line.find("(TVMFuncCall") != std::string::npos) { - ffi_boundary = true; - } - } - } - return os.str(); -} - -} // namespace runtime -} // namespace tvm - -using namespace tvm::runtime; - -struct WrappedPythonError : Error { - WrappedPythonError() : Error("WrappedPythonError", "", TVM_FFI_TRACEBACK_HERE) {} - explicit WrappedPythonError(WrappedPythonObject obj) - : Error("WrappedPythonError", "", TVM_FFI_TRACEBACK_HERE), obj(std::move(obj)) {} - - WrappedPythonObject obj; -}; - -struct TVMRuntimeEntry { - std::string ret_str; - TVMByteArray ret_bytes; - - std::variant last_error; - std::string last_error_formatted; -}; - -typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; - -const char* TVMGetLastError() { - auto* store = TVMAPIRuntimeStore::Get(); - const auto& last_error = store->last_error; - if (const auto* message = std::get_if(&last_error)) { - return message->c_str(); - } else if (const auto* internal = std::get_if(&last_error)) { - // Use last_error_formatted to store the formatted error message, to avoid - // dangling pointer. - store->last_error_formatted = internal->what(); - return store->last_error_formatted.c_str(); - } else { - return nullptr; - } -} - -void* TVMGetLastPythonError() { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - if (auto* wrapped = std::get_if(&last_error)) { - return wrapped->obj.raw_pointer(); - } else { - return nullptr; - } -} - -const char* TVMGetLastBacktrace() { - const auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - static thread_local std::string traceback; - if (const auto* wrapped = std::get_if(&last_error)) { - traceback = wrapped->traceback(); - return traceback.c_str(); - } else if (const auto* wrapped = std::get_if(&last_error)) { - traceback = wrapped->traceback(); - return traceback.c_str(); - } else { - return nullptr; - } -} - -void TVMDropLastPythonError() { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - if (std::get_if(&last_error)) { - last_error = ""; - } -} - -int TVMAPIHandleException(const std::exception& e) { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - - if (const auto* wrapped = dynamic_cast(&e)) { - last_error = *wrapped; - } else if (const auto* internal = dynamic_cast(&e)) { - last_error = *internal; - } else { - last_error = NormalizeError(e.what()); - } - return -1; -} - -void TVMAPISetLastPythonError(void* obj) { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - last_error = WrappedPythonError(WrappedPythonObject(obj)); -} - -void TVMThrowLastError() { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - if (auto* wrapped = std::get_if(&last_error)) { - WrappedPythonError wrapped_err; - std::swap(wrapped_err, *wrapped); - throw wrapped_err; - } else if (auto* internal = std::get_if(&last_error)) { - throw *internal; - } else { - // redirect to tvm-ffi error handling. - throw ::tvm::ffi::details::MoveFromSafeCallRaised(); - } -} - -void TVMAPISetLastError(const char* msg) { - auto& last_error = TVMAPIRuntimeStore::Get()->last_error; - last_error = msg; -} - -int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { - API_BEGIN(); - tvm::ffi::Any ret; - ret = Module::LoadFromFile(file_name, format); - TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(ret)); - *out = val.v_obj; - API_END(); -} - -int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { - API_BEGIN(); - ObjectInternal::GetModuleNode(mod)->Import(GetRef(ObjectInternal::GetModuleNode(dep))); - API_END(); -} - -int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, - TVMFunctionHandle* func) { - API_BEGIN(); - tvm::ffi::Function pf = - ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); - if (pf != nullptr) { - tvm::ffi::Any ret = pf; - TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(ret)); - *func = val.v_obj; - } else { - *func = nullptr; - } - API_END(); -} - -int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } - -int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { - API_BEGIN(); - *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name))->get(); - API_END(); -} - -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, - int dtype_bits_hint) { - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - - DLDataType type_hint; - type_hint.code = static_cast(dtype_code_hint); - type_hint.bits = static_cast(dtype_bits_hint); - type_hint.lanes = 1; - - return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast(size), type_hint); -} - -int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr); - return 0; -} - -int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { - if (*handle == nullptr) { - *handle = reinterpret_cast(1); - return (*f)(cdata); - } - return 0; -} - -int TVMFuncFree(TVMFunctionHandle func) { return TVMObjectFree(func); } - -int TVMByteArrayFree(TVMByteArray* arr) { - if (arr == &TVMAPIRuntimeStore::Get()->ret_bytes) { - return 0; // Thread-local storage does not need explicit deleting. - } - - delete arr; - return 0; -} - -int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, - TVMValue* ret_val, int* ret_type_code) { - API_BEGIN(); - tvm::ffi::Any rv; - tvm::ffi::FunctionObj* ffi_func = static_cast(func); - std::vector args_vec(num_args); - tvm::runtime::LegacyTVMArgsToPackedArgs(args, arg_type_codes, num_args, args_vec.data()); - ffi_func->CallPacked(args_vec.data(), args_vec.size(), &rv); - // special handle of certain return types. - if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIDataType || - rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIBytes || - rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIStr) { - // TODO(tvm-team): handle bytes return type here - TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); - if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIDataType) { - e->ret_str = DLDataTypeToString(rv.cast()); - *ret_type_code = kTVMStr; - ret_val->v_str = e->ret_str.c_str(); - } else if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIBytes) { - e->ret_str = rv.cast(); - e->ret_bytes.data = e->ret_str.c_str(); - e->ret_bytes.size = e->ret_str.length(); - *ret_type_code = kTVMBytes; - ret_val->v_handle = &(e->ret_bytes); - } else if (rv.type_index() == tvm::ffi::TypeIndex::kTVMFFIStr) { - e->ret_str = rv.cast(); - *ret_type_code = kTVMStr; - ret_val->v_str = e->ret_str.c_str(); - } - } else { - MoveAnyToLegacyTVMValue(std::move(rv), ret_val, ret_type_code); - } - API_END(); -} - -int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { - API_BEGIN(); - ICHECK_EQ(num_ret, 1); - tvm::ffi::Any* rv = static_cast(ret); - *rv = LegacyTVMArgValueToAnyView(value[0], type_code[0]); - API_END(); -} - -int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, - TVMFunctionHandle* out) { - API_BEGIN(); - if (fin == nullptr) { - tvm::ffi::Any ret; - ret = tvm::ffi::Function::FromPacked( - [func, resource_handle](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - // run ABI translation - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - int ret = func(values.data(), type_codes.data(), args.size(), rv, resource_handle); - if (ret != 0) { - TVMThrowLastError(); - } - }); - TVMValue val; - int type_code; - MoveAnyToLegacyTVMValue(std::move(ret), &val, &type_code); - *out = val.v_handle; - } else { - // wrap it in a shared_ptr, with fin as deleter. - // so fin will be called when the lambda went out of scope. - std::shared_ptr rpack(resource_handle, fin); - tvm::ffi::Any ret; - ret = - tvm::ffi::Function::FromPacked([func, rpack](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { - // run ABI translation - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - int ret = func(values.data(), type_codes.data(), args.size(), rv, rpack.get()); - - if (ret != 0) { - TVMThrowLastError(); - } - }); - TVMValue val; - val.v_handle = nullptr; - int type_code; - MoveAnyToLegacyTVMValue(std::move(ret), &val, &type_code); - *out = val.v_handle; - } - API_END(); -} - -int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - *out = DeviceAPIManager::Get(dev)->CreateStream(dev); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamCreate").set_body_typed([](DLDevice dev) { - return reinterpret_cast(DeviceAPIManager::Get(dev)->CreateStream(dev)); -}); - -int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->FreeStream(dev, stream); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamFree").set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast(stream)); -}); - -int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->SetStream(dev, stream); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_SetStream").set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast(stream)); -}); - -int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->StreamSync(dev, stream); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamSync").set_body_typed([](DLDevice dev, int64_t stream) { - DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast(stream)); -}); - -int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, - TVMStreamHandle dst) { - API_BEGIN(); - DLDevice dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, src, dst); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.Device_StreamSyncFromTo") - .set_body_typed([](DLDevice dev, int64_t src, int64_t dst) { - DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), - reinterpret_cast(dst)); - }); - -int TVMCbArgToReturn(TVMValue* value, int* code) { - API_BEGIN(); - AnyView arg = LegacyTVMArgValueToAnyView(*value, *code); - Any rv; - if (auto opt_rv = arg.as>()) { - rv = *std::move(*std::move(opt_rv)); - } else { - rv = arg; - } - MoveAnyToLegacyTVMValue(std::move(rv), value, code); - API_END(); -} - -int TVMDeviceAllocDataSpace(DLDevice dev, size_t nbytes, size_t alignment, DLDataType type_hint, - void** out_data) { - API_BEGIN(); - out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, nbytes, alignment, type_hint); - API_END(); -} - -int TVMDeviceAllocDataSpaceWithScope(DLDevice dev, int ndim, const int64_t* shape, DLDataType dtype, - const char* mem_scope, void** out_data) { - API_BEGIN(); - Optional scope; - if (mem_scope != nullptr) { - scope = String(std::string(mem_scope)); - } - out_data[0] = DeviceAPIManager::Get(dev)->AllocDataSpace(dev, ndim, shape, dtype, scope); - API_END(); -} - -int TVMDeviceFreeDataSpace(DLDevice dev, void* ptr) { - API_BEGIN(); - DeviceAPIManager::Get(dev)->FreeDataSpace(dev, ptr); - API_END(); -} - -int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - API_BEGIN(); - DLDevice dev_from = from->device; - DLDevice dev_to = to->device; - DLDevice dev = dev_from.device_type != kDLCPU ? dev_from : dev_to; - DeviceAPIManager::Get(dev)->CopyDataFromTo(from, to, stream); - API_END(); -} - -// set device api -TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - DeviceAPIManager::Get(dev)->SetDevice(dev); - }); - -// set device api -TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") - .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { - DLDevice dev; - dev.device_type = static_cast(args[0].cast()); - dev.device_id = args[1].cast(); - - DeviceAttrKind kind = static_cast(args[2].cast()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); - if (api != nullptr) { - api->GetAttr(dev, kind, ret); - } else { - *ret = 0; - } - } else { - DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); - } - }); - -TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream); diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index 2536847726c8..84cd4943c552 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -27,12 +27,12 @@ * code and constants significantly reduces the efforts for handling external * codegen and runtimes. */ -#include -#include -#include +#include +#include +#include +#include +#include #include -#include -#include #include @@ -143,7 +143,7 @@ class ConstLoaderModuleNode : public ModuleNode { // Initialize the module with constants. int ret = init(md).cast(); // Report the error if initialization is failed. - ICHECK_EQ(ret, 0) << TVMGetLastError(); + ICHECK_EQ(ret, 0); break; } } @@ -247,7 +247,7 @@ Module ConstLoaderModuleCreate( return Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_const_loader") .set_body_typed(ConstLoaderModuleNode::LoadFromBinary); } // namespace runtime diff --git a/src/runtime/container.cc b/src/runtime/container.cc deleted file mode 100644 index 004feb3afba8..000000000000 --- a/src/runtime/container.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/runtime/container.cc - * \brief Implementations of common containers. - */ -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace runtime { - -// Array -TVM_REGISTER_OBJECT_TYPE(ArrayObj); - -TVM_REGISTER_GLOBAL("runtime.Array").set_body_packed([](ffi::PackedArgs args, Any* ret) { - Array result; - for (int i = 0; i < args.size(); ++i) { - result.push_back(args[i]); - } - *ret = result; -}); - -TVM_REGISTER_GLOBAL("runtime.ArrayGetItem") - .set_body_typed([](const ffi::ArrayObj* n, int64_t i) -> Any { return n->at(i); }); - -TVM_REGISTER_GLOBAL("runtime.ArraySize").set_body_typed([](const ffi::ArrayObj* n) -> int64_t { - return static_cast(n->size()); -}); - -// String -TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { - return String(std::move(str)); -}); - -TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { - return std::string(str); -}); - -// Map -TVM_REGISTER_GLOBAL("runtime.Map").set_body_packed([](ffi::PackedArgs args, Any* ret) { - ICHECK_EQ(args.size() % 2, 0); - Map data; - for (int i = 0; i < args.size(); i += 2) { - data.Set(args[i], args[i + 1]); - } - *ret = data; -}); - -TVM_REGISTER_GLOBAL("runtime.MapSize").set_body_typed([](const ffi::MapObj* n) -> int64_t { - return static_cast(n->size()); -}); - -TVM_REGISTER_GLOBAL("runtime.MapGetItem") - .set_body_typed([](const ffi::MapObj* n, const Any& k) -> Any { return n->at(k); }); - -TVM_REGISTER_GLOBAL("runtime.MapCount") - .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }); - -TVM_REGISTER_GLOBAL("runtime.MapItems").set_body_typed([](const ffi::MapObj* n) -> Array { - Array rkvs; - for (const auto& kv : *n) { - rkvs.push_back(kv.first); - rkvs.push_back(kv.second); - } - return rkvs; -}); - -// ShapeTuple -TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj); - -TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body_packed([](ffi::PackedArgs args, Any* ret) { - std::vector shape; - for (int i = 0; i < args.size(); ++i) { - shape.push_back(args[i].cast()); - } - *ret = ShapeTuple(shape); -}); - -TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) { - return static_cast(shape.size()); -}); - -TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) { - ICHECK_LT(idx, shape.size()); - return shape[idx]; -}); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index 72225f39954f..1eb63a10fa4c 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -21,8 +21,7 @@ * \file src/runtime/contrib/amx/amx_config.cc * \brief extraction of AMX configuration on x86 platforms */ -#include -#include +#include namespace tvm { namespace runtime { @@ -76,7 +75,7 @@ void init_tile_config(__tilecfg_u* dst, uint16_t cols, uint8_t rows) { _tile_loadconfig(dst->a); } -TVM_REGISTER_GLOBAL("runtime.amx_tileconfig") +TVM_FFI_REGISTER_GLOBAL("runtime.amx_tileconfig") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int rows = args[0].cast(); int cols = args[1].cast(); @@ -90,7 +89,7 @@ TVM_REGISTER_GLOBAL("runtime.amx_tileconfig") }); // register a global packed function in c++,to init the system for AMX config -TVM_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("runtime.amx_init").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // -----------Detect and request for AMX control---------------------- uint64_t bitmask = 0; int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); diff --git a/src/runtime/contrib/arm_compute_lib/acl_allocator.h b/src/runtime/contrib/arm_compute_lib/acl_allocator.h index d4e72a73314f..a755393209ec 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_allocator.h +++ b/src/runtime/contrib/arm_compute_lib/acl_allocator.h @@ -28,9 +28,9 @@ #include #include #include +#include #include #include -#include #include diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 5687e687cfb6..eeca2fcdf347 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for Arm Compute Library. */ +#include #include -#include #include "../json/json_node.h" #include "../json/json_runtime.h" @@ -593,8 +593,8 @@ runtime::Module ACLRuntimeCreate(const String& symbol_name, const String& graph_ return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib") +TVM_FFI_REGISTER_GLOBAL("runtime.arm_compute_lib_runtime_create").set_body_typed(ACLRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_arm_compute_lib") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index cb921aa729a1..aed0080589e0 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -22,9 +22,9 @@ * \brief Simple JSON runtime for Apple BNNS primitives */ +#include #include #include -#include #include #include @@ -562,9 +562,9 @@ runtime::Module BNNSJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.BNNSJSONRuntimeCreate").set_body_typed(BNNSJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_bnns_json") .set_body_typed(BNNSJSONRuntime::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 155e1f05f197..4d04d8263447 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -20,9 +20,9 @@ /*! * \file Use external cblas library call. */ +#include #include #include -#include extern "C" { #include @@ -123,7 +123,7 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -134,7 +134,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") CallGemm(args, ret, CblasDgemmOp()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index 18840eb55db1..68819d015326 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -20,9 +20,9 @@ /*! * \file Use external cblas library call. */ +#include #include #include -#include extern "C" { #include @@ -46,7 +46,7 @@ struct DNNLSgemmOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.dnnl.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 14b74d4736fc..a44cf1b365ec 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -25,8 +25,8 @@ #ifndef TVM_RUNTIME_CONTRIB_CBLAS_GEMM_COMMON_H_ #define TVM_RUNTIME_CONTRIB_CBLAS_GEMM_COMMON_H_ +#include #include -#include #include #include diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index f98df0c6d624..33b52e5e375d 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -20,9 +20,9 @@ /*! * \file Use external mkl library call. */ +#include #include #include -#include extern "C" { #include @@ -154,7 +154,7 @@ struct MKLDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -166,7 +166,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul") }); // integer matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); @@ -177,7 +177,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.matmul_u8s8s32") CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -188,7 +188,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mkl.batch_matmul_iterative") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 21584997a76f..5ee90e29b009 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -291,7 +291,7 @@ class CLMLRuntime : public JSONRuntimeBase { // Dump tensor to CPU std::vector shape = node.GetOpShape()[0]; DLDataType tvm_dtype = node.GetOpDataType()[0]; - NDArray narr = NDArray::Empty(ShapeTuple(shape), tvm_dtype, {kDLCPU, 0}); + NDArray narr = NDArray::Empty(ffi::Shape(shape), tvm_dtype, {kDLCPU, 0}); CopyDataFromCLMLTensor(clml_desc, narr.operator->()->data); // Naming convention @@ -1830,8 +1830,8 @@ runtime::Module CLMLRuntimeCreate(const String& symbol_name, const String& graph return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.clml_runtime_create").set_body_typed(CLMLRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_clml") +TVM_FFI_REGISTER_GLOBAL("runtime.clml_runtime_create").set_body_typed(CLMLRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_clml") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/clml/clml_runtime.h b/src/runtime/contrib/clml/clml_runtime.h index faada2ddeeb5..4431b63cafcc 100644 --- a/src/runtime/contrib/clml/clml_runtime.h +++ b/src/runtime/contrib/clml/clml_runtime.h @@ -32,9 +32,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h index b3f7e846e0ec..5f5eec1d03ca 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.h +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -29,8 +29,8 @@ #import #include +#include #include -#include #include #include diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 7b2733c4312e..f98c97f68b12 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -20,7 +20,7 @@ /*! * \file coreml_runtime.cc */ -#include +#include #include "coreml_runtime.h" @@ -192,7 +192,7 @@ Module CoreMLRuntimeCreate(const std::string& symbol, const std::string& model_p return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create") +TVM_FFI_REGISTER_GLOBAL("tvm.coreml_runtime.create") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = CoreMLRuntimeCreate(args[0], args[1]); }); @@ -249,7 +249,8 @@ Module CoreMLRuntimeLoadFromBinary(void* strm) { return Module(exec); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_coreml").set_body_typed(CoreMLRuntimeLoadFromBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_coreml") + .set_body_typed(CoreMLRuntimeLoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index e3222e3adc40..19d83e624d91 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,9 +20,9 @@ /*! * \file Use external cblas library call. */ +#include #include #include -#include #include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -514,7 +514,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); @@ -539,7 +539,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") }); #if CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); @@ -557,7 +557,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") }); #endif // CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index c9c6cf85c6ba..8f7b6ac1f188 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for CUBLAS. */ +#include #include -#include #include #include @@ -153,9 +153,9 @@ runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 5844f802fd84..53e00fe14199 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -23,7 +23,7 @@ #include "cublas_utils.h" #include -#include +#include #include "../../cuda/cuda_common.h" diff --git a/src/runtime/contrib/cudnn/conv_backward.cc b/src/runtime/contrib/cudnn/conv_backward.cc index 52c69c81cf08..a19fc192efd1 100644 --- a/src/runtime/contrib/cudnn/conv_backward.cc +++ b/src/runtime/contrib/cudnn/conv_backward.cc @@ -20,9 +20,9 @@ /*! * \file cuDNN kernel calls for backward algorithms. */ +#include #include #include -#include #include "cudnn_utils.h" @@ -185,7 +185,7 @@ void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], c ret[0] = static_cast(best_algo); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -206,7 +206,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); @@ -225,7 +225,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo") data_dtype, conv_dtype, verbose, ret); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -246,7 +246,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter") dw, conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 87e6121e74c7..856d796e9038 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -20,9 +20,9 @@ /*! * \file cuDNN kernel calls for the forward algorithm. */ +#include #include #include -#include #include "cudnn_utils.h" @@ -153,7 +153,7 @@ void FindAlgo(int format, int dims, int groups, const int pad[], const int strid ret[0] = static_cast(best_algo); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -174,7 +174,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -198,7 +198,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward") dilation_v, x, w, y, bias, conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int mode = args[0].cast(); int format = args[1].cast(); @@ -219,7 +219,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { int format = args[0].cast(); int dims = args[1].cast(); diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc index f8b170fe2052..dffce6738907 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc @@ -24,8 +24,8 @@ #include "./attention.h" +#include #include -#include #include "../../../cuda/cuda_common.h" #include "../cudnn_utils.h" diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h index 4d0309fb3ba6..ae11764ce02c 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -26,7 +26,7 @@ #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_ #include -#include +#include #include #include diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 08909a3150c2..eda3b694d7f0 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for CUDNN. */ +#include #include -#include #include #include @@ -237,9 +237,9 @@ runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.cuDNNJSONRuntimeCreate").set_body_typed(cuDNNJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.cuDNNJSONRuntimeCreate").set_body_typed(cuDNNJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cudnn_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cudnn_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 191fb9af325e..8e2e85c67524 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -24,8 +24,8 @@ #include "cudnn_utils.h" #include +#include #include -#include #include #include @@ -177,7 +177,7 @@ void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int g entry_ptr->conv_entry.tensor_format = static_cast(format); // Set Data Type entry_ptr->conv_entry.data_type = - CuDNNDataType::DLTypeToCuDNNType(runtime::StringToDLDataType(conv_dtype)); + CuDNNDataType::DLTypeToCuDNNType(ffi::StringToDLDataType(conv_dtype)); cudnnDataType_t cudnn_data_type = CuDNNDataType::DLTypeToCuDNNType(data_dtype); @@ -265,7 +265,7 @@ SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_des SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool { +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.exists").set_body_typed([]() -> bool { return CuDNNThreadEntry::ThreadLocal(false)->exists(); }); diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index c2b3ac3db84c..aa37acd2c3a9 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -21,8 +21,8 @@ * \file src/runtime/contrib/cudnn/softmax.cc * \brief Use external cudnn softmax function */ +#include #include -#include #include "cudnn_utils.h" @@ -77,12 +77,12 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* r entry_ptr->softmax_entry.shape_desc, y->data)); } -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); }); diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc index e8aaf31fc5f3..e31c5fdfebf8 100644 --- a/src/runtime/contrib/curand/curand.cc +++ b/src/runtime/contrib/curand/curand.cc @@ -17,8 +17,8 @@ * under the License. */ #include -#include -#include +#include +#include #include "../../cuda/cuda_common.h" #include "./helper_cuda_kernels.h" @@ -112,7 +112,7 @@ void RandomFill(DLTensor* tensor) { TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr); } -TVM_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); +TVM_FFI_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); } // namespace curand } // namespace runtime diff --git a/src/runtime/contrib/curand/helper_cuda_kernels.h b/src/runtime/contrib/curand/helper_cuda_kernels.h index 582162579a3a..6df29ee69056 100644 --- a/src/runtime/contrib/curand/helper_cuda_kernels.h +++ b/src/runtime/contrib/curand/helper_cuda_kernels.h @@ -20,7 +20,7 @@ #define TVM_RUNTIME_CONTRIB_CURAND_HELPER_CUDA_KERNELS_H_ #include -#include +#include namespace tvm { namespace runtime { diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu b/src/runtime/contrib/cutlass/fp16_group_gemm.cu index f09925ceecd6..dffe7dc4ffed 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cu @@ -20,8 +20,8 @@ #include #include #include -#include -#include +#include +#include #include "group_gemm_runner.cuh" @@ -60,7 +60,7 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr static_cast(out->data), stream); } -TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") .set_body_typed(tvm_cutlass_group_gemm_sm90); } // namespace runtime diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu index bf257622f1f9..b8732357c7bd 100644 --- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu @@ -20,8 +20,8 @@ #include #include #include -#include -#include +#include +#include #include "../cublas/cublas_utils.h" #include "blockwise_scaled_gemm_runner.cuh" @@ -153,9 +153,9 @@ void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray scales_a } } -TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") +TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm); -TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") +TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm); } // namespace runtime diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index 485929570592..4ee31e73abca 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -20,8 +20,8 @@ #include #include #include -#include -#include +#include +#include #include "../cublas/cublas_utils.h" #include "gemm_runner.cuh" @@ -77,15 +77,15 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray } } -TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") .set_body_typed( tvm_cutlass_fp8_gemm); -TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_gemm); -TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_gemm); diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm.cu index fd528a22cc1a..62a91dec1809 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu @@ -20,8 +20,8 @@ #include #include #include -#include -#include +#include +#include #include "group_gemm_runner.cuh" @@ -66,15 +66,15 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr static_cast(out->data), stream); } -TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") .set_body_typed( tvm_cutlass_fp8_group_gemm); -TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_group_gemm); -TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16") +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16") .set_body_typed( tvm_cutlass_fp8_group_gemm); diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index 5fded82762a3..5fece6166158 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -17,9 +17,8 @@ * under the License. */ +#include #include -#include -#include #include "cutlass_kernels/cutlass_preprocessors.h" @@ -35,7 +34,7 @@ namespace runtime { // black box. // // The preprocessing functions are defined in C++, so we need to copy the input weight to CPU. -TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") +TVM_FFI_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) { bool is_2d = packed_weight->ndim == 2; int num_experts = is_2d ? 1 : packed_weight->shape[0]; diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 3e73b19116ee..9cc053ec7ca4 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -348,7 +348,7 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ } // DNNL Conv2d single OP -TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto weights = args[1].cast(); diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index b06b17c17d8e..154ee12790f7 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for DNNL. */ +#include #include -#include #include #include @@ -927,9 +927,9 @@ runtime::Module DNNLJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.DNNLJSONRuntimeCreate").set_body_typed(DNNLJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_dnnl_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index 04e06d9c9e94..f12467a67e64 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -25,9 +25,9 @@ #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ -#include +#include +#include #include -#include #include diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 2a2462786327..5d706836e6ce 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace tvm { namespace runtime { @@ -68,7 +68,7 @@ Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") +TVM_FFI_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = EdgeTPURuntimeCreate(args[0], args[1]); }); diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index c85f15cc743a..fb4e394e7fc2 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -20,9 +20,9 @@ /*! * \file Use external hipblas library call. */ +#include #include #include -#include #include "../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -300,8 +300,8 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) << "leading dimension must divide 4 for int8 gemm"; ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - double alpha = args.size() > 5 ? args[5] : 1.0; - double beta = args.size() > 6 ? args[6] : 0.0; + double alpha = args.size() > 5 ? args[5].cast() : 1.0; + double beta = args.size() > 6 ? args[6].cast() : 0.0; hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype); @@ -359,8 +359,8 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t << "leading dimension must divide 4 for int8 gemm"; ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - double alpha = args.size() > 5 ? args[5] : 1.0; - double beta = args.size() > 6 ? args[6] : 0.0; + double alpha = args.size() > 5 ? args[5].cast() : 1.0; + double beta = args.size() > 6 ? args[6].cast() : 0.0; int A_stride = A->shape[1] * A->shape[2]; int B_stride = B->shape[1] * B->shape[2]; @@ -407,7 +407,7 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); @@ -430,7 +430,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.matmul") } }); -TVM_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hipblas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto C = args[2].cast(); diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 2cd1223bc654..60e439125c10 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -22,8 +22,8 @@ * \brief A simple JSON runtime for HIPBLAS. */ +#include #include -#include #include #include @@ -72,12 +72,10 @@ class HipblasJSONRuntime : public JSONRuntimeBase { for (size_t i = 0; i < static_cast(args.size()); i++) { auto eid = i < input_var_eid_.size() ? input_var_eid_[i] : EntryID(outputs_[i - input_var_eid_.size()]); - ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle) - << "Expect NDArray or DLTensor as inputs"; const DLTensor* arg; - if (args[i].IsObjectRef()) { - NDArray arr = args[i]; + if (auto opt_nd = args[i].as()) { + NDArray arr = opt_nd.value(); arg = arr.operator->(); } else { arg = args[i].cast(); @@ -141,9 +139,10 @@ runtime::Module HipblasJSONRuntimeCreate(String symbol_name, String graph_json, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate").set_body_typed(HipblasJSONRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.HipblasJSONRuntimeCreate") + .set_body_typed(HipblasJSONRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hipblas_json") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 02d91646518c..6facbb232b2c 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -23,7 +23,7 @@ #include "hipblas_utils.h" #include -#include +#include #include "../../rocm/rocm_common.h" diff --git a/src/runtime/contrib/json/json_node.h b/src/runtime/contrib/json/json_node.h index 88834a7c01ea..b0f2bb582142 100644 --- a/src/runtime/contrib/json/json_node.h +++ b/src/runtime/contrib/json/json_node.h @@ -149,7 +149,7 @@ class JSONGraphNode { reader->Read(&tmp); ICHECK(!reader->NextArrayItem()); for (const auto& it : tmp) { - dtype_.push_back(tvm::runtime::StringToDLDataType(it)); + dtype_.push_back(ffi::StringToDLDataType(it)); } } else if (key == "shape") { reader->BeginArray(); diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 3f42e109f839..025e85263ebc 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -124,7 +124,7 @@ class JSONRuntimeBase : public ModuleNode { // Bind argument tensors to data entries. this->SetInputOutputBuffers(args); - if (auto opt_str = rv->as()) { + if (auto opt_str = rv->try_cast()) { String purpose = std::move(opt_str.value()); if ("debug_dump" == purpose) { *rv = this->DebugDump(); diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index 19eec4a0a026..247863c56a99 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -20,9 +20,9 @@ /*! * \file Use external miopen utils function */ +#include #include #include -#include #include @@ -34,7 +34,7 @@ namespace miopen { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const int mode = args[0].cast(); const int dtype = args[1].cast(); @@ -148,7 +148,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ret[0] = static_cast(best_algo); }); -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const int mode = args[0].cast(); const int dtype = args[1].cast(); diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index b750e56c7e81..bb091fdf7aa1 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -23,7 +23,7 @@ #include "miopen_utils.h" #include -#include +#include #include #include diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index 021d0387defb..10289f22bdda 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -21,8 +21,8 @@ * \file src/runtime/contrib/miopen/softmax.cc * \brief Use external miopen softmax function */ +#include #include -#include #include "miopen_utils.h" @@ -79,12 +79,12 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t entry_ptr->softmax_entry.shape_desc, y->data, alg, mode)); } -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_ACCURATE); }); -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.miopen.log_softmax.forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { softmax_impl(args, ret, MIOPEN_SOFTMAX_LOG); }); diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 4200477b2713..dbbb92dd05f7 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -24,7 +24,7 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto buf = args[0].cast(); auto img = args[1].cast(); @@ -57,7 +57,7 @@ imageIndex:0]; }); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto img = args[0].cast(); auto buf = args[1].cast(); @@ -76,7 +76,7 @@ buf -> dtype, nullptr); }); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { // MPS-NHWC auto data = args[0].cast(); diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index 77eb6dd03dd3..51285251c82e 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -24,7 +24,7 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.mps.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index c2b7e3c7aa99..1dd1a2c1e3fc 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -26,10 +26,10 @@ #import #include +#include #include #include #include -#include #include diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index 01cfb385c7f5..3b21ba0e5dc5 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -476,9 +476,9 @@ bool MarvellHardwareModuleNode::use_dpdk_cb = false; ml_tvmc_cb MarvellHardwareModuleNode::tvmc_cb_ = {}; ml_dpdk_cb MarvellHardwareModuleNode::dpdk_cb_ = {}; -TVM_REGISTER_GLOBAL("runtime.mrvl_hw_runtime_create") +TVM_FFI_REGISTER_GLOBAL("runtime.mrvl_hw_runtime_create") .set_body_typed(MarvellHardwareModuleRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_hw") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_hw") .set_body_typed(MarvellHardwareModuleNode::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index 186cc3b3a859..701ae6ed8dcd 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -24,9 +24,9 @@ #include #include +#include #include #include -#include #include #include @@ -149,9 +149,9 @@ runtime::Module MarvellSimulatorModuleRuntimeCreate(const String& symbol_name, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.mrvl_runtime_create") +TVM_FFI_REGISTER_GLOBAL("runtime.mrvl_runtime_create") .set_body_typed(MarvellSimulatorModuleRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_sim") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_mrvl_sim") .set_body_typed(MarvellSimulatorModuleNode::LoadFromBinary); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc index b062a50dccb5..c63bafcd0089 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -25,8 +25,8 @@ #include "mrvl_sw_runtime_lib.h" #include +#include #include -#include #include #include diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h index c5242a4e7bde..10114802b4bf 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.h @@ -25,7 +25,7 @@ #ifndef TVM_RUNTIME_CONTRIB_MRVL_MRVL_SW_RUNTIME_LIB_H_ #define TVM_RUNTIME_CONTRIB_MRVL_MRVL_SW_RUNTIME_LIB_H_ -#include +#include #include #include diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 7ddbcb34ad02..8819cfd2fc4a 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include #include @@ -348,9 +348,10 @@ runtime::Module MSCTensorRTRuntimeCreate(const String& symbol_name, const String return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.msc_tensorrt_runtime_create").set_body_typed(MSCTensorRTRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.msc_tensorrt_runtime_create") + .set_body_typed(MSCTensorRTRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_msc_tensorrt") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_msc_tensorrt") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/mscclpp/allreduce.cu b/src/runtime/contrib/mscclpp/allreduce.cu index 7ead504340be..66a6a097f650 100644 --- a/src/runtime/contrib/mscclpp/allreduce.cu +++ b/src/runtime/contrib/mscclpp/allreduce.cu @@ -18,8 +18,8 @@ */ #include -#include -#include +#include +#include #include "msccl.cuh" diff --git a/src/runtime/contrib/nnapi/nnapi_ops.cc b/src/runtime/contrib/nnapi/nnapi_ops.cc index 8e2da6489fde..bea3712fab3e 100644 --- a/src/runtime/contrib/nnapi/nnapi_ops.cc +++ b/src/runtime/contrib/nnapi/nnapi_ops.cc @@ -275,7 +275,7 @@ void CastOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& n const auto dtype_attr = node.GetAttr>("astype_dtype"); ICHECK(dtype_attr.size() == 1); const auto dtype_str = dtype_attr[0]; - const DLDataType dtype = runtime::StringToDLDataType(dtype_str); + const DLDataType dtype = StringToDLDataType(dtype_str); ICHECK(outputs.size() == 1); const auto output_tensor_type = outputs[0].GetTensorType(); ICHECK(TensorTypeFromDLDataType(dtype) == output_tensor_type) diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index c63098873da1..0fcf9fded0a8 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include @@ -240,9 +240,9 @@ runtime::Module NNAPIRuntimeCreate(const String& symbol_name, const String& grap return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.nnapi_runtime_create").set_body_typed(NNAPIRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_nnapi") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 33a787b5b9f3..7b4a617a2501 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -19,16 +19,15 @@ #include #include #include +#include #include -#include -#include #include "../../cuda/cuda_common.h" namespace tvm { namespace runtime { -ShapeTuple InitNVSHMEMUID() { +ffi::Shape InitNVSHMEMUID() { nvshmemx_uniqueid_t uid; nvshmemx_get_uniqueid(&uid); std::vector uid_64; @@ -36,10 +35,10 @@ ShapeTuple InitNVSHMEMUID() { for (int i = 0; i < UNIQUEID_PADDING; ++i) { uid_64.push_back(static_cast(uid.internal[i])); } - return ShapeTuple(uid_64); + return ffi::Shape(uid_64); } -void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) { +void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int worker_id_start) { DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker; int worker_id; if (worker == nullptr) { @@ -99,7 +98,7 @@ void InitNVSHMEMWrapper(String args) { uid_vector.push_back(elem.get()); } - ShapeTuple uid_64(uid_vector); + ffi::Shape uid_64(uid_vector); int num_workers = static_cast(obj["npes"].get()); int worker_id_start = static_cast(obj["pe_start"].get()); @@ -107,11 +106,11 @@ void InitNVSHMEMWrapper(String args) { InitNVSHMEM(uid_64, num_workers, worker_id_start); } -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper") .set_body_typed(InitNVSHMEMWrapper); } // namespace runtime diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu b/src/runtime/contrib/nvshmem/kv_transfer.cu index cf3a9958f895..2dad73707df7 100644 --- a/src/runtime/contrib/nvshmem/kv_transfer.cu +++ b/src/runtime/contrib/nvshmem/kv_transfer.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include template __device__ int64_t calc_flattened_index(int shape[dim], int index[dim]) { @@ -329,5 +329,5 @@ int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages, return 0; } -TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); -TVM_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage); +TVM_FFI_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer); +TVM_FFI_REGISTER_GLOBAL("nvshmem.KVTransferPageToPage").set_body_typed(_KVTransferPageToPage); diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index 4380c7e65dd8..facfc9521741 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -18,9 +18,8 @@ */ #include #include +#include #include -#include -#include #include @@ -57,7 +56,7 @@ class NVSHMEMAllocator final : public PooledAllocator { return allocator; } - NDArray Empty(ShapeTuple shape, DataType dtype, Device device) { + NDArray Empty(ffi::Shape shape, DataType dtype, Device device) { NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, device); container->SetDeleter([](Object* obj) { auto* ptr = static_cast(obj); @@ -88,18 +87,18 @@ class NVSHMEMAllocator final : public PooledAllocator { void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } }; -NDArray NVSHMEMEmpty(ShapeTuple shape, DataType dtype, Device device) { +NDArray NVSHMEMEmpty(ffi::Shape shape, DataType dtype, Device device) { return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.empty").set_body_typed(NVSHMEMEmpty); void NVSHMEMFinalize() { NVSHMEMAllocator::Global()->Clear(); nvshmem_finalize(); } -TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.finalize_nvshmem").set_body_typed(NVSHMEMFinalize); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index 9f98890b93ac..882cee36b246 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -290,7 +290,7 @@ MetricCollector CreatePAPIMetricCollector(Map> metr TVM_REGISTER_OBJECT_TYPE(PAPIEventSetNode); TVM_REGISTER_OBJECT_TYPE(PAPIMetricCollectorNode); -TVM_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.PAPIMetricCollector") .set_body_typed([](Map> metrics) { return PAPIMetricCollector(metrics); }); diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index ed4e1a3fad38..8f05a7241b02 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -21,9 +21,9 @@ * \file External random functions for tensor. */ #include +#include #include #include -#include #include #include @@ -69,7 +69,7 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } -TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.randint") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); int64_t low = args[0].cast(); @@ -103,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") }) }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.uniform") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); double low = args[0].cast(); @@ -112,7 +112,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") entry->random_engine.SampleUniform(out, low, high); }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.normal") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); double loc = args[0].cast(); @@ -121,14 +121,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") entry->random_engine.SampleNormal(out, loc, scale); }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.random_fill") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); auto out = args[0].cast(); entry->random_engine.RandomFill(out); }); -TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) -> void { const auto curand = tvm::ffi::Function::GetGlobal("runtime.contrib.curand.RandomFill"); auto out = args[0].cast(); diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 88c6071e1efd..2969d7fd0e5e 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -23,9 +23,9 @@ #include "rocblas.h" #include +#include #include #include -#include namespace tvm { namespace contrib { @@ -65,7 +65,7 @@ struct RocBlasThreadEntry { typedef dmlc::ThreadLocalStore RocBlasThreadStore; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); @@ -103,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") ldc)); }); -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); auto B = args[1].cast(); diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index f413af696661..62639e684055 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -22,7 +22,9 @@ */ #include -#include +#include +#include +#include #include #include @@ -77,7 +79,7 @@ struct float16 { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto sort_num = args[1].cast(); @@ -216,7 +218,7 @@ void sort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.argsort") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto output = args[1].cast(); @@ -229,8 +231,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") "input ndim " << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); if (data_dtype == "float32") { if (out_dtype == "int32") { @@ -312,7 +314,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.sort") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.sort") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); auto output = args[1].cast(); @@ -442,7 +444,7 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.sort.topk") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto input = args[0].cast(); DLTensor* values_out = nullptr; @@ -467,8 +469,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") ICHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataTypeToString(indices_out->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = + (indices_out == nullptr) ? "int64" : ffi::DLDataTypeToString(indices_out->dtype); if (data_dtype == "float32") { if (out_dtype == "int32") { diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index e1f205e22f10..a8bd43127258 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -23,8 +23,8 @@ */ #include +#include #include -#include #include #include @@ -524,9 +524,9 @@ runtime::Module TensorRTRuntimeCreate(const String& symbol_name, const String& g return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.tensorrt_runtime_create").set_body_typed(TensorRTRuntimeCreate); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_tensorrt") .set_body_typed(JSONRuntimeBase::LoadFromBinary); } // namespace contrib diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 990475069574..74cfcad3e650 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include namespace tvm { namespace runtime { @@ -183,11 +183,11 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") +TVM_FFI_REGISTER_GLOBAL("tvm.tflite_runtime.create") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = TFLiteRuntimeCreate(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); +TVM_FFI_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index 6557fa07975e..5e8751a01281 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -27,8 +27,9 @@ #include #include +#include +#include #include -#include #include #include diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index aa1befaeef32..6b6b9df834ab 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -31,7 +31,8 @@ #include #include #include -#include +#include +#include #include #include @@ -232,25 +233,25 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") -.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_GE(args.num_args, 4); - auto input = args[0].cast(); - auto values_out = args[1].cast(); - auto indices_out = args[2].cast(); - bool is_ascend = args[3].cast(); - DLTensor* workspace = nullptr; - if (args.num_args == 5) { - workspace = args[4]; - } +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK_GE(args.size(), 4); + auto input = args[0].cast(); + auto values_out = args[1].cast(); + auto indices_out = args[2].cast(); + bool is_ascend = args[3].cast(); + DLTensor* workspace = nullptr; + if (args.size() == 5) { + workspace = args[4].cast(); + } - auto data_dtype = DLDataTypeToString(input->dtype); - auto out_dtype = DLDataTypeToString(indices_out->dtype); + auto data_dtype = ffi::DLDataTypeToString(input->dtype); + auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype); - int n_values = input->shape[input->ndim - 1]; - thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, - workspace); -}); + int n_values = input->shape[input->ndim - 1]; + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, data_dtype, out_dtype, + workspace); + }); template void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* keys_out, @@ -279,21 +280,21 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, DLTensor* thrust::stable_sort_by_key(policy, keys_out_ptr, keys_out_ptr + size, values_out_ptr); } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_GE(args.num_args, 5); + ICHECK_GE(args.size(), 5); auto keys_in = args[0].cast(); auto values_in = args[1].cast(); auto keys_out = args[2].cast(); auto values_out = args[3].cast(); bool for_scatter = args[4].cast(); DLTensor* workspace = nullptr; - if (args.num_args == 6) { - workspace = args[5]; + if (args.size() == 6) { + workspace = args[5].cast(); } - auto key_dtype = DLDataTypeToString(keys_in->dtype); - auto value_dtype = DLDataTypeToString(values_in->dtype); + auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype); + auto value_dtype = ffi::DLDataTypeToString(values_in->dtype); if (key_dtype == "int32") { if (value_dtype == "int32") { @@ -394,83 +395,83 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool exclusive, DLTensor* wor } } -TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") -.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4); - auto data = args[0].cast(); - auto output = args[1].cast(); - bool exclusive = false; - DLTensor* workspace = nullptr; - - if (args.num_args >= 3) { - exclusive = args[2]; - } +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4); + auto data = args[0].cast(); + auto output = args[1].cast(); + bool exclusive = false; + DLTensor* workspace = nullptr; - if (args.num_args == 4) { - workspace = args[3]; - } + if (args.size() >= 3) { + exclusive = args[2].cast(); + } - auto in_dtype = DLDataTypeToString(data->dtype); - auto out_dtype = DLDataTypeToString(output->dtype); + if (args.size() == 4) { + workspace = args[3].cast(); + } - if (in_dtype == "bool") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int32") { - if (out_dtype == "int32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int32, int64, float32, and float64"; - } - } else if (in_dtype == "int64") { - if (out_dtype == "int64") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are int64, float32, and float64"; - } - } else if (in_dtype == "float32") { - if (out_dtype == "float32") { - thrust_scan(data, output, exclusive, workspace); - } else if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtypes are float32, and float64"; - } - } else if (in_dtype == "float64") { - if (out_dtype == "float64") { - thrust_scan(data, output, exclusive, workspace); - } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype - << ". Supported output dtype is float64"; - } - } else { - LOG(FATAL) << "Unsupported input dtype: " << in_dtype - << ". Supported input dtypes are bool, int32, int64, float32, and float64"; - } -}); + auto in_dtype = ffi::DLDataTypeToString(data->dtype); + auto out_dtype = ffi::DLDataTypeToString(output->dtype); + + if (in_dtype == "bool") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int32") { + if (out_dtype == "int32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int32, int64, float32, and float64"; + } + } else if (in_dtype == "int64") { + if (out_dtype == "int64") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are int64, float32, and float64"; + } + } else if (in_dtype == "float32") { + if (out_dtype == "float32") { + thrust_scan(data, output, exclusive, workspace); + } else if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtypes are float32, and float64"; + } + } else if (in_dtype == "float64") { + if (out_dtype == "float64") { + thrust_scan(data, output, exclusive, workspace); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype + << ". Supported output dtype is float64"; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << in_dtype + << ". Supported input dtypes are bool, int32, int64, float32, and float64"; + } + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/vllm/attention_kernels.cu b/src/runtime/contrib/vllm/attention_kernels.cu index 2b59044f844c..9221f4672511 100644 --- a/src/runtime/contrib/vllm/attention_kernels.cu +++ b/src/runtime/contrib/vllm/attention_kernels.cu @@ -19,8 +19,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -735,7 +735,7 @@ void single_query_cached_kv_attention_v2( } } -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") .set_body_typed([](const DLTensor* query, const DLTensor* key_cache, const DLTensor* value_cache, const DLTensor* block_tables, const DLTensor* context_lens, int block_size, @@ -759,10 +759,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention") }); // Expose for testing -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v1") .set_body_typed(single_query_cached_kv_attention_v1); -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.single_query_cached_kv_attention_v2") .set_body_typed(single_query_cached_kv_attention_v2); } // namespace runtime diff --git a/src/runtime/contrib/vllm/cache_alloc.cc b/src/runtime/contrib/vllm/cache_alloc.cc index aea50aa47a5c..dd2b7bd5bb37 100644 --- a/src/runtime/contrib/vllm/cache_alloc.cc +++ b/src/runtime/contrib/vllm/cache_alloc.cc @@ -17,8 +17,8 @@ * under the License. */ #include +#include #include -#include namespace tvm { namespace runtime { @@ -48,7 +48,7 @@ Array AllocateKVCache(int head_size, int num_layers, int num_heads, int return cache; } -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache); +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.allocate_kv_cache").set_body_typed(AllocateKVCache); } // namespace vllm } // namespace runtime diff --git a/src/runtime/contrib/vllm/cache_kernels.cu b/src/runtime/contrib/vllm/cache_kernels.cu index b53cd094c1aa..01320daac650 100644 --- a/src/runtime/contrib/vllm/cache_kernels.cu +++ b/src/runtime/contrib/vllm/cache_kernels.cu @@ -17,8 +17,8 @@ * under the License. */ #include -#include -#include +#include +#include #include #include @@ -130,7 +130,7 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, int64_t* value_cache namespace tvm { namespace runtime { -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") .set_body_typed([](NDArray key, NDArray value, NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { int num_tokens = key->shape[0]; @@ -155,7 +155,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reshape_and_cache") return Array{key_cache, value_cache}; }); -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") .set_body_typed([](NDArray key_cache, NDArray value_cache, NDArray slot_mapping) { int num_tokens = slot_mapping->shape[0]; int num_heads = value_cache->shape[1]; @@ -184,7 +184,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.vllm.reconstruct_from_cache") return Array{key, value}; }); -TVM_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.vllm.copy_blocks") .set_body_typed([](Array key_value_caches, NDArray block_mapping) { auto num_layers = key_value_caches.size() / 2; auto num_pairs = block_mapping->shape[0] / 2; diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index df2271e64732..68594f0769fe 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -21,9 +21,9 @@ * \file cpu_device_api.cc */ #include +#include #include #include -#include #include #include @@ -150,7 +150,7 @@ void CPUDeviceAPI::FreeWorkspace(Device dev, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } -TVM_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.cpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CPUDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index 037dd9426209..a378e53c54a5 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -25,7 +25,7 @@ #define TVM_RUNTIME_CUDA_CUDA_COMMON_H_ #include -#include +#include #include diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 1dc928e77801..399312e19321 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -24,9 +24,9 @@ #include #include #include +#include #include #include -#include #include @@ -286,15 +286,16 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {} CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.cuda_host") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global(); + *rv = static_cast(ptr); + }); class CUDATimerNode : public TimerNode { public: @@ -329,7 +330,7 @@ class CUDATimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(CUDATimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.cuda").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.cuda").set_body_typed([](Device dev) { return Timer(make_object()); }); @@ -342,9 +343,9 @@ TVM_DLL String GetCudaFreeMemory() { return ss.str(); } -TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); +TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); -TVM_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() { return static_cast(CUDAThreadEntry::ThreadLocal()->stream); }); @@ -354,7 +355,7 @@ TVM_DLL int GetCudaDeviceCount() { return count; } -TVM_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount); +TVM_FFI_REGISTER_GLOBAL("runtime.GetCudaDeviceCount").set_body_typed(GetCudaDeviceCount); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index db01b76cb531..acb2dc6cdf11 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include @@ -290,10 +290,10 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index ae7c057be0cc..726df80de8bc 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -19,8 +19,8 @@ #include "../../../3rdparty/nvbench/l2_cache_flush.h" #include +#include #include -#include #include "cuda_common.h" @@ -32,11 +32,12 @@ typedef dmlc::ThreadLocalStore L2FlushStore; L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); } -TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; - cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; - L2Flush::ThreadLocal()->Flush(stream); -}); +TVM_FFI_REGISTER_GLOBAL("l2_cache_flush_cuda") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; + cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream; + L2Flush::ThreadLocal()->Flush(stream); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc index 570c4cb4cb8d..4b22e2649462 100644 --- a/src/runtime/debug_compile.cc +++ b/src/runtime/debug_compile.cc @@ -22,23 +22,23 @@ * \brief File used for debug migration */ // #include +#include +#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include -#include // #include // #include -// #include +// #include // #include // #include -// #include -// #include +// #include +// #include // #include namespace tvm { @@ -46,7 +46,7 @@ namespace debug { using namespace tvm::runtime; -// TVM_REGISTER_GLOBAL("tvm.debug.Test").set_body_typed([](PrimExpr value) { +// TVM_FFI_REGISTER_GLOBAL("tvm.debug.Test").set_body_typed([](PrimExpr value) { // LOG(INFO) << value; // return value; // }); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc new file mode 100644 index 000000000000..3e3145c32f5c --- /dev/null +++ b/src/runtime/device_api.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file device_api.cc + * \brief Device specific implementations + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +class DeviceAPIManager { + public: + static const int kMaxDeviceAPI = TVMDeviceExtType_End; + // Get API + static DeviceAPI* Get(const Device& dev) { return Get(dev.device_type); } + static DeviceAPI* Get(int dev_type, bool allow_missing = false) { + return Global()->GetAPI(dev_type, allow_missing); + } + + private: + std::array api_; + DeviceAPI* rpc_api_{nullptr}; + std::mutex mutex_; + // constructor + DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } + // Global static variable. + static DeviceAPIManager* Global() { + static DeviceAPIManager* inst = new DeviceAPIManager(); + return inst; + } + // Get or initialize API. + DeviceAPI* GetAPI(int type, bool allow_missing) { + if (type < kRPCSessMask) { + if (api_[type] != nullptr) return api_[type]; + std::lock_guard lock(mutex_); + if (api_[type] != nullptr) return api_[type]; + api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing); + return api_[type]; + } else { + if (rpc_api_ != nullptr) return rpc_api_; + std::lock_guard lock(mutex_); + if (rpc_api_ != nullptr) return rpc_api_; + rpc_api_ = GetAPI("rpc", allow_missing); + return rpc_api_; + } + } + DeviceAPI* GetAPI(const std::string name, bool allow_missing) { + std::string factory = "device_api." + name; + const auto f = tvm::ffi::Function::GetGlobal(factory); + if (!f.has_value()) { + ICHECK(allow_missing) << "Device API " << name << " is not enabled."; + return nullptr; + } + void* ptr = (*f)().cast(); + return static_cast(ptr); + } +}; + +DeviceAPI* DeviceAPI::Get(Device dev, bool allow_missing) { + return DeviceAPIManager::Get(static_cast(dev.device_type), allow_missing); +} + +void* DeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { + return AllocDataSpace(dev, size, kTempAllocaAlignment, type_hint); +} + +static size_t GetDataAlignment(const DLDataType dtype) { + size_t align = (dtype.bits / 8) * dtype.lanes; + if (align < kAllocAlignment) return kAllocAlignment; + return align; +} + +size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional mem_scope) { + if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") { + size_t size = 1; + for (int i = 0; i < arr.ndim; ++i) { + size *= static_cast(arr.shape[i]); + } + size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; + return size; + } + LOG(FATAL) << "Device does not support physical mem computation with " + << "specified memory scope: " << mem_scope.value(); + return 0; +} + +void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope) { + if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") { + // by default, we can always redirect to the flat memory allocations + DLTensor temp; + temp.data = nullptr; + temp.device = dev; + temp.ndim = ndim; + temp.dtype = dtype; + temp.shape = const_cast(shape); + temp.strides = nullptr; + temp.byte_offset = 0; + size_t size = GetDataSize(temp); + size_t alignment = GetDataAlignment(temp.dtype); + return AllocDataSpace(dev, size, alignment, dtype); + } + LOG(FATAL) << "Device does not support allocate data space with " + << "specified memory scope: " << mem_scope.value(); + return nullptr; +} + +void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + // by default, we can always redirect to the flat memory copy operation. + size_t nbytes = GetDataSize(*from); + ICHECK_EQ(nbytes, GetDataSize(*to)); + + ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to)) + << "CopyDataFromTo only support contiguous array for now"; + CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, + to->device, from->dtype, stream); +} + +void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, Device dev_from, Device dev_to, + DLDataType type_hint, TVMStreamHandle stream) { + LOG(FATAL) << "Device does not support CopyDataFromTo."; +} + +void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } + +TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } + +void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} + +TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; } + +void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { +} + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamCreate").set_body_typed([](DLDevice dev) { + return reinterpret_cast(DeviceAPIManager::Get(dev)->CreateStream(dev)); +}); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamFree") + .set_body_typed([](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->FreeStream(dev, reinterpret_cast(stream)); + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_SetStream") + .set_body_typed([](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->SetStream(dev, reinterpret_cast(stream)); + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamSync") + .set_body_typed([](DLDevice dev, int64_t stream) { + DeviceAPIManager::Get(dev)->StreamSync(dev, reinterpret_cast(stream)); + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.Device_StreamSyncFromTo") + .set_body_typed([](DLDevice dev, int64_t src, int64_t dst) { + DeviceAPIManager::Get(dev)->SyncStreamFromTo(dev, reinterpret_cast(src), + reinterpret_cast(dst)); + }); + +// set device api +TVM_FFI_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + DeviceAPIManager::Get(dev)->SetDevice(dev); + }); + +// set device api +TVM_FFI_REGISTER_GLOBAL("runtime.GetDeviceAttr") + .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + DLDevice dev; + dev.device_type = static_cast(args[0].cast()); + dev.device_id = args[1].cast(); + + DeviceAttrKind kind = static_cast(args[2].cast()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(dev.device_type, true); + if (api != nullptr) { + api->GetAttr(dev, kind, ret); + } else { + *ret = 0; + } + } else { + DeviceAPIManager::Get(dev)->GetAttr(dev, kind, ret); + } + }); + +TVM_FFI_REGISTER_GLOBAL("runtime.TVMSetStream") + .set_body_typed([](int device_type, int device_id, void* stream) { + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + DeviceAPIManager::Get(dev)->SetStream(dev, stream); + }); +} // namespace runtime +} // namespace tvm + +using namespace tvm::runtime; + +int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFFIObjectHandle* func) { + TVM_FFI_SAFE_CALL_BEGIN(); + *func = const_cast( + static_cast(mod_node)->GetFuncFromEnv(func_name)->get()); + TVM_FFI_SAFE_CALL_END(); +} + +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, + int dtype_bits_hint) { + DLDevice dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + return DeviceAPIManager::Get(dev)->AllocWorkspace(dev, static_cast(size), type_hint); +} + +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { + DLDevice dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + DeviceAPIManager::Get(dev)->FreeWorkspace(dev, ptr); + return 0; +} + +int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { + if (*handle == nullptr) { + *handle = reinterpret_cast(1); + return (*f)(cdata); + } + return 0; +} diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 81f10190e00b..46ecb49f50fc 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -18,9 +18,8 @@ */ #include "./bcast_session.h" +#include #include -#include -#include #include @@ -32,7 +31,7 @@ struct BcastSessionObj::Internal { static void TVM_ALWAYS_INLINE BroadcastUnpacked(BcastSessionObj* self, DiscoAction action, int64_t reg_id, Args&&... args) { constexpr int kNumArgs = 2 + sizeof...(Args); - AnyView packed_args[kNumArgs]; + ffi::AnyView packed_args[kNumArgs]; ffi::PackedArgs::Fill(packed_args, static_cast(action), reg_id, std::forward(args)...); self->BroadcastPacked(ffi::PackedArgs(packed_args, kNumArgs)); @@ -68,7 +67,7 @@ void BcastSessionObj::Shutdown() { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kShutDown, 0); } -void BcastSessionObj::InitCCL(String ccl, IntTuple device_ids) { +void BcastSessionObj::InitCCL(String ccl, ffi::Shape device_ids) { const auto pf = tvm::ffi::Function::GetGlobal("runtime.disco." + ccl + ".init_ccl"); CHECK(pf.has_value()) << "ValueError: Cannot initialize CCL `" << ccl << "`, because cannot find function: runtime.disco." << ccl << ".init_ccl"; @@ -88,7 +87,7 @@ void BcastSessionObj::SyncWorker(int worker_id) { DRef BcastSessionObj::CallWithPacked(const ffi::PackedArgs& args) { // NOTE: this action is not safe unless we know args is not // used else where in this case it is oK - AnyView* args_vec = const_cast(args.data()); + ffi::AnyView* args_vec = const_cast(args.data()); // tranlsate args into remote calling convention int reg_id = AllocateReg(); { diff --git a/src/runtime/disco/bcast_session.h b/src/runtime/disco/bcast_session.h index bfb1ca24b565..f92369d85337 100644 --- a/src/runtime/disco/bcast_session.h +++ b/src/runtime/disco/bcast_session.h @@ -43,7 +43,7 @@ class BcastSessionObj : public SessionObj { void Shutdown() override; void InitCCL(String ccl, IntTuple device_ids) override; ffi::Any DebugGetFromRemote(int64_t reg_id, int worker_id) override = 0; - void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) override = 0; + void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) override = 0; protected: /*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */ diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index a58c840ea325..7c769b7dd081 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -17,12 +17,11 @@ * under the License. */ #include -#include +#include +#include #include #include #include -#include -#include #include #include @@ -66,7 +65,7 @@ Module LoadVMModule(std::string path, Device device) { return mod; } -NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device) { +NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device) { return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } @@ -121,53 +120,55 @@ void SyncWorker() { } } -TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule); -TVM_REGISTER_GLOBAL("runtime.disco.empty") - .set_body_typed([](ShapeTuple shape, DataType dtype, Device device, bool worker0_only, +TVM_FFI_REGISTER_GLOBAL("runtime.disco.empty") + .set_body_typed([](ffi::Shape shape, DataType dtype, Device device, bool worker0_only, bool in_group) -> Optional { int worker_id = WorkerId(); int group_size = DiscoWorker::ThreadLocal()->num_workers / DiscoWorker::ThreadLocal()->num_groups; bool is_worker0 = (worker_id == 0 && !in_group) || (in_group && worker_id % group_size == 0); if (worker0_only && !is_worker0) { - return NullOpt; + return std::nullopt; } else { return DiscoEmptyNDArray(shape, dtype, device); } }); -TVM_REGISTER_GLOBAL("runtime.disco.allreduce") - .set_body_typed([](NDArray send, ShapeTuple reduce_kind, bool in_group, NDArray recv) { - int kind = IntegerFromShapeTuple(reduce_kind); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.allreduce") + .set_body_typed([](NDArray send, ffi::Shape reduce_kind, bool in_group, NDArray recv) { + int kind = IntegerFromShape(reduce_kind); CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; AllReduce(send, static_cast(kind), in_group, recv); }); -TVM_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); -TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(BroadcastFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); -TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); -TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); -TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); -TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple { - return ShapeTuple({WorkerId()}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.allgather").set_body_typed(AllGather); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0") + .set_body_typed(BroadcastFromWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ffi::Shape { + return ffi::Shape({WorkerId()}); }); -TVM_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.worker_rank").set_body_typed([]() -> int64_t { return WorkerId(); }); -TVM_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.device").set_body_typed([]() -> Device { return DiscoWorker::ThreadLocal()->default_device; }); -TVM_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core").set_body_typed([](IntTuple cpu_ids) { - int worker_id = WorkerId(); - ICHECK_LT(worker_id, static_cast(cpu_ids.size())); - const auto f_set_thread_affinity = - tvm::ffi::Function::GetGlobalRequired("tvm.runtime.threading.set_current_thread_affinity"); - f_set_thread_affinity(IntTuple{cpu_ids[worker_id]}); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.bind_worker_to_cpu_core") + .set_body_typed([](ffi::Shape cpu_ids) { + int worker_id = WorkerId(); + ICHECK_LT(worker_id, static_cast(cpu_ids.size())); + const auto f_set_thread_affinity = tvm::ffi::Function::GetGlobalRequired( + "tvm.runtime.threading.set_current_thread_affinity"); + f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index eb22834d1a80..778ecc16e5a2 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -18,9 +18,9 @@ */ #include +#include #include #include -#include #include "../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h" #include "../../cuda/cuda_common.h" @@ -200,7 +200,7 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { * \param dtype_hint The dtype of the storage to allocate. * \return The allocated storage object with internal CUDA IPC memory buffer. */ -memory::Storage IPCAllocStorage(ShapeTuple buffer_shape, DLDataType dtype_hint) { +memory::Storage IPCAllocStorage(ffi::Shape buffer_shape, DLDataType dtype_hint) { auto storage_obj = ffi::SimpleObjAllocator().make_object(); nccl::CCLThreadLocalContext* nccl_ctx = nccl::CCLThreadLocalContext::Get(); Device device{DLDeviceType::kDLCUDA, nccl_ctx->device_id}; @@ -212,11 +212,10 @@ memory::Storage IPCAllocStorage(ShapeTuple buffer_shape, DLDataType dtype_hint) return storage; } -TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.alloc_storage").set_body_typed(IPCAllocStorage); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.alloc_storage").set_body_typed(IPCAllocStorage); -TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear").set_body_typed([]() { - CUDAIPCMemoryAllocator::Global()->Clear(); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.cuda_ipc_memory_allocator_clear") + .set_body_typed([]() { CUDAIPCMemoryAllocator::Global()->Clear(); }); /******************** CUDAIPCMemoryObj ********************/ diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index d969005f9476..fa7ef040f3ed 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -18,9 +18,9 @@ */ #include +#include #include #include -#include #include "../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h" #include "../nccl/nccl_context.h" @@ -112,7 +112,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { ctx->GetDefaultStream()); } -TVM_REGISTER_GLOBAL("runtime.disco.cuda_ipc.custom_allreduce").set_body_typed(CustomAllReduce); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.cuda_ipc.custom_allreduce").set_body_typed(CustomAllReduce); } // namespace cuda_ipc } // namespace nccl diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index 7f98feacd83b..8e63355283a8 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -16,11 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include -#include -#include #include "../../support/process_id.h" #include "./protocol.h" @@ -34,7 +33,7 @@ TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { return ret; } -void DiscoWorker::SetRegister(int reg_id, AnyView value) { +void DiscoWorker::SetRegister(int reg_id, ffi::AnyView value) { ICHECK(0 <= reg_id && reg_id < static_cast(register_file.size())); ffi::Any& rv = register_file.at(reg_id); if (rv.type_index() == ffi::TypeIndex::kTVMFFINDArray && @@ -95,7 +94,7 @@ struct DiscoWorker::Impl { } case DiscoAction::kDebugSetRegister: { int worker_id = args[2].cast(); - AnyView value = args[3]; + ffi::AnyView value = args[3]; DebugSetRegister(self, reg_id, worker_id, value); break; } @@ -139,7 +138,7 @@ struct DiscoWorker::Impl { static void SyncWorker(DiscoWorker* self, int worker_id) { if (worker_id == self->worker_id) { ::tvm::runtime::SyncWorker(); - AnyView packed_args[2]; + ffi::AnyView packed_args[2]; ffi::PackedArgs::Fill(packed_args, static_cast(DiscoAction::kSyncWorker), worker_id); self->channel->Reply(ffi::PackedArgs(packed_args, 2)); } @@ -151,17 +150,17 @@ struct DiscoWorker::Impl { if (rv.as()) { rv = DiscoDebugObject::Wrap(rv); } - AnyView packed_args[2]; + ffi::AnyView packed_args[2]; ffi::PackedArgs::Fill(packed_args, static_cast(DiscoAction::kDebugGetFromRemote), rv); self->channel->Reply(ffi::PackedArgs(packed_args, 2)); } } - static void DebugSetRegister(DiscoWorker* self, int reg_id, int worker_id, AnyView value) { + static void DebugSetRegister(DiscoWorker* self, int reg_id, int worker_id, ffi::AnyView value) { if (worker_id == self->worker_id) { ::tvm::runtime::SyncWorker(); self->SetRegister(reg_id, value); - AnyView packed_args[1]; + ffi::AnyView packed_args[1]; ffi::PackedArgs::Fill(packed_args, static_cast(DiscoAction::kDebugSetRegister)); self->channel->Reply(ffi::PackedArgs(packed_args, 1)); } @@ -171,7 +170,7 @@ struct DiscoWorker::Impl { const ffi::PackedArgs& args) { // NOTE: this action is not safe unless we know args is not // used else where in this case it is oK - AnyView* args_vec = const_cast(args.data()); + ffi::AnyView* args_vec = const_cast(args.data()); // translate args into remote calling convention for (int i = 0; i < args.size(); ++i) { if (auto opt_dref = args_vec[i].as()) { diff --git a/src/runtime/disco/disco_worker_thread.h b/src/runtime/disco/disco_worker_thread.h index 8d6b44396f4d..99960201b9e2 100644 --- a/src/runtime/disco/disco_worker_thread.h +++ b/src/runtime/disco/disco_worker_thread.h @@ -25,9 +25,9 @@ #ifndef TVM_RUNTIME_DISCO_DISCO_WORKER_THREAD_H_ #define TVM_RUNTIME_DISCO_DISCO_WORKER_THREAD_H_ +#include #include #include -#include #include #include diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index 9c25d4abb68e..6cd012b64e11 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include #include @@ -294,7 +294,7 @@ void RemoteSocketSessionEntryPoint(const String& server_host, int server_port, proxy.MainLoop(); } -TVM_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.RemoteSocketSession") .set_body_typed(RemoteSocketSessionEntryPoint); Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, const String& host, @@ -303,9 +303,9 @@ Session SocketSession(int num_nodes, int num_workers_per_node, int num_groups, c return Session(n); } -TVM_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SocketSession").set_body_typed(SocketSession); -TVM_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.socket_session_init_workers") .set_body_typed([](int num_nodes, int node_id, int num_groups, int num_workers_per_node) { LOG(INFO) << "Initializing worker group with " << num_nodes << " nodes, " << num_workers_per_node << " workers per node, and " << num_groups << " groups."; diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index 7a79f0c392ef..f93170d02f07 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -21,10 +21,9 @@ #define __STDC_FORMAT_MACROS #endif #include +#include #include #include -#include -#include #include #include @@ -45,7 +44,7 @@ using ParamRecord = NDArrayCacheMetadata::FileRecord::ParamRecord; struct ShardInfo { struct TensorInfo { - ShapeTuple shape; + ffi::Shape shape; DataType dtype; }; struct ShardFunc { @@ -78,7 +77,7 @@ ShardInfo::TensorInfo LoadTensorInfoFromJSON(const picojson::array& json_tensor_ shape.push_back(AsType(shape_json[i])); } std::string dtype = AsType(json_tensor_info[1]); - return ShardInfo::TensorInfo{ShapeTuple(std::move(shape)), DataType(StringToDLDataType(dtype))}; + return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), DataType(StringToDLDataType(dtype))}; } ShardInfo::ShardFunc LoadShardFuncFromJSON(const picojson::array& json_shard_func) { @@ -218,7 +217,7 @@ NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, NDArray o = NDArray::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device); ffi::Function f = this->shard_funcs_.at(shard_func.name); int n = static_cast(shard_func.params.size()); - std::vector packed_args(n + 2); + std::vector packed_args(n + 2); const DLTensor* w_in = param.operator->(); const DLTensor* w_out = o.operator->(); packed_args[0] = const_cast(w_in); @@ -226,7 +225,7 @@ NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, packed_args[i + 1] = shard_func.params[i]; } packed_args[n + 1] = const_cast(w_out); - Any rv; + ffi::Any rv; f.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv); return o; } @@ -314,13 +313,13 @@ NDArray ShardLoaderObj::Load(int weight_index) const { bool needs_sharding = !param_info.shard_info.funcs.empty(); if (needs_sharding) { - ShapeTuple shape = param_info.shard_info.funcs.back().output_info.shape; + ffi::Shape shape = param_info.shard_info.funcs.back().output_info.shape; DataType dtype = param_info.shard_info.funcs.back().output_info.dtype; ICHECK(shape.size() >= 1 && shape[0] == num_shards) << "ValueError: The first dimension of the " << "output shape must be equal to the " << "number of shards, but got: " << shape << " and num_shards = " << num_shards; - NDArray recv = NDArray::Empty(ShapeTuple(shape.begin() + 1, shape.end()), dtype, device); + NDArray recv = NDArray::Empty(ffi::Shape(shape.begin() + 1, shape.end()), dtype, device); if (worker_id == 0) { NDArray w = LoadDirect(weight_index); for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) { @@ -328,7 +327,7 @@ NDArray ShardLoaderObj::Load(int weight_index) const { } ScatterFromWorker0(w, /*in_group=*/false, recv); } else { - ScatterFromWorker0(NullOpt, /*in_group=*/false, recv); + ScatterFromWorker0(std::nullopt, /*in_group=*/false, recv); } return recv; } else { @@ -406,30 +405,31 @@ Array ShardLoaderObj::LoadAllPresharded() const { return params; } -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad") - .set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad") + .set_body_typed([](ObjectRef loader_obj, ffi::Shape weight_index) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); - return loader->Load(IntegerFromShapeTuple(weight_index)); + return loader->Load(IntegerFromShape(weight_index)); }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded") - .set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadPresharded") + .set_body_typed([](ObjectRef loader_obj, ffi::Shape weight_index) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); - return loader->LoadPresharded(IntegerFromShapeTuple(weight_index)); + return loader->LoadPresharded(IntegerFromShape(weight_index)); }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll").set_body_typed([](ObjectRef loader_obj) { - const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " - << loader_obj->GetTypeKey(); - return loader->LoadAll(); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAll") + .set_body_typed([](ObjectRef loader_obj) { + const auto* loader = loader_obj.as(); + CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " + << loader_obj->GetTypeKey(); + return loader->LoadAll(); + }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") .set_body_typed([](ObjectRef loader_obj) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " @@ -437,7 +437,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadAllPresharded") return loader->LoadAllPresharded(); }); -TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoadParamOnWorker0") .set_body_typed([](ObjectRef loader_obj, int param_index) { const auto* loader = loader_obj.as(); CHECK(loader != nullptr) << "TypeError: Expected ShardLoaderObj, but gets: " diff --git a/src/runtime/disco/message_queue.h b/src/runtime/disco/message_queue.h index 6b3600acbb97..e8286384ff50 100644 --- a/src/runtime/disco/message_queue.h +++ b/src/runtime/disco/message_queue.h @@ -37,31 +37,24 @@ class DiscoStreamMessageQueue : private dmlc::Stream, ~DiscoStreamMessageQueue() = default; void Send(const ffi::PackedArgs& args) { - // Run legacy ABI translation. - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); // TODO(tqchen): use native convention that do not need ABI translation. - RPCReference::ReturnPackedSeq(values.data(), type_codes.data(), args.size(), this); + RPCReference::ReturnPackedSeq(reinterpret_cast(args.data()), args.size(), + this); CommitSendAndNotifyEnqueue(); } ffi::PackedArgs Recv() { bool is_implicit_shutdown = DequeueNextPacket(); - AnyView* packed_args = nullptr; + ffi::AnyView* packed_args = nullptr; int num_args = 0; if (is_implicit_shutdown) { num_args = 2; - packed_args = reinterpret_cast(ArenaAlloc(num_args)); + packed_args = reinterpret_cast(ArenaAlloc(num_args)); packed_args[0] = static_cast(DiscoAction::kShutDown); packed_args[1] = 0; } else { - TVMValue* values = nullptr; - int* type_codes = nullptr; - RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); - packed_args = reinterpret_cast(ArenaAlloc(num_args)); - LegacyTVMArgsToPackedArgs(values, type_codes, num_args, packed_args); + RPCReference::RecvPackedSeq(reinterpret_cast(&packed_args), &num_args, this); } return ffi::PackedArgs(packed_args, num_args); } diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 075068a336ac..8095cbeeea4a 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -52,7 +52,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) { throw; } -void InitCCL(Session sess, IntTuple device_ids) { +void InitCCL(Session sess, ffi::Shape device_ids) { DRef func = sess->GetGlobalFunc("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker"); DLOG(INFO) << "Initializing " TVM_DISCO_CCL_NAME " with devices: " << device_ids; ncclUniqueId id; @@ -60,7 +60,7 @@ void InitCCL(Session sess, IntTuple device_ids) { sess->CallPacked(func, device_ids, ffi::Bytes(id.internal, NCCL_UNIQUE_ID_BYTES)); } -void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { +void InitCCLPerWorker(ffi::Shape device_ids, std::string unique_id_bytes) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); DiscoWorker* worker = DiscoWorker::ThreadLocal(); ICHECK(worker != nullptr); @@ -116,7 +116,7 @@ void InitCCLPerWorker(IntTuple device_ids, std::string unique_id_bytes) { void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - ShapeTuple shape = send.Shape(); + ffi::Shape shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); DataType dtype = DataType(send->dtype); @@ -131,7 +131,7 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv void AllGather(NDArray send, bool in_group, NDArray recv) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - ShapeTuple shape = send.Shape(); + ffi::Shape shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllGather(send->data, recv->data, numel, @@ -325,41 +325,42 @@ void SyncWorker() { StreamSynchronize(stream); } -TVM_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("runtime.disco.compiled_ccl").set_body_typed([]() -> String { return TVM_DISCO_CCL_NAME; }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl").set_body_typed(InitCCL); +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker") .set_body_typed(InitCCLPerWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce") .set_body_typed([](NDArray send, int kind, bool in_group, NDArray recv) { CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; nccl::AllReduce(send, static_cast(kind), in_group, recv); }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".allgather") .set_body_typed([](NDArray send, bool in_group, NDArray recv) { nccl::AllGather(send, in_group, recv); }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".broadcast_from_worker0") .set_body_typed(BroadcastFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".scatter_from_worker0") .set_body_typed(ScatterFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0") .set_body_typed(GatherToWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0") .set_body_typed(RecvFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group") .set_body_typed(SendToNextGroup); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group") .set_body_typed(RecvFromPrevGroup); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker") .set_body_typed(SendToWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker") .set_body_typed(RecvFromWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker") + .set_body_typed(SyncWorker); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME - ".test_send_to_next_group_recv_from_prev_group") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME + ".test_send_to_next_group_recv_from_prev_group") .set_body_typed([](NDArray buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; @@ -373,7 +374,7 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME } }); -TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0") .set_body_typed([](NDArray buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index d70efebdc844..fff165bfdd04 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -21,10 +21,10 @@ #define TVM_RUNTIME_DISCO_NCCL_NCCL_CONTEXT_H_ #include -#include +#include +#include #include #include -#include #include "../../../support/process_id.h" #include "../utils.h" diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 4265bd21c43d..4563079c30b4 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -16,11 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include +#include #include #include -#include -#include #include #include @@ -70,7 +69,7 @@ class ProcessSessionObj final : public BcastSessionObj { read_fds.reserve(num_workers - 1); write_fds.reserve(num_workers - 1); for (int i = 1; i < num_workers; ++i) { - IntTuple fds = process_pool(i).cast(); + ffi::Shape fds = process_pool(i).cast(); CHECK_EQ(fds.size(), 2) << "ValueError: process_pool(" << i << ") should return a tuple of " << "size 2, but got a tuple of size " << fds.size() << "."; read_fds.push_back(fds[0]); @@ -100,7 +99,7 @@ class ProcessSessionObj final : public BcastSessionObj { return worker_0_->worker->register_file.at(reg_id); } { - AnyView packed_args[3]; + ffi::AnyView packed_args[3]; ffi::PackedArgs::Fill(packed_args, static_cast(DiscoAction::kDebugGetFromRemote), reg_id, worker_id); workers_[worker_id - 1]->Send(ffi::PackedArgs(packed_args, 3)); @@ -113,7 +112,7 @@ class ProcessSessionObj final : public BcastSessionObj { return result; } - void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) { + void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) { if (worker_id == 0) { this->SyncWorker(worker_id); worker_0_->worker->SetRegister(reg_id, value); @@ -125,7 +124,7 @@ class ProcessSessionObj final : public BcastSessionObj { value = wrapped; } { - AnyView packed_args[4]; + ffi::AnyView packed_args[4]; ffi::PackedArgs::Fill(packed_args, static_cast(DiscoAction::kDebugSetRegister), reg_id, worker_id, value); SendPacked(worker_id, ffi::PackedArgs(packed_args, 4)); @@ -197,8 +196,8 @@ void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_f worker.MainLoop(); } -TVM_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession); -TVM_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionProcess").set_body_typed(Session::ProcessSession); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.WorkerProcess").set_body_typed(WorkerProcess); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index 13fda2fecde4..30a1e6ed6609 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -21,10 +21,9 @@ #include #include -#include +#include +#include #include -#include -#include #include #include @@ -61,7 +60,7 @@ struct DiscoProtocol { inline void WriteObject(Object* obj); /*! \brief Read the object from stream. Used by RPCReference. */ - inline void ReadObject(int* tcode, TVMValue* value); + inline void ReadObject(TVMFFIAny* out); /*! \brief Callback method used when starting a new message. Used by RPCReference. */ void MessageStart(uint64_t packet_nbytes) {} @@ -124,15 +123,15 @@ template inline uint64_t DiscoProtocol::GetObjectBytes(Object* obj) { if (obj->IsInstance()) { return sizeof(uint32_t) + sizeof(int64_t); - } else if (obj->IsInstance()) { - uint64_t size = static_cast(obj)->size; + } else if (obj->IsInstance()) { + uint64_t size = static_cast(obj)->size; return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); } else if (obj->IsInstance()) { uint64_t size = static_cast(obj)->size; return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); - } else if (obj->IsInstance()) { - uint64_t ndim = static_cast(obj)->size; - return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ShapeTupleObj::index_type); + } else if (obj->IsInstance()) { + uint64_t ndim = static_cast(obj)->size; + return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ffi::ShapeObj::index_type); } else if (obj->IsInstance()) { return sizeof(uint32_t) + static_cast(obj)->GetObjectBytes(); } else { @@ -147,9 +146,9 @@ inline void DiscoProtocol::WriteObject(Object* obj) { int64_t reg_id = static_cast(obj)->reg_id; self->template Write(TypeIndex::kRuntimeDiscoDRef); self->template Write(reg_id); - } else if (obj->IsInstance()) { - StringObj* str = static_cast(obj); - self->template Write(TypeIndex::kRuntimeString); + } else if (obj->IsInstance()) { + ffi::StringObj* str = static_cast(obj); + self->template Write(ffi::TypeIndex::kTVMFFIStr); self->template Write(str->size); self->template WriteArray(str->data, str->size); } else if (obj->IsInstance()) { @@ -157,11 +156,11 @@ inline void DiscoProtocol::WriteObject(Object* obj) { self->template Write(ffi::TypeIndex::kTVMFFIBytes); self->template Write(bytes->size); self->template WriteArray(bytes->data, bytes->size); - } else if (obj->IsInstance()) { - ShapeTupleObj* shape = static_cast(obj); - self->template Write(TypeIndex::kRuntimeShapeTuple); + } else if (obj->IsInstance()) { + ffi::ShapeObj* shape = static_cast(obj); + self->template Write(ffi::TypeIndex::kTVMFFIShape); self->template Write(shape->size); - self->template WriteArray(shape->data, shape->size); + self->template WriteArray(shape->data, shape->size); } else if (obj->IsInstance()) { self->template Write(0); std::string str = static_cast(obj)->SaveToStr(); @@ -174,7 +173,7 @@ inline void DiscoProtocol::WriteObject(Object* obj) { } template -inline void DiscoProtocol::ReadObject(int* tcode, TVMValue* value) { +inline void DiscoProtocol::ReadObject(TVMFFIAny* out) { SubClassType* self = static_cast(this); ObjectRef result{nullptr}; uint32_t type_index; @@ -184,7 +183,7 @@ inline void DiscoProtocol::ReadObject(int* tcode, TVMValue* value) self->template Read(&dref->reg_id); dref->session = Session{nullptr}; result = ObjectRef(std::move(dref)); - } else if (type_index == TypeIndex::kRuntimeString) { + } else if (type_index == ffi::TypeIndex::kTVMFFIStr) { uint64_t size = 0; self->template Read(&size); std::string data(size, '\0'); @@ -196,12 +195,12 @@ inline void DiscoProtocol::ReadObject(int* tcode, TVMValue* value) std::string data(size, '\0'); self->template ReadArray(data.data(), size); result = ffi::Bytes(std::move(data)); - } else if (type_index == TypeIndex::kRuntimeShapeTuple) { + } else if (type_index == ffi::TypeIndex::kTVMFFIShape) { uint64_t ndim = 0; self->template Read(&ndim); - std::vector data(ndim); - self->template ReadArray(data.data(), ndim); - result = ShapeTuple(std::move(data)); + std::vector data(ndim); + self->template ReadArray(data.data(), ndim); + result = ffi::Shape(std::move(data)); } else if (type_index == 0) { uint64_t size = 0; self->template Read(&size); @@ -212,9 +211,7 @@ inline void DiscoProtocol::ReadObject(int* tcode, TVMValue* value) LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; } - // translate AnyView to legacy TVMValue and type_code - AnyView res_view = result; - AnyViewToLegacyTVMArgValue(res_view.CopyToTVMFFIAny(), value, tcode); + *reinterpret_cast(out) = result; object_arena_.push_back(result); } diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index 467888c65768..ed2d8575387f 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -16,10 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include -#include -#include namespace tvm { namespace runtime { @@ -32,27 +31,27 @@ struct SessionObj::FFI { TVM_REGISTER_OBJECT_TYPE(DRefObj); TVM_REGISTER_OBJECT_TYPE(SessionObj); -TVM_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); -TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionThreaded").set_body_typed(Session::ThreadedSession); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") .set_body_method(&DRefObj::DebugGetFromRemote); -TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom").set_body_method(&DRefObj::DebugCopyFrom); -TVM_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom").set_body_method(&DRefObj::DebugCopyFrom); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") .set_body_method(&SessionObj::GetNumWorkers); -TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") .set_body_method(&SessionObj::GetGlobalFunc); -TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") .set_body_method(&SessionObj::CopyFromWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCopyToWorker0") .set_body_method(&SessionObj::CopyToWorker0); -TVM_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker").set_body_method(&SessionObj::SyncWorker); -TVM_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionSyncWorker").set_body_method(&SessionObj::SyncWorker); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionInitCCL") // .set_body_method(&SessionObj::InitCCL); -TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked") +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionCallPacked") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Session self = args[0].cast(); *rv = SessionObj::FFI::CallWithPacked(self, args.Slice(1)); }); -TVM_REGISTER_GLOBAL("runtime.disco.SessionShutdown").set_body_method(&SessionObj::Shutdown); +TVM_FFI_REGISTER_GLOBAL("runtime.disco.SessionShutdown").set_body_method(&SessionObj::Shutdown); } // namespace runtime } // namespace tvm diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index f40fae007e50..f7e34e400c6a 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -17,7 +17,7 @@ * under the License. */ #include -#include +#include #include #include @@ -41,24 +41,16 @@ class DiscoThreadedMessageQueue : private dmlc::Stream, private DiscoProtocol { public: void Send(const ffi::PackedArgs& args) { - // Run legacy ABI translation. - std::vector values(args.size()); - std::vector type_codes(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values.data(), type_codes.data()); - // TODO(tqchen): use native convention that do not need ABI translation. - RPCReference::ReturnPackedSeq(values.data(), type_codes.data(), args.size(), this); + RPCReference::ReturnPackedSeq(reinterpret_cast(args.data()), args.size(), + this); CommitSendAndNotifyEnqueue(); } ffi::PackedArgs Recv() { DequeueNextPacket(); - TVMValue* values = nullptr; - int* type_codes = nullptr; + ffi::AnyView* packed_args = nullptr; int num_args = 0; - RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); - // Run legacy ABI translation. - AnyView* packed_args = reinterpret_cast(ArenaAlloc(num_args)); - LegacyTVMArgsToPackedArgs(values, type_codes, num_args, packed_args); + RPCReference::RecvPackedSeq(reinterpret_cast(&packed_args), &num_args, this); return ffi::PackedArgs(packed_args, num_args); } @@ -170,7 +162,7 @@ class ThreadedSessionObj final : public BcastSessionObj { return this->workers_.at(worker_id).worker->register_file.at(reg_id); } - void DebugSetRegister(int64_t reg_id, AnyView value, int worker_id) { + void DebugSetRegister(int64_t reg_id, ffi::AnyView value, int worker_id) { this->SyncWorker(worker_id); this->workers_.at(worker_id).worker->SetRegister(reg_id, value); } diff --git a/src/runtime/disco/utils.h b/src/runtime/disco/utils.h index 0c177e36e925..fa58c73aa787 100644 --- a/src/runtime/disco/utils.h +++ b/src/runtime/disco/utils.h @@ -39,7 +39,7 @@ inline Device UseDefaultDeviceIfNone(Device device) { * \note At the time of scaffolding Disco, RelaxVM has not provided mature support for standalone * integers. A common workaround is to use a 1-d shape tuple as an integer. */ -inline int64_t IntegerFromShapeTuple(const ShapeTuple& shape) { +inline int64_t IntegerFromShape(const ffi::Shape& shape) { CHECK_EQ(shape.size(), 1) << "ValueError: shape tuple must be 1-d to be converted to integer."; return shape[0]; } diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 185b066e724e..8a8666691300 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,10 +21,9 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ -#include +#include +#include #include -#include -#include #include "library_module.h" @@ -149,7 +148,7 @@ ObjectPtr CreateDSOLibraryObject(std::string library_path) { return n; } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_so") .set_body_typed([](std::string library_path, std::string) { ObjectPtr n = CreateDSOLibraryObject(library_path); return CreateModuleFromLibrary(n); diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index d8d10f885234..2aa377b9f8bd 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -24,8 +24,8 @@ #include #include +#include #include -#include #include #include @@ -237,22 +237,23 @@ std::string SaveParams(const Map& params) { return bytes; } -TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map& params) { - std::string s = ::tvm::runtime::SaveParams(params); - return ffi::Bytes(std::move(s)); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.SaveParams") + .set_body_typed([](const Map& params) { + std::string s = ::tvm::runtime::SaveParams(params); + return ffi::Bytes(std::move(s)); + }); -TVM_REGISTER_GLOBAL("runtime.SaveParamsToFile") +TVM_FFI_REGISTER_GLOBAL("runtime.SaveParamsToFile") .set_body_typed([](const Map& params, const String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); SaveParams(&strm, params); }); -TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const ffi::Bytes& s) { +TVM_FFI_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const ffi::Bytes& s) { return ::tvm::runtime::LoadParams(s); }); -TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { +TVM_FFI_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); return LoadParams(&strm); }); diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 0f3dc135712f..b4da7adea813 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_FILE_UTILS_H_ #define TVM_RUNTIME_FILE_UTILS_H_ -#include -#include +#include +#include #include #include diff --git a/src/runtime/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon_buffer.h index 8cb8a3209514..986d6b6e5ec6 100644 --- a/src/runtime/hexagon/hexagon_buffer.h +++ b/src/runtime/hexagon/hexagon_buffer.h @@ -20,11 +20,11 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_BUFFER_H_ -#include +#include +#include #include #include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index c959e39e1d39..4c95d68b2dc3 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -22,9 +22,9 @@ */ #include "hexagon_common.h" +#include #include #include -#include #include #include @@ -56,7 +56,7 @@ class HexagonTimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(HexagonTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.hexagon").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.hexagon").set_body_typed([](Device dev) { return Timer(make_object()); }); } // namespace hexagon @@ -89,7 +89,7 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } } // namespace detail -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ObjectPtr n = CreateDSOLibraryObject(args[0].cast()); *rv = CreateModuleFromLibrary(n); diff --git a/src/runtime/hexagon/hexagon_common.h b/src/runtime/hexagon/hexagon_common.h index 5834093a9e43..1e68a93a8b2e 100644 --- a/src/runtime/hexagon/hexagon_common.h +++ b/src/runtime/hexagon/hexagon_common.h @@ -24,9 +24,10 @@ #define TVM_RUNTIME_HEXAGON_HEXAGON_COMMON_H_ #include +#include #include +#include #include -#include #if defined(__hexagon__) #include diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 40294324018b..0bc7e2b80194 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -24,9 +24,9 @@ #include "hexagon_device_api.h" #include +#include #include #include -#include #include #include @@ -190,7 +190,7 @@ void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto dst = args[0].cast(); auto src = args[1].cast(); @@ -209,7 +209,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_copy") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); void* dst = args[1].cast(); @@ -226,7 +226,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy") *rv = static_cast(ret); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_wait") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); int inflight = args[1].cast(); @@ -235,21 +235,21 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.dma_end_group") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -274,7 +274,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd") *rv = hexapi->AllocDataSpace(dev, ndim, shape, type_hint, String(scope)); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.free_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -291,28 +291,29 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.acquire_resources") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.acquire_resources") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); api->AcquireResources(); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.release_resources") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.release_resources") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); api->ReleaseResources(); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.vtcm_device_bytes") +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon.vtcm_device_bytes") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { HexagonDeviceAPI* api = HexagonDeviceAPI::Global(); *rv = static_cast(api->VtcmPool()->VtcmDeviceBytes()); }); -TVM_REGISTER_GLOBAL("device_api.hexagon").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.hexagon") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global(); + *rv = static_cast(ptr); + }); } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 6ed2a2757f68..a5a8de45357a 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -24,8 +24,8 @@ #include "hexagon_module.h" #include +#include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_thread_manager.h b/src/runtime/hexagon/hexagon_thread_manager.h index 9bf6bb6efe64..7ec3ac61506d 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.h +++ b/src/runtime/hexagon/hexagon_thread_manager.h @@ -20,9 +20,9 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_THREAD_MANAGER_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_THREAD_MANAGER_H_ -#include +#include +#include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/runtime/hexagon/hexagon_vtcm_pool.h index 88b8f1470cf3..ece8454b859a 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.h +++ b/src/runtime/hexagon/hexagon_vtcm_pool.h @@ -20,11 +20,11 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_VTCM_POOL_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_VTCM_POOL_H_ -#include +#include +#include #include #include #include -#include #include #include diff --git a/src/runtime/hexagon/ops/conv2d.h b/src/runtime/hexagon/ops/conv2d.h index 79bd0217179b..5865d46117a0 100644 --- a/src/runtime/hexagon/ops/conv2d.h +++ b/src/runtime/hexagon/ops/conv2d.h @@ -17,7 +17,7 @@ * under the License. */ -#include +#include #include #include diff --git a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc index 5c764355aa58..5f171894d9cd 100644 --- a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc +++ b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -44,8 +44,7 @@ // 4: int stride_h // 5: int stride_w // 6: DLTensor output (NHWC) -extern "C" int conv2d_packed_fp16(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle); +extern "C" int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val); namespace tvm { namespace runtime { @@ -403,26 +402,27 @@ void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*) } // namespace runtime } // namespace tvm -int conv2d_packed_fp16(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle) { +int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) { namespace conv_utils = tvm::runtime::hexagon::conv_utils; ICHECK_EQ(num_args, 7) << "Unexpected number of arguments"; - ICHECK_EQ(type_codes[0], kTVMDLTensorHandle) + ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) << "First argument is expected to be the input tensor"; // Input activations - ICHECK_EQ(type_codes[1], kTVMDLTensorHandle) + ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) << "Second argument is expected to be the weights tensor"; // Weights - ICHECK_EQ(type_codes[2], kDLInt) + ICHECK_EQ(args[2].type_index, kTVMFFIInt) << "Third argument is expected to be the pad_top offset"; // pad_top offset - ICHECK_EQ(type_codes[3], kDLInt) + ICHECK_EQ(args[3].type_index, kTVMFFIInt) << "Fourth argument is expected to be the pad_left offset"; // pad_left offset - ICHECK_EQ(type_codes[4], kDLInt) << "Fifth argument is expected to be the stride_h"; // stride_h - ICHECK_EQ(type_codes[5], kDLInt) << "Sixth argument is expected to be the stride_w"; // stride_w - ICHECK_EQ(type_codes[6], kTVMDLTensorHandle) + ICHECK_EQ(args[4].type_index, kTVMFFIInt) + << "Fifth argument is expected to be the stride_h"; // stride_h + ICHECK_EQ(args[5].type_index, kTVMFFIInt) + << "Sixth argument is expected to be the stride_w"; // stride_w + ICHECK_EQ(args[6].type_index, kTVMFFIDLTensorPtr) << "Seventh argument is expected to be the output tensor"; // output - auto* act_flat = static_cast(args[0].v_handle); - auto* wgt_flat = static_cast(args[1].v_handle); - auto* out_flat = static_cast(args[6].v_handle); + auto* act_flat = static_cast(args[0].v_ptr); + auto* wgt_flat = static_cast(args[1].v_ptr); + auto* out_flat = static_cast(args[6].v_ptr); // Temporary assertion until multiple batches are supported ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; diff --git a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc b/src/runtime/hexagon/ops/conv2d_quant_hvx.cc index 99f7c245f557..30cba60cf1a8 100644 --- a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc +++ b/src/runtime/hexagon/ops/conv2d_quant_hvx.cc @@ -20,13 +20,12 @@ #include #include #include -#include +#include #include #include "conv2d.h" -extern "C" int conv2d_packed_quant(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle); +extern "C" int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val); namespace tvm { namespace runtime { @@ -230,30 +229,38 @@ void conv_layer_int8_hvx_whole(DLTensor& cr_out, const DLTensor& cr_act, // NOL } // namespace runtime } // namespace tvm -int conv2d_packed_quant(TVMValue* args, int* type_codes, int num_args, TVMValue* out_val, - int out_code, void* res_handle) { +int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) { namespace conv_utils = tvm::runtime::hexagon::conv_utils; ICHECK_EQ(num_args, 13) << "Unexpected number of arguments"; - ICHECK_EQ(type_codes[0], kTVMDLTensorHandle) + ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) << "First argument is expected to be the input tensor"; // Input activations - ICHECK_EQ(type_codes[1], kTVMDLTensorHandle) + ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) << "Second argument is expected to be the weights tensor"; // Weights - ICHECK_EQ(type_codes[2], kDLFloat) << "Third argument is expected to be the activation scale"; - ICHECK_EQ(type_codes[3], kDLInt) << "Fourth argument is expected to be the activation zero point"; - ICHECK_EQ(type_codes[4], kDLFloat) << "Fifth argument is expected to be the weight scale"; - ICHECK_EQ(type_codes[5], kDLInt) << "Sixth argument is expected to be the weight zero point"; - ICHECK_EQ(type_codes[6], kDLFloat) << "Seventh argument is expected to be the output scale"; - ICHECK_EQ(type_codes[7], kDLInt) << "Eigth argument is expected to be the output zero point"; - ICHECK_EQ(type_codes[8], kDLInt) << "Nineth argument is expected to be the stride_h"; // stride_h - ICHECK_EQ(type_codes[9], kDLInt) << "Tenth argument is expected to be the stride_w"; // stride_w - ICHECK_EQ(type_codes[10], kDLInt) << "Eleventh argument is expected to be fixed final scale"; - ICHECK_EQ(type_codes[11], kDLInt) << "Twelfth argument is expected to be scale factor"; - ICHECK_EQ(type_codes[12], kTVMDLTensorHandle) + ICHECK_EQ(args[2].type_index, kTVMFFIFloat) + << "Third argument is expected to be the activation scale"; + ICHECK_EQ(args[3].type_index, kTVMFFIInt) + << "Fourth argument is expected to be the activation zero point"; + ICHECK_EQ(args[4].type_index, kTVMFFIFloat) + << "Fifth argument is expected to be the weight scale"; + ICHECK_EQ(args[5].type_index, kTVMFFIInt) + << "Sixth argument is expected to be the weight zero point"; + ICHECK_EQ(args[6].type_index, kTVMFFIFloat) + << "Seventh argument is expected to be the output scale"; + ICHECK_EQ(args[7].type_index, kTVMFFIInt) + << "Eigth argument is expected to be the output zero point"; + ICHECK_EQ(args[8].type_index, kTVMFFIInt) + << "Nineth argument is expected to be the stride_h"; // stride_h + ICHECK_EQ(args[9].type_index, kTVMFFIInt) + << "Tenth argument is expected to be the stride_w"; // stride_w + ICHECK_EQ(args[10].type_index, kTVMFFIInt) + << "Eleventh argument is expected to be fixed final scale"; + ICHECK_EQ(args[11].type_index, kTVMFFIInt) << "Twelfth argument is expected to be scale factor"; + ICHECK_EQ(args[12].type_index, kTVMFFIDLTensorPtr) << "Thirteenth argument is expected to be the output tensor"; // output - auto* act_flat = static_cast(args[0].v_handle); - auto* wgt_flat = static_cast(args[1].v_handle); - auto* out_flat = static_cast(args[12].v_handle); + auto* act_flat = static_cast(args[0].v_ptr); + auto* wgt_flat = static_cast(args[1].v_ptr); + auto* out_flat = static_cast(args[12].v_ptr); // Temporary assertion until multiple batches are supported ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 265e5bb12e57..0f71f7265024 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -21,7 +21,7 @@ * \file hexagon_session.cc */ -#include +#include extern "C" { #include @@ -109,7 +109,7 @@ class HexagonTransportChannel : public RPCChannel { remote_handle64 _handle = AEE_EUNKNOWN; }; -TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() >= 4) << args.size() << " is less than 4"; diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index 78d65fb8deeb..7880018ff8e8 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -27,9 +27,8 @@ extern "C" { #include #include +#include #include -#include -#include #include #include @@ -329,14 +328,14 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.load_module") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); *rv = CreateModuleFromLibrary(n); }); -TVM_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto profiling_mode = args[0].cast(); auto out_file = args[1].cast(); @@ -354,7 +353,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto file_name = args[0].cast(); auto data = args[1].cast(); diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index a98abe634e8b..2301ffc13d17 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -32,8 +32,8 @@ #include "../../hexagon_common.h" #include "../../profiler/prof_utils.h" #include "hexagon_sim_proto.h" +#include "tvm/ffi/function.h" #include "tvm/runtime/packed_func.h" -#include "tvm/runtime/registry.h" namespace tvm { namespace runtime { @@ -332,14 +332,14 @@ __attribute__((weak)) void _Get_eh_data() {} __attribute__((weak)) void _Parse_fde_instr() {} } -TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.load_module") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto soname = args[0].cast(); tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); *rv = CreateModuleFromLibrary(n); }); -TVM_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") +TVM_FFI_REGISTER_GLOBAL("tvm.hexagon.get_profile_output") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto profiling_mode = args[0].cast(); auto out_file = args[1].cast(); @@ -357,7 +357,7 @@ void SaveBinaryToFile(const std::string& file_name, const std::string& data) { fs.write(&data[0], data.length()); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") .set_body_packed([](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) { auto file_name = args[0].cast(); auto data = args[1].cast(); diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index 7366371b491a..5eb7beab0f57 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -18,8 +18,7 @@ */ #include -#include -#include +#include // POSIX includes #include #include @@ -1370,7 +1369,7 @@ std::optional SimulatorRPCChannel::to_nullptr(const detail::Mayb .Default(std::nullopt); } -TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() >= 4) << args.size() << " is less than 4"; diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index f580a6d667f1..18f973daf159 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -25,8 +25,8 @@ #include #include +#include #include -#include #include #include diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index ccc0b3193b87..60ce95e2369b 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ #define TVM_RUNTIME_LIBRARY_MODULE_H_ +#include #include -#include #include #include diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index 2d4164ce4425..45e83f33e2da 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -121,7 +121,7 @@ int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int li if (filename) { // Stack frames for TVM FFI if (strstr(filename, "include/tvm/runtime/packed_func.h") || - strstr(filename, "include/tvm/runtime/registry.h") || + strstr(filename, "include/tvm/ffi/function.h") || strstr(filename, "src/runtime/c_runtime_api.cc")) { return true; } diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index 9b6eae80394a..b6c2a098d474 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -21,8 +21,8 @@ * \file tvm/runtime/memory/memory_manager.cc * \brief Allocate and manage memory for the runtime. */ +#include #include -#include #include #include @@ -59,7 +59,7 @@ inline size_t GetDataAlignment(const DLDataType& dtype) { return align; } -NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ShapeTuple shape, DLDataType dtype, +NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ffi::Shape shape, DLDataType dtype, String scope) { if (scope == "global" || scope.empty()) { return AllocNDArray(offset, shape, dtype); @@ -90,7 +90,7 @@ NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ShapeTuple shape, DLDataT this->buffer.device, shape, scope, offset); } -NDArray StorageObj::AllocNDArray(int64_t offset, ShapeTuple shape, DLDataType dtype) { +NDArray StorageObj::AllocNDArray(int64_t offset, ffi::Shape shape, DLDataType dtype) { VerifyDataType(dtype); size_t needed_size = ffi::GetDataSize(shape.Product(), dtype); @@ -212,7 +212,7 @@ void MemoryManager::Clear() { } } -NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev, +NDArray Allocator::Empty(ffi::Shape shape, DLDataType dtype, DLDevice dev, Optional mem_scope) { VerifyDataType(dtype); @@ -245,7 +245,7 @@ bool Allocator::AllowMemoryScope(const std::string& mem_scope) const { return mem_scope.empty() || mem_scope == "global"; } -Buffer Allocator::Alloc(Device dev, ShapeTuple shape, DLDataType type_hint, +Buffer Allocator::Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) { if (AllowMemoryScope(mem_scope)) { // by default, we can always redirect to the flat memory allocations @@ -264,7 +264,7 @@ void Allocator::Clear() { // Pooled allocator will override this method. } -TVM_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); } // namespace memory } // namespace runtime diff --git a/src/runtime/memory/naive_allocator.h b/src/runtime/memory/naive_allocator.h index 6d8e90fed9f2..aed990d22c3b 100644 --- a/src/runtime/memory/naive_allocator.h +++ b/src/runtime/memory/naive_allocator.h @@ -48,7 +48,7 @@ class NaiveAllocator final : public Allocator { return buf; } - Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint, + Buffer Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) final { Buffer buf; size_t nbytes = 1; diff --git a/src/runtime/memory/pooled_allocator.h b/src/runtime/memory/pooled_allocator.h index 2c46d5df51a4..744c61987cdd 100644 --- a/src/runtime/memory/pooled_allocator.h +++ b/src/runtime/memory/pooled_allocator.h @@ -73,7 +73,7 @@ class PooledAllocator : public Allocator { return buf; } - Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint, + Buffer Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) override { if (AllowMemoryScope(mem_scope)) { return Allocator::Alloc(dev, shape, type_hint, mem_scope); diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index c415468088ed..51120c1f9efb 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -26,17 +26,15 @@ #include #include +#include #include #include -#include #include #include #include #include -#include "runtime_base.h" - namespace tvm { namespace runtime { diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index ab383732ea8c..138d312dd47c 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -30,10 +30,10 @@ #import #import #import -#include +#include +#include #include #include -#include #include #include diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 83f2c38a2bd5..46824b1600ee 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -21,8 +21,8 @@ * \file metal_device_api.mm */ #include +#include #include -#include #include "metal_common.h" namespace tvm { @@ -362,12 +362,12 @@ int GetWarpSize(id dev) { MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.metal").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.metal").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = MetalWorkspace::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); @@ -403,7 +403,7 @@ virtual void Stop() { TVM_REGISTER_OBJECT_TYPE(MetalTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.metal").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.metal").set_body_typed([](Device dev) { return Timer(make_object(dev)); }); diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index d01523b1faba..e2705a7a806b 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ -#include +#include #include #include diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index cc25fd8b0daf..f7c59156cb6a 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -22,8 +22,8 @@ */ #include "metal_module.h" #include +#include #include -#include #include #include #include @@ -260,23 +260,23 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) ffi::Function MetalModuleNode::GetFunction(const String& name, const ObjectPtr& sptr_to_self) { - ffi::Function f; + ffi::Function ret; AUTORELEASEPOOL { ICHECK_EQ(sptr_to_self.get(), this); ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) { - f = ffi::Function(); - return; + ret = ffi::Function(); + return ret; } const FunctionInfo& info = it->second; MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, info.launch_param_tags); - pf = PackFuncNonBufferArg(f, info.arg_types); + ret = PackFuncNonBufferArg(f, info.arg_types); }; - return pf; + return ret; } Module MetalModuleCreate(std::unordered_map smap, @@ -287,7 +287,7 @@ Module MetalModuleCreate(std::unordered_map smap, return Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.create_metal_module") +TVM_FFI_REGISTER_GLOBAL("runtime.module.create_metal_module") .set_body_typed([](Map smap, std::string fmap_json, std::string fmt, std::string source) { std::istringstream stream(fmap_json); @@ -317,6 +317,6 @@ Module MetalModuleLoadBinary(void* strm) { return MetalModuleCreate(smap, fmap, fmt, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/minrpc/minrpc_interfaces.h b/src/runtime/minrpc/minrpc_interfaces.h deleted file mode 100644 index a45dee9f2c35..000000000000 --- a/src/runtime/minrpc/minrpc_interfaces.h +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_ -#define TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_ - -#include - -#include "rpc_reference.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief Return interface used in ExecInterface to generate and send the responses. - */ -class MinRPCReturnInterface { - public: - virtual ~MinRPCReturnInterface() {} - /*! * \brief sends a response to the client with kTVMNullptr in payload. */ - virtual void ReturnVoid() = 0; - - /*! * \brief sends a response to the client with one kTVMOpaqueHandle in payload. */ - virtual void ReturnHandle(void* handle) = 0; - - /*! * \brief sends an exception response to the client with a kTVMStr in payload. */ - virtual void ReturnException(const char* msg) = 0; - - /*! * \brief sends a packed argument sequnce to the client. */ - virtual void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) = 0; - - /*! * \brief sends a copy of the requested remote data to the client. */ - virtual void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) = 0; - - /*! * \brief sends an exception response to the client with the last TVM erros as the message. */ - virtual void ReturnLastTVMError() = 0; - - /*! * \brief internal error. */ - virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) = 0; -}; - -/*! - * \brief Execute interface used in MinRPCServer to process different received commands - */ -class MinRPCExecInterface { - public: - virtual ~MinRPCExecInterface() {} - - /*! * \brief Execute an Initilize server command. */ - virtual void InitServer(int num_args) = 0; - - /*! * \brief calls a function specified by the call_handle. */ - virtual void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, - int num_args) = 0; - - /*! * \brief Execute a copy from remote command by sending the data described in arr to the client - */ - virtual void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) = 0; - - /*! * \brief Execute a copy to remote command by receiving the data described in arr from the - * client */ - virtual int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) = 0; - - /*! * \brief calls a system function specified by the code. */ - virtual void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) = 0; - - /*! * \brief internal error. */ - virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) = 0; - - /*! * \brief return the ReturnInterface pointer that is used to generate and send the responses. - */ - virtual MinRPCReturnInterface* GetReturnInterface() = 0; -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_ diff --git a/src/runtime/minrpc/minrpc_logger.cc b/src/runtime/minrpc/minrpc_logger.cc deleted file mode 100644 index 4f3b7e764c9b..000000000000 --- a/src/runtime/minrpc/minrpc_logger.cc +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "minrpc_logger.h" - -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "minrpc_interfaces.h" -#include "rpc_reference.h" - -namespace tvm { -namespace runtime { - -void Logger::LogTVMValue(int tcode, TVMValue value) { - switch (tcode) { - case kDLInt: { - LogValue("(int64)", value.v_int64); - break; - } - case kDLUInt: { - LogValue("(uint64)", value.v_int64); - break; - } - case kDLFloat: { - LogValue("(float)", value.v_float64); - break; - } - case kTVMDataType: { - LogDLData("DLDataType(code,bits,lane)", &value.v_type); - break; - } - case kDLDevice: { - LogDLDevice("DLDevice(type,id)", &value.v_device); - break; - } - case kTVMPackedFuncHandle: { - LogValue("(PackedFuncHandle)", value.v_handle); - break; - } - case kTVMModuleHandle: { - LogValue("(ModuleHandle)", value.v_handle); - break; - } - case kTVMOpaqueHandle: { - LogValue("(OpaqueHandle)", value.v_handle); - break; - } - case kTVMDLTensorHandle: { - LogValue("(TensorHandle)", value.v_handle); - break; - } - case kTVMNDArrayHandle: { - LogValue("kTVMNDArrayHandle", value.v_handle); - break; - } - case kTVMNullptr: { - Log("Nullptr"); - break; - } - case kTVMStr: { - Log("\""); - Log(value.v_str); - Log("\""); - break; - } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(value.v_handle); - int len = bytes->size; - LogValue("(Bytes) [size]: ", len); - if (PRINT_BYTES) { - Log(", [Values]:"); - Log(" { "); - if (len > 0) { - LogValue("", (uint8_t)bytes->data[0]); - } - for (int j = 1; j < len; j++) LogValue(" - ", (uint8_t)bytes->data[j]); - Log(" } "); - } - break; - } - default: { - Log("ERROR-kUnknownTypeCode)"); - break; - } - } - Log("; "); -} - -void Logger::OutputLog() { - LOG(INFO) << os_.str(); - os_.str(std::string()); -} - -void MinRPCReturnsWithLog::ReturnVoid() { - next_->ReturnVoid(); - logger_->Log("-> ReturnVoid"); - logger_->OutputLog(); -} - -void MinRPCReturnsWithLog::ReturnHandle(void* handle) { - next_->ReturnHandle(handle); - if (code_ == RPCCode::kGetGlobalFunc) { - RegisterHandleName(handle); - } - logger_->LogValue("-> ReturnHandle: ", handle); - logger_->OutputLog(); -} - -void MinRPCReturnsWithLog::ReturnException(const char* msg) { - next_->ReturnException(msg); - logger_->Log("-> Exception: "); - logger_->Log(msg); - logger_->OutputLog(); -} - -void MinRPCReturnsWithLog::ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, - int num_args) { - next_->ReturnPackedSeq(arg_values, type_codes, num_args); - ProcessValues(arg_values, type_codes, num_args); - logger_->OutputLog(); -} - -void MinRPCReturnsWithLog::ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) { - next_->ReturnCopyFromRemote(data_ptr, num_bytes); - logger_->LogValue("-> CopyFromRemote: ", num_bytes); - logger_->LogValue(", ", static_cast(data_ptr)); - logger_->OutputLog(); -} - -void MinRPCReturnsWithLog::ReturnLastTVMError() { - const char* err = TVMGetLastError(); - ReturnException(err); -} - -void MinRPCReturnsWithLog::ThrowError(RPCServerStatus code, RPCCode info) { - next_->ThrowError(code, info); - logger_->Log("-> ERROR: "); - logger_->Log(RPCServerStatusToString(code)); - logger_->OutputLog(); -} - -void MinRPCReturnsWithLog::ProcessValues(const TVMValue* values, const int* tcodes, int num_args) { - if (tcodes != nullptr) { - logger_->Log("-> ["); - for (int i = 0; i < num_args; ++i) { - logger_->LogTVMValue(tcodes[i], values[i]); - - if (tcodes[i] == kTVMOpaqueHandle) { - RegisterHandleName(values[i].v_handle); - } - } - logger_->Log("]"); - } -} - -void MinRPCReturnsWithLog::ResetHandleName(RPCCode code) { - code_ = code; - handle_name_.clear(); -} - -void MinRPCReturnsWithLog::UpdateHandleName(const char* name) { - if (handle_name_.length() != 0) { - handle_name_.append("::"); - } - handle_name_.append(name); -} - -void MinRPCReturnsWithLog::GetHandleName(void* handle) { - if (handle_descriptions_.find(handle) != handle_descriptions_.end()) { - handle_name_.append(handle_descriptions_[handle]); - logger_->LogHandleName(handle_name_); - } -} - -void MinRPCReturnsWithLog::ReleaseHandleName(void* handle) { - if (handle_descriptions_.find(handle) != handle_descriptions_.end()) { - logger_->LogHandleName(handle_descriptions_[handle]); - handle_descriptions_.erase(handle); - } -} - -void MinRPCReturnsWithLog::RegisterHandleName(void* handle) { - handle_descriptions_[handle] = handle_name_; -} - -void MinRPCExecuteWithLog::InitServer(int num_args) { - SetRPCCode(RPCCode::kInitServer); - logger_->Log("Init Server"); - next_->InitServer(num_args); -} - -void MinRPCExecuteWithLog::NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, - int num_args) { - SetRPCCode(RPCCode::kCallFunc); - logger_->LogValue("call_handle: ", reinterpret_cast(call_handle)); - ret_handler_->GetHandleName(reinterpret_cast(call_handle)); - if (num_args > 0) { - logger_->Log(", "); - } - ProcessValues(values, tcodes, num_args); - next_->NormalCallFunc(call_handle, values, tcodes, num_args); -} - -void MinRPCExecuteWithLog::CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) { - SetRPCCode(RPCCode::kCopyFromRemote); - logger_->LogValue("data_handle: ", static_cast(arr->data)); - logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device)); - logger_->LogValue(", ndim: ", arr->ndim); - logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype)); - logger_->LogValue(", num_bytes:", num_bytes); - next_->CopyFromRemote(arr, num_bytes, temp_data); -} - -int MinRPCExecuteWithLog::CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { - SetRPCCode(RPCCode::kCopyToRemote); - logger_->LogValue("data_handle: ", static_cast(arr->data)); - logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device)); - logger_->LogValue(", ndim: ", arr->ndim); - logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype)); - logger_->LogValue(", byte_offset: ", arr->byte_offset); - return next_->CopyToRemote(arr, num_bytes, data_ptr); -} - -void MinRPCExecuteWithLog::SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) { - SetRPCCode(code); - if ((code) == RPCCode::kFreeHandle) { - if ((num_args == 2) && (tcodes[0] == kTVMOpaqueHandle) && (tcodes[1] == kDLInt)) { - logger_->LogValue("handle: ", static_cast(values[0].v_handle)); - if (values[1].v_int64 == kTVMModuleHandle || values[1].v_int64 == kTVMPackedFuncHandle) { - ret_handler_->ReleaseHandleName(static_cast(values[0].v_handle)); - } - } - } else { - ProcessValues(values, tcodes, num_args); - } - next_->SysCallFunc(code, values, tcodes, num_args); -} - -void MinRPCExecuteWithLog::ThrowError(RPCServerStatus code, RPCCode info) { - logger_->Log("-> Error\n"); - next_->ThrowError(code, info); -} - -void MinRPCExecuteWithLog::ProcessValues(TVMValue* values, int* tcodes, int num_args) { - if (tcodes != nullptr) { - logger_->Log("["); - for (int i = 0; i < num_args; ++i) { - logger_->LogTVMValue(tcodes[i], values[i]); - - if (tcodes[i] == kTVMStr) { - if (strlen(values[i].v_str) > 0) { - ret_handler_->UpdateHandleName(values[i].v_str); - } - } - } - logger_->Log("]"); - } -} - -void MinRPCExecuteWithLog::SetRPCCode(RPCCode code) { - logger_->Log(RPCCodeToString(code)); - logger_->Log(", "); - ret_handler_->ResetHandleName(code); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/minrpc/minrpc_logger.h b/src/runtime/minrpc/minrpc_logger.h deleted file mode 100644 index 13d44c3cba9b..000000000000 --- a/src/runtime/minrpc/minrpc_logger.h +++ /dev/null @@ -1,296 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_ -#define TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_ - -#include - -#include -#include -#include -#include - -#include "minrpc_interfaces.h" -#include "rpc_reference.h" - -namespace tvm { -namespace runtime { - -#define PRINT_BYTES false - -/*! - * \brief Generates a user readeable log on the console - */ -class Logger { - public: - Logger() {} - - /*! - * \brief this function logs a string - * - * \param s the string to be logged. - */ - void Log(const char* s) { os_ << s; } - void Log(std::string s) { os_ << s; } - - /*! - * \brief this function logs a numerical value - * - * \param desc adds any necessary description before the value. - * \param val is the value to be logged. - */ - template - void LogValue(const char* desc, T val) { - os_ << desc << val; - } - - /*! - * \brief this function logs the properties of a DLDevice - * - * \param desc adds any necessary description before the DLDevice. - * \param dev is the pointer to the DLDevice to be logged. - */ - void LogDLDevice(const char* desc, DLDevice* dev) { - os_ << desc << "(" << dev->device_type << "," << dev->device_id << ")"; - } - - /*! - * \brief this function logs the properties of a DLDataType - * - * \param desc adds any necessary description before the DLDataType. - * \param data is the pointer to the DLDataType to be logged. - */ - void LogDLData(const char* desc, DLDataType* data) { - os_ << desc << "(" << (uint16_t)data->code << "," << (uint16_t)data->bits << "," << data->lanes - << ")"; - } - - /*! - * \brief this function logs a handle name. - * - * \param name is the name to be logged. - */ - void LogHandleName(std::string name) { - if (name.length() > 0) { - os_ << " <" << name.c_str() << ">"; - } - } - - /*! - * \brief this function logs a TVMValue based on its type. - * - * \param tcode the type_code of the value stored in TVMValue. - * \param value is the TVMValue to be logged. - */ - void LogTVMValue(int tcode, TVMValue value); - - /*! - * \brief this function output the log to the console. - */ - void OutputLog(); - - private: - std::stringstream os_; -}; - -/*! - * \brief A wrapper for a MinRPCReturns object, that also logs the responses. - * - * \param next underlying MinRPCReturns that generates the responses. - */ -class MinRPCReturnsWithLog : public MinRPCReturnInterface { - public: - /*! - * \brief Constructor. - * \param io The IO handler. - */ - MinRPCReturnsWithLog(MinRPCReturnInterface* next, Logger* logger) - : next_(next), logger_(logger) {} - - ~MinRPCReturnsWithLog() {} - - void ReturnVoid(); - - void ReturnHandle(void* handle); - - void ReturnException(const char* msg); - - void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args); - - void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes); - - void ReturnLastTVMError(); - - void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone); - - /*! - * \brief this function logs a list of TVMValues, and registers handle_name when needed. - * - * \param values is the list of TVMValues. - * \param tcodes is the list type_code of the TVMValues. - * \param num_args is the number of items in the list. - */ - void ProcessValues(const TVMValue* values, const int* tcodes, int num_args); - - /*! - * \brief this function is called when a new command is executed. - * It clears the handle_name_ and records the command code. - * - * \param code the RPC command code. - */ - void ResetHandleName(RPCCode code); - - /*! - * \brief appends name to the handle_name_. - * - * \param name handle name. - */ - void UpdateHandleName(const char* name); - - /*! - * \brief get the stored handle description. - * - * \param handle the handle to get the description for. - */ - void GetHandleName(void* handle); - - /*! - * \brief remove the handle description from handle_descriptions_. - * - * \param handle the handle to remove the description for. - */ - void ReleaseHandleName(void* handle); - - private: - /*! - * \brief add the handle description to handle_descriptions_. - * - * \param handle the handle to add the description for. - */ - void RegisterHandleName(void* handle); - - MinRPCReturnInterface* next_; - std::string handle_name_; - std::unordered_map handle_descriptions_; - RPCCode code_; - Logger* logger_; -}; - -/*! - * \brief A wrapper for a MinRPCExecute object, that also logs the responses. - * - * \param next: underlying MinRPCExecute that processes the packets. - */ -class MinRPCExecuteWithLog : public MinRPCExecInterface { - public: - MinRPCExecuteWithLog(MinRPCExecInterface* next, Logger* logger) : next_(next), logger_(logger) { - ret_handler_ = reinterpret_cast(next_->GetReturnInterface()); - } - - ~MinRPCExecuteWithLog() {} - - void InitServer(int num_args); - - void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args); - - void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data); - - int CopyToRemote(DLTensor* arr, uint64_t _num_bytes, uint8_t* _data_ptr); - - void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args); - - void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone); - - MinRPCReturnInterface* GetReturnInterface() { return next_->GetReturnInterface(); } - - private: - /*! - * \brief this function logs a list of TVMValues, and updates handle_name when needed. - * - * \param values is the list of TVMValues. - * \param tcodes is the list type_code of the TVMValues. - * \param num_args is the number of items in the list. - */ - void ProcessValues(TVMValue* values, int* tcodes, int num_args); - - /*! - * \brief this function is called when a new command is executed. - * - * \param code the RPC command code. - */ - void SetRPCCode(RPCCode code); - - MinRPCExecInterface* next_; - MinRPCReturnsWithLog* ret_handler_; - Logger* logger_; -}; - -/*! - * \brief A No-operation MinRPCReturns used within the MinRPCSniffer - * - * \tparam TIOHandler* IO provider to provide io handling. - */ -template -class MinRPCReturnsNoOp : public MinRPCReturnInterface { - public: - /*! - * \brief Constructor. - * \param io The IO handler. - */ - explicit MinRPCReturnsNoOp(TIOHandler* io) : io_(io) {} - ~MinRPCReturnsNoOp() {} - void ReturnVoid() {} - void ReturnHandle(void* handle) {} - void ReturnException(const char* msg) {} - void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {} - void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {} - void ReturnLastTVMError() {} - void ThrowError(RPCServerStatus code, RPCCode info) {} - - private: - TIOHandler* io_; -}; - -/*! - * \brief A No-operation MinRPCExecute used within the MinRPCSniffer - * - * \tparam ReturnInterface* ReturnInterface pointer to generate and send the responses. - - */ -class MinRPCExecuteNoOp : public MinRPCExecInterface { - public: - explicit MinRPCExecuteNoOp(MinRPCReturnInterface* ret_handler) : ret_handler_(ret_handler) {} - ~MinRPCExecuteNoOp() {} - void InitServer(int _num_args) {} - void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) {} - void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) {} - int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { return 1; } - void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {} - void ThrowError(RPCServerStatus code, RPCCode info) {} - MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; } - - private: - MinRPCReturnInterface* ret_handler_; -}; - -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_" diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 2b14a8ae8398..ccfd3d079280 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -28,511 +28,24 @@ #ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ #define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ -#ifndef DMLC_LITTLE_ENDIAN -#define DMLC_LITTLE_ENDIAN 1 -#endif - -#include -#include +#include +#include +#include +#include +#include #include #include #include "../../support/generic_arena.h" -#include "minrpc_interfaces.h" #include "rpc_reference.h" -#ifndef MINRPC_CHECK -#define MINRPC_CHECK(cond) \ - if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); -#endif - namespace tvm { namespace runtime { - -namespace detail { +namespace details { template class PageAllocator; -} - -/*! - * \brief Responses to a minimum RPC command. - * - * \tparam TIOHandler IO provider to provide io handling. - */ -template -class MinRPCReturns : public MinRPCReturnInterface { - public: - /*! - * \brief Constructor. - * \param io The IO handler. - */ - explicit MinRPCReturns(TIOHandler* io) : io_(io) {} - - void ReturnVoid() { - int32_t num_args = 1; - int32_t tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; - - uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); - - io_->MessageStart(packet_nbytes); - Write(packet_nbytes); - Write(code); - Write(num_args); - Write(tcode); - io_->MessageDone(); - } - - void ReturnHandle(void* handle) { - int32_t num_args = 1; - int32_t tcode = kTVMOpaqueHandle; - RPCCode code = RPCCode::kReturn; - uint64_t encode_handle = reinterpret_cast(handle); - uint64_t packet_nbytes = - sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); - - io_->MessageStart(packet_nbytes); - Write(packet_nbytes); - Write(code); - Write(num_args); - Write(tcode); - Write(encode_handle); - io_->MessageDone(); - } - - void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } - - void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { - RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); - } - - void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) { - RPCCode code = RPCCode::kCopyAck; - uint64_t packet_nbytes = sizeof(code) + num_bytes; - - io_->MessageStart(packet_nbytes); - Write(packet_nbytes); - Write(code); - WriteArray(data_ptr, num_bytes); - io_->MessageDone(); - } - - void ReturnLastTVMError() { - const char* err = TVMGetLastError(); - ReturnException(err); - } - - void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); } - - void MessageDone() { io_->MessageDone(); } - - void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { - io_->Exit(static_cast(code)); - } - - void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); } - uint64_t GetObjectBytes(void* obj) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); - return 0; - } - - template - void Write(const T& data) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return WriteRawBytes(&data, sizeof(T)); - } - - template - void WriteArray(T* data, size_t count) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return WriteRawBytes(data, sizeof(T) * count); - } - - private: - void WriteRawBytes(const void* data, size_t size) { - const uint8_t* buf = static_cast(data); - size_t ndone = 0; - while (ndone < size) { - ssize_t ret = io_->PosixWrite(buf, size - ndone); - if (ret <= 0) { - this->ThrowError(RPCServerStatus::kWriteError); - } - buf += ret; - ndone += ret; - } - } - - TIOHandler* io_; -}; - -/*! - * \brief Executing a minimum RPC command. - * - * \tparam TIOHandler IO provider to provide io handling. - * \tparam MinRPCReturnInterface* handles response generatation and transmission. - */ -template -class MinRPCExecute : public MinRPCExecInterface { - public: - MinRPCExecute(TIOHandler* io, MinRPCReturnInterface* ret_handler) - : io_(io), ret_handler_(ret_handler) {} - - void InitServer(int num_args) { - MINRPC_CHECK(num_args == 0); - ret_handler_->ReturnVoid(); - } - - void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) { - TVMValue ret_value[3]; - int ret_tcode[3]; - - int call_ecode = TVMFuncCall(reinterpret_cast(call_handle), values, tcodes, num_args, - &(ret_value[1]), &(ret_tcode[1])); - - if (call_ecode == 0) { - // Return value encoding as in LocalSession - int rv_tcode = ret_tcode[1]; - ret_tcode[0] = kDLInt; - ret_value[0].v_int64 = rv_tcode; - if (rv_tcode == kTVMNDArrayHandle) { - ret_tcode[1] = kTVMDLTensorHandle; - ret_value[2].v_handle = ret_value[1].v_handle; - ret_tcode[2] = kTVMOpaqueHandle; - ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 3); - } else if (rv_tcode == kTVMBytes) { - ret_tcode[1] = kTVMBytes; - ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); - TVMByteArrayFree(reinterpret_cast(ret_value[1].v_handle)); // NOLINT(*) - } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle || - rv_tcode == kTVMObjectHandle) { - ret_tcode[1] = kTVMOpaqueHandle; - ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); - } else { - ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); - } - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { - int call_ecode = 0; - if (arr->device.device_type != kDLCPU) { - DLTensor temp; - temp.data = static_cast(data_ptr); - temp.device = DLDevice{kDLCPU, 0}; - temp.ndim = arr->ndim; - temp.dtype = arr->dtype; - temp.shape = arr->shape; - temp.strides = nullptr; - temp.byte_offset = 0; - call_ecode = TVMDeviceCopyDataFromTo(arr, &temp, nullptr); - // need sync to make sure that the copy is completed. - if (call_ecode == 0) { - call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); - } - } - - if (call_ecode == 0) { - ret_handler_->ReturnCopyFromRemote(data_ptr, num_bytes); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { - int call_ecode = 0; - - int ret = ReadArray(data_ptr, num_bytes); - if (ret <= 0) return ret; - - if (arr->device.device_type != kDLCPU) { - DLTensor temp; - temp.data = data_ptr; - temp.device = DLDevice{kDLCPU, 0}; - temp.ndim = arr->ndim; - temp.dtype = arr->dtype; - temp.shape = arr->shape; - temp.strides = nullptr; - temp.byte_offset = 0; - call_ecode = TVMDeviceCopyDataFromTo(&temp, arr, nullptr); - // need sync to make sure that the copy is completed. - if (call_ecode == 0) { - call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); - } - } - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - - return 1; - } - - void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) { - switch (code) { - case RPCCode::kFreeHandle: { - SyscallFreeHandle(values, tcodes, num_args); - break; - } - case RPCCode::kGetGlobalFunc: { - SyscallGetGlobalFunc(values, tcodes, num_args); - break; - } - case RPCCode::kDevSetDevice: { - ret_handler_->ReturnException("SetDevice not supported"); - break; - } - case RPCCode::kDevGetAttr: { - ret_handler_->ReturnException("GetAttr not supported"); - break; - } - case RPCCode::kDevAllocData: { - SyscallDevAllocData(values, tcodes, num_args); - break; - } - case RPCCode::kDevAllocDataWithScope: { - SyscallDevAllocDataWithScope(values, tcodes, num_args); - break; - } - case RPCCode::kDevFreeData: { - SyscallDevFreeData(values, tcodes, num_args); - break; - } - case RPCCode::kDevCreateStream: { - SyscallDevCreateStream(values, tcodes, num_args); - break; - } - case RPCCode::kDevFreeStream: { - SyscallDevFreeStream(values, tcodes, num_args); - break; - } - case RPCCode::kDevStreamSync: { - SyscallDevStreamSync(values, tcodes, num_args); - break; - } - case RPCCode::kDevSetStream: { - SyscallDevSetStream(values, tcodes, num_args); - break; - } - case RPCCode::kCopyAmongRemote: { - SyscallCopyAmongRemote(values, tcodes, num_args); - break; - } - default: { - ret_handler_->ReturnException("Syscall not recognized"); - break; - } - } - } - - void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 1); - MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); - void* handle = values[0].v_handle; - int call_ecode = TVMObjectFree(handle); - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 1); - MINRPC_CHECK(tcodes[0] == kTVMStr); - void* handle; - int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); - - if (call_ecode == 0) { - ret_handler_->ReturnHandle(handle); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 3); - // from dltensor - MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle); - // to dltensor - MINRPC_CHECK(tcodes[1] == kTVMDLTensorHandle); - // stream - MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle); - - void* from = values[0].v_handle; - void* to = values[1].v_handle; - TVMStreamHandle stream = values[2].v_handle; - - int call_ecode = TVMDeviceCopyDataFromTo(reinterpret_cast(from), - reinterpret_cast(to), stream); - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 4); - MINRPC_CHECK(tcodes[0] == kDLDevice); - MINRPC_CHECK(tcodes[1] == kDLInt); - MINRPC_CHECK(tcodes[2] == kDLInt); - MINRPC_CHECK(tcodes[3] == kTVMDataType); - - DLDevice dev = values[0].v_device; - int64_t nbytes = values[1].v_int64; - int64_t alignment = values[2].v_int64; - DLDataType type_hint = values[3].v_type; - - void* handle; - int call_ecode = TVMDeviceAllocDataSpace(dev, nbytes, alignment, type_hint, &handle); - - if (call_ecode == 0) { - ret_handler_->ReturnHandle(handle); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevAllocDataWithScope(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 2); - MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle); - MINRPC_CHECK(tcodes[1] == kTVMNullptr || tcodes[1] == kTVMStr); - - DLTensor* arr = static_cast(values[0].v_handle); - const char* mem_scope = (tcodes[1] == kTVMNullptr ? nullptr : values[1].v_str); - void* handle; - int call_ecode = TVMDeviceAllocDataSpaceWithScope(arr->device, arr->ndim, arr->shape, - arr->dtype, mem_scope, &handle); - if (call_ecode == 0) { - ret_handler_->ReturnHandle(handle); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 2); - MINRPC_CHECK(tcodes[0] == kDLDevice); - MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); - - DLDevice dev = values[0].v_device; - void* handle = values[1].v_handle; - - int call_ecode = TVMDeviceFreeDataSpace(dev, handle); - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevCreateStream(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 1); - MINRPC_CHECK(tcodes[0] == kDLDevice); - - DLDevice dev = values[0].v_device; - void* handle; - - int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); - - if (call_ecode == 0) { - ret_handler_->ReturnHandle(handle); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevFreeStream(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 2); - MINRPC_CHECK(tcodes[0] == kDLDevice); - MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); - - DLDevice dev = values[0].v_device; - void* handle = values[1].v_handle; - - int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 2); - MINRPC_CHECK(tcodes[0] == kDLDevice); - MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); - - DLDevice dev = values[0].v_device; - void* handle = values[1].v_handle; - - int call_ecode = TVMSynchronize(dev.device_type, dev.device_id, handle); - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void SyscallDevSetStream(TVMValue* values, int* tcodes, int num_args) { - MINRPC_CHECK(num_args == 2); - MINRPC_CHECK(tcodes[0] == kDLDevice); - MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); - - DLDevice dev = values[0].v_device; - void* handle = values[1].v_handle; - - int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); - - if (call_ecode == 0) { - ret_handler_->ReturnVoid(); - } else { - ret_handler_->ReturnLastTVMError(); - } - } - - void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { - ret_handler_->ThrowError(code, info); - } - - MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; } - - private: - template - int ReadArray(T* data, size_t count) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return ReadRawBytes(data, sizeof(T) * count); - } - - int ReadRawBytes(void* data, size_t size) { - uint8_t* buf = static_cast(data); - size_t ndone = 0; - while (ndone < size) { - ssize_t ret = io_->PosixRead(buf, size - ndone); - if (ret <= 0) return ret; - ndone += ret; - buf += ret; - } - return 1; - } - - TIOHandler* io_; - MinRPCReturnInterface* ret_handler_; -}; - +} // namespace details /*! * \brief A minimum RPC server that only depends on the tvm C runtime.. * @@ -544,180 +57,61 @@ class MinRPCExecute : public MinRPCExecInterface { * - MessageStart(num_bytes), MessageDone(): framing APIs. * - Exit: exit with status code. */ -template class Allocator = detail::PageAllocator> +template class Allocator = details::PageAllocator> class MinRPCServer { public: using PageAllocator = Allocator; - /*! - * \brief Constructor. - * \param io The IO handler. - */ - MinRPCServer(TIOHandler* io, std::unique_ptr&& exec_handler) - : io_(io), arena_(PageAllocator(io_)), exec_handler_(std::move(exec_handler)) {} + using FServerHandler = ffi::TypedFunction; - explicit MinRPCServer(TIOHandler* io) - : io_(io), - arena_(PageAllocator(io)), - ret_handler_(new MinRPCReturns(io_)), - exec_handler_(std::unique_ptr( - new MinRPCExecute(io_, ret_handler_))) {} - - ~MinRPCServer() { - if (ret_handler_ != nullptr) { - delete ret_handler_; - } + explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io_)) { + auto fsend = ffi::Function::FromTyped([this](TVMFFIByteArray* bytes) { + return io_->PosixWrite(reinterpret_cast(bytes->data), bytes->size); + }); + auto fcreate = tvm::ffi::Function::GetGlobalRequired("rpc.CreateEventDrivenServer"); + ffi::Any value = fcreate(fsend, "MinRPCServer", ""); + fserver_handler_ = value.cast(); } - /*! \brief Process a single request. + /*! + * \brief Process a single request. * * \return true when the server should continue processing requests. false when it should be * shutdown. */ bool ProcessOnePacket() { - RPCCode code; uint64_t packet_len; - arena_.RecycleAll(); allow_clean_shutdown_ = true; - Read(&packet_len); if (packet_len == 0) return true; - Read(&code); - allow_clean_shutdown_ = false; - - if (code >= RPCCode::kSyscallCodeStart) { - HandleSyscallFunc(code); - } else { - switch (code) { - case RPCCode::kCallFunc: { - HandleNormalCallFunc(); - break; - } - case RPCCode::kInitServer: { - HandleInitServer(); - break; - } - case RPCCode::kCopyFromRemote: { - HandleCopyFromRemote(); - break; - } - case RPCCode::kCopyToRemote: { - HandleCopyToRemote(); - break; - } - case RPCCode::kShutdown: { - Shutdown(); - return false; - } - default: { - this->ThrowError(RPCServerStatus::kUnknownRPCCode); - break; - } + char* read_buffer = this->ArenaAlloc(sizeof(uint64_t) + packet_len); + // copy header into read buffer + std::memcpy(read_buffer, &packet_len, sizeof(uint64_t)); + // read the rest of the packet + ReadRawBytes(read_buffer + sizeof(uint64_t), packet_len); + // setup write flags + int write_flags = 3; + TVMFFIByteArray read_bytes{read_buffer, sizeof(uint64_t) + static_cast(packet_len)}; + int status = fserver_handler_(&read_bytes, write_flags); + + while (status == 2) { + TVMFFIByteArray write_bytes{nullptr, 0}; + // continue call handler until it have nothing to write + status = fserver_handler_(&write_bytes, write_flags); + if (status == 0) { + this->Shutdown(); + return false; } } - return true; } - void HandleInitServer() { - uint64_t len; - Read(&len); - char* proto_ver = ArenaAlloc(len + 1); - ReadArray(proto_ver, len); - TVMValue* values; - int* tcodes; - int num_args; - RecvPackedSeq(&values, &tcodes, &num_args); - exec_handler_->InitServer(num_args); - } - void Shutdown() { arena_.FreeAll(); io_->Close(); } - void HandleNormalCallFunc() { - uint64_t call_handle; - TVMValue* values; - int* tcodes; - int num_args; - - Read(&call_handle); - RecvPackedSeq(&values, &tcodes, &num_args); - exec_handler_->NormalCallFunc(call_handle, values, tcodes, num_args); - } - - void HandleCopyFromRemote() { - DLTensor* arr = ArenaAlloc(1); - uint64_t data_handle; - Read(&data_handle); - arr->data = reinterpret_cast(data_handle); - Read(&(arr->device)); - Read(&(arr->ndim)); - Read(&(arr->dtype)); - arr->shape = ArenaAlloc(arr->ndim); - ReadArray(arr->shape, arr->ndim); - arr->strides = nullptr; - Read(&(arr->byte_offset)); - - uint64_t num_bytes; - Read(&num_bytes); - - uint8_t* data_ptr; - if (arr->device.device_type == kDLCPU) { - data_ptr = reinterpret_cast(data_handle) + arr->byte_offset; - } else { - data_ptr = ArenaAlloc(num_bytes); - } - - exec_handler_->CopyFromRemote(arr, num_bytes, data_ptr); - } - - void HandleCopyToRemote() { - DLTensor* arr = ArenaAlloc(1); - uint64_t data_handle; - Read(&data_handle); - arr->data = reinterpret_cast(data_handle); - Read(&(arr->device)); - Read(&(arr->ndim)); - Read(&(arr->dtype)); - arr->shape = ArenaAlloc(arr->ndim); - ReadArray(arr->shape, arr->ndim); - arr->strides = nullptr; - Read(&(arr->byte_offset)); - uint64_t num_bytes; - Read(&num_bytes); - int ret; - if (arr->device.device_type == kDLCPU) { - uint8_t* dptr = reinterpret_cast(data_handle) + arr->byte_offset; - ret = exec_handler_->CopyToRemote(arr, num_bytes, dptr); - } else { - uint8_t* temp_data = ArenaAlloc(num_bytes); - ret = exec_handler_->CopyToRemote(arr, num_bytes, temp_data); - } - if (ret == 0) { - if (allow_clean_shutdown_) { - Shutdown(); - io_->Exit(0); - } else { - this->ThrowError(RPCServerStatus::kReadError); - } - } - if (ret == -1) { - this->ThrowError(RPCServerStatus::kReadError); - } - } - - void HandleSyscallFunc(RPCCode code) { - TVMValue* values; - int* tcodes; - int num_args; - RecvPackedSeq(&values, &tcodes, &num_args); - - exec_handler_->SysCallFunc(code, values, tcodes, num_args); - } - void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { io_->Exit(static_cast(code)); } @@ -736,32 +130,7 @@ class MinRPCServer { ReadRawBytes(data, sizeof(T)); } - template - void ReadArray(T* data, size_t count) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return ReadRawBytes(data, sizeof(T) * count); - } - - void ReadObject(int* tcode, TVMValue* value) { - // handles RPCObject in minRPC - // NOTE: object needs to be supported by C runtime - // because minrpc's restriction of C only - // we only handle RPCObjectRef - uint32_t type_index; - Read(&type_index); - MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex); - uint64_t object_handle; - Read(&object_handle); - tcode[0] = kTVMObjectHandle; - value[0].v_handle = reinterpret_cast(object_handle); - } - private: - void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { - RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); - } - void ReadRawBytes(void* data, size_t size) { uint8_t* buf = static_cast(data); size_t ndone = 0; @@ -783,18 +152,17 @@ class MinRPCServer { } } + /*! \brief server handler. */ + FServerHandler fserver_handler_; /*! \brief IO handler. */ TIOHandler* io_; /*! \brief internal arena. */ support::GenericArena arena_; - MinRPCReturns* ret_handler_ = nullptr; - std::unique_ptr exec_handler_; /*! \brief Whether we are in a state that allows clean shutdown. */ bool allow_clean_shutdown_{true}; - static_assert(DMLC_LITTLE_ENDIAN == 1, "MinRPC only works on little endian."); }; -namespace detail { +namespace details { // Internal allocator that redirects alloc to TVM's C API. template class PageAllocator { @@ -805,10 +173,9 @@ class PageAllocator { ArenaPageHeader* allocate(size_t min_size) { size_t npages = ((min_size + kPageSize - 1) / kPageSize); - void* data; + void* data = malloc(npages * kPageSize); - if (TVMDeviceAllocDataSpace(DLDevice{kDLCPU, 0}, npages * kPageSize, kPageAlign, - DLDataType{kDLInt, 1, 1}, &data) != 0) { + if (data == nullptr) { io_->Exit(static_cast(RPCServerStatus::kAllocError)); } @@ -818,11 +185,7 @@ class PageAllocator { return header; } - void deallocate(ArenaPageHeader* page) { - if (TVMDeviceFreeDataSpace(DLDevice{kDLCPU, 0}, page) != 0) { - io_->Exit(static_cast(RPCServerStatus::kAllocError)); - } - } + void deallocate(ArenaPageHeader* page) { free(page); } static const constexpr int kPageSize = 2 << 10; static const constexpr int kPageAlign = 8; @@ -830,7 +193,7 @@ class PageAllocator { private: TIOHandler* io_; }; -} // namespace detail +} // namespace details } // namespace runtime } // namespace tvm diff --git a/src/runtime/minrpc/minrpc_server_logging.h b/src/runtime/minrpc/minrpc_server_logging.h deleted file mode 100644 index 89650efe9a1b..000000000000 --- a/src/runtime/minrpc/minrpc_server_logging.h +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ -#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ - -#include -#include - -#include "minrpc_logger.h" -#include "minrpc_server.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief A minimum RPC server that logs the received commands. - * - * \tparam TIOHandler IO provider to provide io handling. - */ -template -class MinRPCServerWithLog { - public: - explicit MinRPCServerWithLog(TIOHandler* io) - : ret_handler_(io), - ret_handler_wlog_(&ret_handler_, &logger_), - exec_handler_(io, &ret_handler_wlog_), - exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)), - next_(io, std::move(exec_handler_ptr_)) {} - - bool ProcessOnePacket() { return next_.ProcessOnePacket(); } - - private: - Logger logger_; - MinRPCReturns ret_handler_; - MinRPCExecute exec_handler_; - MinRPCReturnsWithLog ret_handler_wlog_; - std::unique_ptr exec_handler_ptr_; - MinRPCServer next_; -}; - -/*! - * \brief A minimum RPC server that only logs the outgoing commands and received responses. - * (Does not process the packets or respond to them.) - * - * \tparam TIOHandler IO provider to provide io handling. - */ -template class Allocator = detail::PageAllocator> -class MinRPCSniffer { - public: - using PageAllocator = Allocator; - explicit MinRPCSniffer(TIOHandler* io) - : io_(io), - arena_(PageAllocator(io_)), - ret_handler_(io_), - ret_handler_wlog_(&ret_handler_, &logger_), - exec_handler_(&ret_handler_wlog_), - exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)), - next_(io_, std::move(exec_handler_ptr_)) {} - - bool ProcessOnePacket() { return next_.ProcessOnePacket(); } - - void ProcessOneResponse() { - RPCCode code; - uint64_t packet_len = 0; - - if (!Read(&packet_len)) return; - if (packet_len == 0) { - OutputLog(); - return; - } - if (!Read(&code)) return; - switch (code) { - case RPCCode::kReturn: { - int32_t num_args; - int* type_codes; - TVMValue* values; - RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); - ret_handler_wlog_.ReturnPackedSeq(values, type_codes, num_args); - break; - } - case RPCCode::kException: { - ret_handler_wlog_.ReturnException(""); - break; - } - default: { - OutputLog(); - break; - } - } - } - - void OutputLog() { logger_.OutputLog(); } - - void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { - logger_.Log("-> "); - logger_.Log(RPCServerStatusToString(code)); - OutputLog(); - } - - template - T* ArenaAlloc(int count) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return arena_.template allocate_(count); - } - - template - bool Read(T* data) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return ReadRawBytes(data, sizeof(T)); - } - - template - bool ReadArray(T* data, size_t count) { - static_assert(std::is_trivial::value && std::is_standard_layout::value, - "need to be trival"); - return ReadRawBytes(data, sizeof(T) * count); - } - - void ReadObject(int* tcode, TVMValue* value) { - this->ThrowError(RPCServerStatus::kUnknownTypeCode); - } - - private: - bool ReadRawBytes(void* data, size_t size) { - uint8_t* buf = reinterpret_cast(data); - size_t ndone = 0; - while (ndone < size) { - ssize_t ret = io_->PosixRead(buf, size - ndone); - if (ret <= 0) { - this->ThrowError(RPCServerStatus::kReadError); - return false; - } - ndone += ret; - buf += ret; - } - return true; - } - - Logger logger_; - TIOHandler* io_; - support::GenericArena arena_; - MinRPCReturnsNoOp ret_handler_; - MinRPCReturnsWithLog ret_handler_wlog_; - MinRPCExecuteNoOp exec_handler_; - std::unique_ptr exec_handler_ptr_; - MinRPCServer next_; -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ diff --git a/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc b/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc index b513d4b7cc1b..014704e97077 100644 --- a/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc +++ b/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc @@ -17,9 +17,6 @@ * under the License. */ -// Disable constructor to bring minimum dep on c++ABI. -#define TVM_ARENA_HAS_DESTRUCTOR 0 - #include #include diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index ff3c9f22fdaa..41bb40b3f2ec 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -35,14 +35,6 @@ namespace runtime { /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; -/*! - * \brief type index of kRuntimeRPCObjectRefTypeIndex - * \note this needs to be kept consistent with runtime/object.h - * but we explicitly declare it here because minrpc needs to be minimum dep - * only c C API - */ -constexpr const int kRuntimeRPCObjectRefTypeIndex = 9; - // When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered. const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX; @@ -83,7 +75,7 @@ enum class RPCServerStatus : int { kInvalidTypeCodeNDArray, kInvalidDLTensorFieldStride, kInvalidDLTensorFieldByteOffset, - kUnknownTypeCode, + kUnknownTypeIndex, kUnknownRPCCode, kRPCCodeNotSupported, kUnknownRPCSyscall, @@ -159,8 +151,8 @@ inline const char* RPCServerStatusToString(RPCServerStatus status) { case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { return "kInvalidDLTensorFieldByteOffset"; } - case RPCServerStatus::kUnknownTypeCode: - return "kUnknownTypeCode"; + case RPCServerStatus::kUnknownTypeIndex: + return "kUnknownTypeIndex"; case RPCServerStatus::kUnknownRPCCode: return "kUnknownRPCCode"; case RPCServerStatus::kRPCCodeNotSupported: @@ -242,10 +234,10 @@ struct RPCReference { * \return The total number of bytes. */ template - static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, - int num_args, bool client_mode, TChannel* channel) { + static uint64_t PackedSeqGetNumBytes(const TVMFFIAny* packed_args, int num_args, bool client_mode, + TChannel* channel) { PackedSeqNumBytesGetter getter(channel); - SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); + SendPackedSeq(packed_args, num_args, client_mode, &getter); return getter.num_bytes(); } @@ -303,93 +295,89 @@ struct RPCReference { * Note that we cannot simply take these argument out(as the handle) * refers to a value on the remote(instead of local). * - * \param arg_values The values to be sent over. - * \param type_codes The type codes to be sent over. + * \param packed_args The values to be sent over. * \param num_args Number of argument. * \param client_mode Whether it is a client to server call. * \param channel The communication channel handler. * \tparam TChannel The type of the communication channel. */ template - static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, - bool client_mode, TChannel* channel) { + static void SendPackedSeq(const TVMFFIAny* packed_args, int num_args, bool client_mode, + TChannel* channel) { channel->Write(num_args); - channel->WriteArray(type_codes, num_args); // Argument packing. for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - TVMValue value = arg_values[i]; - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - channel->template Write(value.v_int64); + int32_t type_index = packed_args[i].type_index; + channel->template Write(type_index); + switch (type_index) { + case ffi::TypeIndex::kTVMFFINone: { + break; + } + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: + case ffi::TypeIndex::kTVMFFIFloat: { + channel->template Write(packed_args[i].v_int64); break; } - case kTVMArgBool: { - channel->template Write(value.v_int64); + case ffi::TypeIndex::kTVMFFIOpaquePtr: { + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(packed_args[i].v_ptr); + channel->template Write(handle); break; } - case kTVMDataType: { - channel->Write(value.v_type); + case ffi::TypeIndex::kTVMFFIDataType: { + channel->Write(packed_args[i].v_dtype); // padding int32_t padding = 0; channel->template Write(padding); break; } - case kDLDevice: { - channel->Write(value.v_device); + case ffi::TypeIndex::kTVMFFIDevice: { + channel->Write(packed_args[i].v_device); break; } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: { + case ffi::TypeIndex::kTVMFFIFunction: + case ffi::TypeIndex::kTVMFFIModule: { if (!client_mode) { channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject); } // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); - channel->Write(handle); - break; - } - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); + uint64_t handle = reinterpret_cast(packed_args[i].v_obj); channel->Write(handle); break; } - case kTVMNDArrayHandle: { + + case ffi::TypeIndex::kTVMFFINDArray: { channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); break; } - case kTVMDLTensorHandle: { - DLTensor* arr = static_cast(value.v_handle); + case ffi::TypeIndex::kTVMFFIDLTensorPtr: { + DLTensor* arr = static_cast(packed_args[i].v_ptr); SendDLTensor(channel, arr); break; } - case kTVMNullptr: - break; - case kTVMStr: { - const char* s = value.v_str; + case ffi::TypeIndex::kTVMFFIRawStr: { + const char* s = packed_args[i].v_c_str; uint64_t len = StrLength(s); channel->Write(len); channel->WriteArray(s, len); break; } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(arg_values[i].v_handle); + case ffi::TypeIndex::kTVMFFIByteArrayPtr: { + TVMFFIByteArray* bytes = static_cast(packed_args[i].v_ptr); uint64_t len = bytes->size; channel->Write(len); channel->WriteArray(bytes->data, len); break; } - case kTVMObjectHandle: { - channel->WriteObject(static_cast(value.v_handle)); - break; - } default: { - channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + channel->WriteObject(reinterpret_cast(packed_args[i].v_obj)); + } else { + channel->ThrowError(RPCServerStatus::kUnknownTypeIndex); + } break; } } @@ -399,102 +387,95 @@ struct RPCReference { /*! * \brief Receive packed seq from the channel. * - * \param out_arg_values The values to be received. - * \param out_tcodes The type codes to be received. + * \param out_packed_args The values to be received. * \param out_num_args Number of argument. * \param channel The communication channel handler. * \tparam TChannel The type of the communication channel. * \note The temporary space are populated via an arena inside channel. */ template - static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args, - TChannel* channel) { + static void RecvPackedSeq(TVMFFIAny** out_packed_args, int32_t* out_num_args, TChannel* channel) { // receive number of args - int num_args; + int32_t num_args; channel->Read(&num_args); *out_num_args = num_args; - if (num_args == 0) { - *out_values = nullptr; - *out_tcodes = nullptr; + *out_packed_args = nullptr; return; } - TVMValue* values = channel->template ArenaAlloc(num_args); - int* tcodes = channel->template ArenaAlloc(num_args); - *out_values = values; - *out_tcodes = tcodes; - - // receive type code. - channel->ReadArray(tcodes, num_args); + TVMFFIAny* packed_args = channel->template ArenaAlloc(num_args); + *out_packed_args = packed_args; // receive arguments - for (int i = 0; i < num_args; ++i) { - auto& value = values[i]; - switch (tcodes[i]) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - channel->template Read(&(value.v_int64)); + for (int32_t i = 0; i < num_args; ++i) { + int32_t type_index; + channel->Read(&type_index); + packed_args[i].type_index = type_index; + switch (type_index) { + case ffi::TypeIndex::kTVMFFINone: { + break; + } + case ffi::TypeIndex::kTVMFFIBool: + case ffi::TypeIndex::kTVMFFIInt: + case ffi::TypeIndex::kTVMFFIFloat: { + channel->template Read(&(packed_args[i].v_int64)); break; } - case kTVMArgBool: { - channel->template Read(&(value.v_int64)); + case ffi::TypeIndex::kTVMFFIOpaquePtr: { + uint64_t handle; + channel->Read(&handle); + packed_args[i].v_ptr = reinterpret_cast(handle); break; } - case kTVMDataType: { - channel->Read(&(value.v_type)); + case ffi::TypeIndex::kTVMFFIDataType: { + channel->Read(&(packed_args[i].v_dtype)); int32_t padding = 0; channel->template Read(&padding); break; } - case kDLDevice: { - channel->Read(&(value.v_device)); + case ffi::TypeIndex::kTVMFFIDevice: { + channel->Read(&(packed_args[i].v_device)); break; } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: - case kTVMOpaqueHandle: { + case ffi::TypeIndex::kTVMFFIFunction: + case ffi::TypeIndex::kTVMFFIModule: { // always send handle in 64 bit. uint64_t handle; channel->Read(&handle); - value.v_handle = reinterpret_cast(handle); - break; - } - case kTVMNullptr: { - value.v_handle = nullptr; + packed_args[i].v_obj = reinterpret_cast(handle); break; } - case kTVMStr: { + case ffi::TypeIndex::kTVMFFIRawStr: { uint64_t len; channel->Read(&len); char* str = channel->template ArenaAlloc(len + 1); str[len] = '\0'; channel->ReadArray(str, len); - value.v_str = str; + packed_args[i].v_c_str = str; break; } - case kTVMBytes: { + case ffi::TypeIndex::kTVMFFIByteArrayPtr: { uint64_t len; channel->Read(&len); - TVMByteArray* arr = channel->template ArenaAlloc(1); + TVMFFIByteArray* arr = channel->template ArenaAlloc(1); char* data = channel->template ArenaAlloc(len); arr->size = len; arr->data = data; channel->ReadArray(data, len); - value.v_handle = arr; + packed_args[i].v_ptr = arr; break; } - case kTVMDLTensorHandle: { - value.v_handle = ReceiveDLTensor(channel); - break; - } - case kTVMObjectHandle: { - channel->ReadObject(&tcodes[i], &value); + case ffi::TypeIndex::kTVMFFIDLTensorPtr: { + packed_args[i].v_ptr = ReceiveDLTensor(channel); break; } default: { - channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + channel->ReadObject(&(packed_args[i])); + } else { + channel->ThrowError(RPCServerStatus::kUnknownTypeIndex); + } break; } } @@ -512,16 +493,17 @@ struct RPCReference { static void ReturnException(const char* msg, TChannel* channel) { RPCCode code = RPCCode::kException; int32_t num_args = 1; - int32_t tcode = kTVMStr; + int32_t type_index = ffi::TypeIndex::kTVMFFIRawStr; uint64_t len = StrLength(msg); - uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len; + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + sizeof(type_index) + sizeof(len) + len; channel->MessageStart(packet_nbytes); channel->Write(packet_nbytes); channel->Write(code); channel->Write(num_args); - channel->Write(tcode); + channel->Write(type_index); channel->Write(len); channel->WriteArray(msg, len); channel->MessageDone(); @@ -535,17 +517,16 @@ struct RPCReference { * \tparam TChannel The type of the communication channel. */ template - static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, - TChannel* channel) { + static void ReturnPackedSeq(const TVMFFIAny* packed_args, int num_args, TChannel* channel) { RPCCode code = RPCCode::kReturn; uint64_t packet_nbytes = - sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel); + sizeof(code) + PackedSeqGetNumBytes(packed_args, num_args, false, channel); channel->MessageStart(packet_nbytes); channel->Write(packet_nbytes); channel->Write(code); - SendPackedSeq(arg_values, type_codes, num_args, false, channel); + SendPackedSeq(packed_args, num_args, false, channel); channel->MessageDone(); } @@ -558,16 +539,16 @@ struct RPCReference { template static void ReturnVoid(TChannel* channel) { int32_t num_args = 1; - int32_t tcode = kTVMNullptr; + int32_t type_index = ffi::TypeIndex::kTVMFFINone; RPCCode code = RPCCode::kReturn; - uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(type_index); channel->MessageStart(packet_nbytes); channel->Write(packet_nbytes); channel->Write(code); channel->Write(num_args); - channel->Write(tcode); + channel->Write(type_index); channel->MessageDone(); } }; diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 5ba2248f7627..d2bc4b2c297a 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -21,9 +21,8 @@ * \file module.cc * \brief TVM module system */ +#include #include -#include -#include #include #include @@ -166,52 +165,52 @@ bool RuntimeEnabled(const String& target_str) { return tvm::ffi::Function::GetGlobal(f_name).has_value(); } -TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); +TVM_FFI_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); -TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { return mod->GetSource(fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { return static_cast(mod->imports().size()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { return mod->imports().at(index); }); -TVM_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleClearImports").set_body_typed([](Module mod) { mod->ClearImports(); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) { return mod->GetFormat(); }); -TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); -TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleSaveToFile") .set_body_typed([](Module mod, String name, String fmt) { mod->SaveToFile(name, fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetPropertyMask").set_body_typed([](Module mod) { return mod->GetPropertyMask(); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImplementsFunction") .set_body_typed([](Module mod, String name, bool query_imports) { return mod->ImplementsFunction(std::move(name), query_imports); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetFunction") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleGetFunction") .set_body_typed([](Module mod, String name, bool query_imports) { return mod->GetFunction(name, query_imports); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImport").set_body_typed([](Module mod, Module other) { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImport").set_body_typed([](Module mod, Module other) { mod->Import(other); }); diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c90c3aaa95f9..2bf56e876164 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -21,13 +21,12 @@ * \file ndarray.cc * \brief NDArray container infratructure. */ -#include +#include +#include #include #include #include -#include -#include "runtime_base.h" #include "tvm/runtime/data_type.h" namespace tvm { @@ -73,10 +72,11 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { DeviceAPI::Get(handle->device)->StreamSync(handle->device, nullptr); } -void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { +void NDArray::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { size_t arr_size = GetDataSize(*handle); ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; - ICHECK(IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; DLTensor to; to.data = const_cast(data); @@ -87,12 +87,12 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { to.strides = nullptr; to.byte_offset = 0; - DeviceAPI::Get(handle->device)->CopyDataFromTo(const_cast(handle), &to, nullptr); + DeviceAPI::Get(handle->device)->CopyDataFromTo(const_cast(handle), &to, stream); // Synchronize in case data become unavailable later. - DeviceAPI::Get(handle->device)->StreamSync(handle->device, nullptr); + DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } -NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope) { +NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional mem_scope) { struct DeviceAPIAlloc { void AllocData(DLTensor* tensor, ffi::Optional mem_scope) { tensor->data = DeviceAPI::Get(tensor->device) @@ -106,18 +106,7 @@ NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional< return ffi::NDArray::FromNDAlloc(DeviceAPIAlloc(), shape, dtype, dev, mem_scope); } -struct NDArray::Internal { - // Implementation of API function - static DLTensor* MoveToFFIHandle(NDArray arr) { - DLTensor* handle = NDArray::FFIGetHandle(arr); - // move and discard as handle is already obtained in FFIGetHandle - ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(arr)); - return handle; - } - static void FFIDecRef(TVMArrayHandle tensor) { NDArray::FFIDecRef(tensor); } -}; - -NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype, +NDArray NDArray::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_byte_offset) const { ICHECK(data_ != nullptr); @@ -152,7 +141,7 @@ NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype, << "This would occupy bytes " << relative_byte_offset << " <= i_byte < " << (relative_byte_offset + view_size) << " within the backing array. " << "However, the NDArray being viewed only contains " << curr_size << " bytes (shape = " - << ShapeTuple(curr_dl_tensor.shape, curr_dl_tensor.shape + curr_dl_tensor.ndim) + << ffi::Shape(curr_dl_tensor.shape, curr_dl_tensor.shape + curr_dl_tensor.ndim) << ", dtype= " << curr_dl_tensor.dtype << ")."; // helper allocator class that retains ref count of original NDArray @@ -178,7 +167,7 @@ NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype, void NDArray::CopyToBytes(void* data, size_t nbytes) const { ICHECK(data != nullptr); ICHECK(data_ != nullptr); - ArrayCopyToBytes(get_mutable(), data, nbytes); + NDArray::CopyToBytes(get_mutable(), data, nbytes); } void NDArray::CopyFromBytes(const void* data, size_t nbytes) { @@ -191,7 +180,7 @@ NDArray NDArray::CopyTo(const Device& dev, Optional mem_scope) const { ICHECK(data_ != nullptr); const DLTensor* dptr = operator->(); NDArray ret = - Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev, mem_scope); + Empty(ffi::Shape(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev, mem_scope); this->CopyTo(ret); Device copy_gpu_dev = dptr->device.device_type != kDLCPU ? dptr->device : dev; DeviceAPI::Get(copy_gpu_dev)->StreamSync(copy_gpu_dev, nullptr); @@ -222,79 +211,19 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str using namespace tvm::runtime; -int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { - API_BEGIN(); - *out_tindex = - tvm::ffi::details::ObjectUnsafe::GetHeader(TVMArrayHandleToObjectHandle(handle))->type_index; - API_END(); -} - -int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, - int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { - API_BEGIN(); - DLDataType dtype; - dtype.code = static_cast(dtype_code); - dtype.bits = static_cast(dtype_bits); - dtype.lanes = static_cast(dtype_lanes); - tvm::Device dev; - dev.device_type = static_cast(device_type); - dev.device_id = device_id; - auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev); - - *out = NDArray::Internal::MoveToFFIHandle(ndarray); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body_typed(NDArray::Empty); +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body_typed(NDArray::Empty); -TVM_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_method(&NDArray::CreateView); +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCreateView").set_body_method(&NDArray::CreateView); -int TVMArrayFree(TVMArrayHandle handle) { - API_BEGIN(); - NDArray::Internal::FFIDecRef(handle); - API_END(); -} - -int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream) { - API_BEGIN(); - NDArray::CopyFromTo(from, to, stream); - API_END(); -} - -int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { - API_BEGIN(); - *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); - API_END(); -} - -int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { - API_BEGIN(); - *out = static_cast(TVMArrayHandleToObjectHandle(from))->ToDLPack(); - API_END(); -} - -int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { - API_BEGIN(); - ArrayCopyFromBytes(handle, data, nbytes); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.TVMArrayCopyFromBytes") +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyFromBytes") .set_body_typed([](DLTensor* arr, void* data, size_t nbytes) { ArrayCopyFromBytes(arr, data, nbytes); }); -int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) { - API_BEGIN(); - ArrayCopyToBytes(handle, data, nbytes); - API_END(); -} - -TVM_REGISTER_GLOBAL("runtime.TVMArrayCopyToBytes") +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyToBytes") .set_body_typed([](DLTensor* arr, void* data, size_t nbytes) { - ArrayCopyToBytes(arr, data, nbytes); + NDArray::CopyToBytes(arr, data, nbytes); }); -TVM_REGISTER_GLOBAL("runtime.TVMArrayCopyFromTo").set_body_typed([](DLTensor* from, DLTensor* to) { - NDArray::CopyFromTo(from, to); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.TVMArrayCopyFromTo") + .set_body_typed([](DLTensor* from, DLTensor* to) { NDArray::CopyFromTo(from, to); }); diff --git a/src/runtime/object.cc b/src/runtime/object.cc deleted file mode 100644 index 095eee5f5e6b..000000000000 --- a/src/runtime/object.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/runtime/object.cc - * \brief Object type management system. - */ -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "object_internal.h" -#include "runtime_base.h" - -namespace tvm { -namespace runtime { - -TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) { - return static_cast(ObjectPtrHash()(obj)); -}); - -} // namespace runtime -} // namespace tvm - -int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { - API_BEGIN(); - ICHECK(obj != nullptr); - out_tindex[0] = static_cast(obj)->type_index(); - API_END(); -} - -int TVMObjectRetain(TVMObjectHandle obj) { - API_BEGIN(); - tvm::runtime::ObjectInternal::ObjectRetain(obj); - API_END(); -} - -int TVMObjectFree(TVMObjectHandle obj) { - API_BEGIN(); - tvm::runtime::ObjectInternal::ObjectFree(obj); - API_END(); -} - -int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { - API_BEGIN(); - *is_derived = [&]() { - if (child_type_index == parent_type_index) return true; - if (child_type_index < parent_type_index) return false; - const TVMFFITypeInfo* child_type_info = TVMFFIGetTypeInfo(child_type_index); - const TVMFFITypeInfo* parent_type_info = TVMFFIGetTypeInfo(parent_type_index); - return (child_type_info->type_depth > parent_type_info->type_depth && - child_type_info->type_acenstors[parent_type_info->type_depth] == - static_cast(parent_type_index)); - }(); - API_END(); -} - -int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { - API_BEGIN(); - out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); - API_END(); -} - -int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) { - API_BEGIN(); - auto key = tvm::runtime::Object::TypeIndex2Key(tindex); - *out_type_key = static_cast(malloc(key.size() + 1)); - strncpy(*out_type_key, key.c_str(), key.size() + 1); - API_END(); -} diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h deleted file mode 100644 index 40e4e2fb4855..000000000000 --- a/src/runtime/object_internal.h +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/runtime/object_internal.h - * \brief Expose a few functions for CFFI purposes. - * This file is not intended to be used - */ -#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_ -#define TVM_RUNTIME_OBJECT_INTERNAL_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Internal object namespace to expose - * certain util functions for FFI. - */ -class ObjectInternal { - public: - /*! - * \brief Retain an object handle. - */ - static void ObjectRetain(TVMObjectHandle obj) { - if (obj != nullptr) { - // static_cast(obj)->IncRef(); - tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(obj); - } - } - - /*! - * \brief Free an object handle. - */ - static void ObjectFree(TVMObjectHandle obj) { - if (obj != nullptr) { - // static_cast(obj)->DecRef(); - tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(obj); - } - } - /*! - * \brief Check of obj derives from the type indicated by type index. - * \param obj The original object. - * \param type_index The type index of interest. - * \return The derivation checking result. - */ - // static bool DerivedFrom(const Object* obj, uint32_t type_index) { - // return obj->DerivedFrom(type_index); - // } - /*! - * \brief Expose TypeKey2Index - * \param type_key The original type key. - * \return the corresponding index. - */ - static uint32_t ObjectTypeKey2Index(const std::string& type_key) { - int32_t type_index; - TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); - return static_cast(type_index); - } - /*! - * \brief Convert ModuleHandle to module node pointer. - * \param handle The module handle. - * \return the corresponding module node pointer. - */ - static ModuleNode* GetModuleNode(TVMModuleHandle handle) { - // NOTE: we will need to convert to Object - // then to ModuleNode in order to get the correct - // address translation - return static_cast(static_cast(handle)); - } -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 6f2a9e610363..dbef2f518f5a 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -24,12 +24,12 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ #define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ -#include +#include +#include #include #include #include #include -#include #include /* There are many OpenCL platforms that do not yet support OpenCL 2.0, @@ -340,8 +340,8 @@ class OpenCLWorkspace : public DeviceAPI { return device_info[GetCLDeviceID(device_id)].image_from_buffer_support; } - void* AllocDataSpaceView(Device dev, void* data, ShapeTuple shape, DLDataType dtype, - Optional mem_scope = NullOpt); + void* AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, DLDataType dtype, + Optional mem_scope = std::nullopt); void FreeDataSpaceView(Device dev, void* ptr); cl_device_id GetCLDeviceID(int device_id); @@ -350,9 +350,9 @@ class OpenCLWorkspace : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; void* AllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) final; void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, - Optional mem_scope = NullOpt) final; + Optional mem_scope = std::nullopt) final; void* AllocDataSpace(Device dev, size_t width, size_t height, DLDataType type_hint, - Optional mem_scope = NullOpt); + Optional mem_scope = std::nullopt); void* GetNativePtr(const tvm::runtime::NDArray& narr); void SetNativePtr(const tvm::runtime::NDArray& narr, void* host_ptr, size_t buf_size); void SetPerfHint(Device dev, cl_uint perf_hint); @@ -360,7 +360,7 @@ class OpenCLWorkspace : public DeviceAPI { void StreamSync(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; - size_t GetDataSize(const DLTensor& arr, Optional mem_scope = NullOpt) final; + size_t GetDataSize(const DLTensor& arr, Optional mem_scope = std::nullopt) final; // cl_mem alloc utils void* AllocCLBuffer(Device dev, size_t size, size_t alignment, DLDataType type_hint); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index a436ede61bc9..000e9a94599e 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -22,8 +22,8 @@ */ #include #include +#include #include -#include #include @@ -358,7 +358,7 @@ size_t OpenCLWorkspace::GetDataSize(const DLTensor& arr, Optional mem_sc mem_scope.value(), row_align); } -void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, ShapeTuple shape, +void* OpenCLWorkspace::AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, DLDataType dtype, Optional mem_scope) { cl::BufferDescriptor* desc = static_cast(data); @@ -760,7 +760,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; } -TVM_REGISTER_GLOBAL("device_api.opencl.alloc_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.opencl.alloc_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -788,7 +788,7 @@ TVM_REGISTER_GLOBAL("device_api.opencl.alloc_nd") String("global.texture")); }); -TVM_REGISTER_GLOBAL("device_api.opencl.free_nd") +TVM_FFI_REGISTER_GLOBAL("device_api.opencl.free_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); @@ -803,14 +803,15 @@ TVM_REGISTER_GLOBAL("device_api.opencl.free_nd") *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.opencl").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.opencl") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = OpenCLWorkspace::Global(); + *rv = static_cast(ptr); + }); TVM_REGISTER_OBJECT_TYPE(OpenCLTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.opencl").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.opencl").set_body_typed([](Device dev) { return Timer(make_object(dev)); }); @@ -851,7 +852,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { return buf; } - Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint, + Buffer Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) override { if (AllowMemoryScope(mem_scope)) { size_t size = ffi::GetDataSize(shape.Product(), type_hint); @@ -881,7 +882,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { VLOG(1) << "reclaim buffer " << buffer.size; } - void* CreateView(const Buffer& buffer, ShapeTuple shape, DLDataType type_hint, + void* CreateView(const Buffer& buffer, ffi::Shape shape, DLDataType type_hint, const std::string& mem_scope) final { OpenCLWorkspace* ws_ = OpenCLWorkspace::Global(); return ws_->AllocDataSpaceView(buffer.device, buffer.data, shape, type_hint, String(mem_scope)); @@ -893,7 +894,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { } }; -TVM_REGISTER_GLOBAL("DeviceAllocator.opencl") +TVM_FFI_REGISTER_GLOBAL("DeviceAllocator.opencl") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Allocator* alloc = new OpenCLPooledAllocator(); *rv = static_cast(alloc); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 90cdcb48bf96..8e8ee5a43b78 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -23,7 +23,7 @@ #include "opencl_module.h" #include -#include +#include #include #include @@ -146,7 +146,7 @@ ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, for (size_t i = 0; i < info.arg_types.size(); ++i) { DLDataType t = info.arg_types[i]; ICHECK_EQ(t.lanes, 1U); - if (t.code == kTVMOpaqueHandle) { + if (t.code == kDLOpaqueHandle) { // specially store pointer type size in OpenCL driver arg_size[i] = sizeof(void*); } else { @@ -389,10 +389,10 @@ Module OpenCLModuleLoadBinary(void* strm) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 22fc119e0318..198adc6cb216 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ #define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ -#include +#include #include #include diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 28e02a4e3749..7d281694decb 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index ec000524fa00..0068db51d522 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -31,8 +31,8 @@ #ifndef TVM_RUNTIME_PACK_ARGS_H_ #define TVM_RUNTIME_PACK_ARGS_H_ -#include -#include +#include +#include #include #include @@ -134,7 +134,7 @@ enum ArgConvertCode { }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { - ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now"; + ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device function for now"; if (t.code == kDLInt) { if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 32U) return INT64_TO_INT32; @@ -143,7 +143,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { } else if (t.code == kDLFloat) { if (t.bits == 64U) return FLOAT64_TO_FLOAT64; if (t.bits == 32U) return FLOAT64_TO_FLOAT32; - } else if (t.code == kTVMOpaqueHandle) { + } else if (t.code == kDLOpaqueHandle) { return HANDLE_TO_HANDLE; } LOG(FATAL) << "Cannot handle " << t << " as device function argument"; @@ -240,7 +240,6 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector pack_(num_args); int32_t* pack = reinterpret_cast(pack_.data()); int32_t* ptr = pack; - static_assert(sizeof(TVMValue) == 8, "invariant"); static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant"); const TVMFFIAny* raw_args = reinterpret_cast(args.data()); @@ -317,13 +316,13 @@ inline ffi::Function PackFuncVoidAddr(F f, const std::vector& arg_ty inline size_t NumBufferArgs(const std::vector& arg_types) { size_t base = arg_types.size(); for (size_t i = 0; i < arg_types.size(); ++i) { - if (arg_types[i].code != kTVMOpaqueHandle) { + if (arg_types[i].code != kDLOpaqueHandle) { base = i; break; } } for (size_t i = base; i < arg_types.size(); ++i) { - ICHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized"; + ICHECK(arg_types[i].code != kDLOpaqueHandle) << "Device function need to be organized"; } return base; } diff --git a/src/runtime/packed_func.cc b/src/runtime/packed_func.cc deleted file mode 100644 index 63ec7bbc7d47..000000000000 --- a/src/runtime/packed_func.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/runtime/packed_func.cc - * \brief Implementation of non-inlinable ffi::Function pieces. - */ -#include -#include - -namespace tvm { -namespace runtime { - -TVM_REGISTER_OBJECT_TYPE(ffi::FunctionObj); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index c073056fb320..bab1d50db6a9 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -23,9 +23,9 @@ */ #include +#include #include #include -#include #include #include @@ -43,11 +43,11 @@ namespace runtime { class DefaultTimerNode : public TimerNode { public: virtual void Start() { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + DeviceAPI::Get(device_)->StreamSync(device_, nullptr); start_ = std::chrono::high_resolution_clock::now(); } virtual void Stop() { - TVMSynchronize(device_.device_type, device_.device_id, nullptr); + DeviceAPI::Get(device_)->StreamSync(device_, nullptr); duration_ = std::chrono::high_resolution_clock::now() - start_; } virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } @@ -84,7 +84,7 @@ class CPUTimerNode : public TimerNode { }; TVM_REGISTER_OBJECT_TYPE(CPUTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.cpu").set_body_typed([](Device dev) { return Timer(make_object()); }); @@ -115,7 +115,7 @@ Timer Timer::Start(Device dev) { } } -TVM_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); +TVM_FFI_REGISTER_GLOBAL("profiling.start_timer").set_body_typed(Timer::Start); namespace profiling { @@ -283,7 +283,7 @@ String ReportNode::AsCSV() const { s << (*it).second.as()->percent; } else if ((*it).second.as()) { s << (*it).second.as()->ratio; - } else if ((*it).second.as()) { + } else if ((*it).second.as()) { s << "\"" << Downcast((*it).second) << "\""; } } @@ -298,7 +298,7 @@ String ReportNode::AsCSV() const { namespace { void metric_as_json(std::ostream& os, ObjectRef o) { - if (o.as()) { + if (o.as()) { os << "{\"string\":" << "\"" << Downcast(o) << "\"" << "}"; @@ -412,7 +412,7 @@ ObjectRef AggregateMetric(const std::vector& metrics) { sum += metric.as()->ratio; } return ObjectRef(make_object(sum / metrics.size())); - } else if (metrics[0].as()) { + } else if (metrics[0].as()) { for (auto& m : metrics) { if (Downcast(metrics[0]) != Downcast(m)) { return ObjectRef(String("")); @@ -461,7 +461,7 @@ static String print_metric(ObjectRef metric) { set_locale_for_separators(s); s << std::setprecision(2) << metric.as()->ratio; val = s.str(); - } else if (metric.as()) { + } else if (metric.as()) { val = Downcast(metric); } else { LOG(FATAL) << "Cannot print metric of type " << metric->GetTypeKey(); @@ -495,7 +495,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con } } for (const auto& p : aggregates) { - std::unordered_map aggregated; + std::unordered_map aggregated; std::unordered_set metrics; for (auto& call : calls) { for (auto& metric : call) { @@ -788,62 +788,65 @@ TVM_REGISTER_OBJECT_TYPE(ReportNode); TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); -TVM_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); -TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); -TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { + return n->AsCSV(); +}); +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { return n->AsJSON(); }); -TVM_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON); -TVM_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.FromJSON").set_body_typed(Report::FromJSON); +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.DeviceWrapper").set_body_typed([](Device dev) { return DeviceWrapper(dev); }); ffi::Function ProfileFunction(Module mod, std::string func_name, int device_type, int device_id, int warmup_iters, Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable - return ffi::Function::FromPacked([=](const AnyView* args, int32_t num_args, Any* ret) mutable { - ffi::Function f = mod.GetFunction(func_name); - CHECK(f.defined()) << "There is no function called \"" << func_name << "\" in the module"; - Device dev{static_cast(device_type), device_id}; - - // warmup - for (int i = 0; i < warmup_iters; i++) { - f.CallPacked(args, num_args, ret); - } - - for (auto& collector : collectors) { - collector->Init({DeviceWrapper(dev)}); - } - std::vector> results; - results.reserve(collectors.size()); - std::vector> collector_data; - collector_data.reserve(collectors.size()); - for (auto& collector : collectors) { - ObjectRef o = collector->Start(dev); - // If not defined, then the collector cannot time this device. - if (o.defined()) { - collector_data.push_back({collector, o}); - } - } + return ffi::Function::FromPacked( + [=](const ffi::AnyView* args, int32_t num_args, ffi::Any* ret) mutable { + ffi::Function f = mod.GetFunction(func_name); + CHECK(f.defined()) << "There is no function called \"" << func_name << "\" in the module"; + Device dev{static_cast(device_type), device_id}; + + // warmup + for (int i = 0; i < warmup_iters; i++) { + f.CallPacked(args, num_args, ret); + } + + for (auto& collector : collectors) { + collector->Init({DeviceWrapper(dev)}); + } + std::vector> results; + results.reserve(collectors.size()); + std::vector> collector_data; + collector_data.reserve(collectors.size()); + for (auto& collector : collectors) { + ObjectRef o = collector->Start(dev); + // If not defined, then the collector cannot time this device. + if (o.defined()) { + collector_data.push_back({collector, o}); + } + } - // TODO(tkonolige): repeated calls if the runtime is small? - f.CallPacked(args, num_args, ret); + // TODO(tkonolige): repeated calls if the runtime is small? + f.CallPacked(args, num_args, ret); - for (auto& kv : collector_data) { - results.push_back(kv.first->Stop(kv.second)); - } - Map combined_results; - for (auto m : results) { - for (auto p : m) { - // assume that there is no shared metric name between collectors - combined_results.Set(p.first, p.second); - } - } - *ret = combined_results; - }); + for (auto& kv : collector_data) { + results.push_back(kv.first->Stop(kv.second)); + } + Map combined_results; + for (auto m : results) { + for (auto p : m) { + // assume that there is no shared metric name between collectors + combined_results.Set(p.first, p.second); + } + } + *ret = combined_results; + }); } -TVM_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.ProfileFunction") .set_body_typed)>([](Module mod, String func_name, int device_type, int device_id, @@ -867,8 +870,8 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re auto ftimer = [pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, - f_preproc](const AnyView* args, int num_args, Any* rv) mutable { - Any temp; + f_preproc](const ffi::AnyView* args, int num_args, ffi::Any* rv) mutable { + ffi::Any temp; std::ostringstream os; // skip first time call, to activate lazy compilation components. pf.CallPacked(args, num_args, &temp); @@ -924,26 +927,26 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re return ffi::Function::FromPacked(ftimer); } -TVM_REGISTER_GLOBAL("runtime.profiling.Report") +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Report") .set_body_typed([](Array> calls, Map> device_metrics, Map configuration) { return Report(calls, device_metrics, configuration); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Count").set_body_typed([](int64_t count) { return ObjectRef(make_object(count)); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Percent").set_body_typed([](double percent) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Percent").set_body_typed([](double percent) { return ObjectRef(make_object(percent)); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Duration").set_body_typed([](double duration) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Duration").set_body_typed([](double duration) { return ObjectRef(make_object(duration)); }); -TVM_REGISTER_GLOBAL("runtime.profiling.Ratio").set_body_typed([](double ratio) { +TVM_FFI_REGISTER_GLOBAL("runtime.profiling.Ratio").set_body_typed([](double ratio) { return ObjectRef(make_object(ratio)); }); diff --git a/src/runtime/regex.cc b/src/runtime/regex.cc index 8b4df9e69395..a91bf479ce4b 100644 --- a/src/runtime/regex.cc +++ b/src/runtime/regex.cc @@ -24,17 +24,18 @@ #include "./regex.h" -#include +#include namespace tvm { namespace runtime { bool regex_match(const std::string& match_against, const std::string& regex_pattern) { const auto regex_match_func = tvm::ffi::Function::GetGlobal("tvm.runtime.regex_match"); - CHECK(regex_match_func.has_value()) - << "RuntimeError: " - << "The ffi::Function 'tvm.runtime.regex_match' has not been registered. " - << "This can occur if the TVM Python library has not yet been imported."; + if (!regex_match_func.has_value()) { + TVM_FFI_THROW(RuntimeError) + << "The ffi::Function 'tvm.runtime.regex_match' has not been registered. " + << "This can occur if the TVM Python library has not yet been imported."; + } return (*regex_match_func)(regex_pattern, match_against).cast(); } diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc deleted file mode 100644 index 1045dc20ef0d..000000000000 --- a/src/runtime/registry.cc +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file registry.cc - * \brief The global registry of packed function. - */ -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "runtime_base.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief Execution environment specific API registry. - * - * This registry stores C API function pointers about - * execution environment(e.g. python) specific API function that - * we need for specific low-level handling(e.g. signal checking). - * - * We only stores the C API function when absolutely necessary (e.g. when signal handler - * cannot trap back into python). Always consider use the ffi::Function FFI when possible - * in other cases. - */ -class EnvCAPIRegistry { - public: - /*! - * \brief Callback to check if signals have been sent to the process and - * if so invoke the registered signal handler in the frontend environment. - * - * When running TVM in another language (Python), the signal handler - * may not be immediately executed, but instead the signal is marked - * in the interpreter state (to ensure non-blocking of the signal handler). - * - * \return 0 if no error happens, -1 if error happens. - */ - typedef int (*F_PyErr_CheckSignals)(); - - /*! \brief Callback to increment/decrement the python ref count */ - typedef void (*F_Py_IncDefRef)(void*); - - // NOTE: the following functions are only registered in a python - // environment. - /*! - * \brief PyErr_CheckSignal function - */ - F_PyErr_CheckSignals pyerr_check_signals = nullptr; - - /*! - * \brief Py_IncRef function - */ - F_Py_IncDefRef py_inc_ref = nullptr; - - /*! - * \brief Py_IncRef function - */ - F_Py_IncDefRef py_dec_ref = nullptr; - - /*! - \brief PyGILState_Ensure function - */ - void* (*py_gil_state_ensure)() = nullptr; - - /*! - \brief PyGILState_Release function - */ - void (*py_gil_state_release)(void*) = nullptr; - - static EnvCAPIRegistry* Global() { - static EnvCAPIRegistry* inst = new EnvCAPIRegistry(); - return inst; - } - - // register environment(e.g. python) specific api functions - void Register(const String& symbol_name, void* fptr) { - if (symbol_name == "PyErr_CheckSignals") { - Update(symbol_name, &pyerr_check_signals, fptr); - } else if (symbol_name == "Py_IncRef") { - Update(symbol_name, &py_inc_ref, fptr); - } else if (symbol_name == "Py_DecRef") { - Update(symbol_name, &py_dec_ref, fptr); - } else if (symbol_name == "PyGILState_Ensure") { - Update(symbol_name, &py_gil_state_ensure, fptr); - } else if (symbol_name == "PyGILState_Release") { - Update(symbol_name, &py_gil_state_release, fptr); - } else { - LOG(FATAL) << "Unknown env API " << symbol_name; - } - } - - // implementation of tvm::runtime::EnvCheckSignals - void CheckSignals() { - // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr) { - // The C++ env comes without gil, so we need to grab gil here - WithGIL context(this); - if ((*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - throw EnvErrorAlreadySet(); - } - } - } - - void IncRef(void* python_obj) { - WithGIL context(this); - ICHECK(py_inc_ref) << "Attempted to call Py_IncRef through EnvCAPIRegistry, " - << "but Py_IncRef wasn't registered"; - (*py_inc_ref)(python_obj); - } - - void DecRef(void* python_obj) { - WithGIL context(this); - ICHECK(py_dec_ref) << "Attempted to call Py_DefRef through EnvCAPIRegistry, " - << "but Py_DefRef wasn't registered"; - (*py_dec_ref)(python_obj); - } - - private: - // update the internal API table - template - void Update(const String& symbol_name, FType* target, void* ptr) { - FType ptr_casted = reinterpret_cast(ptr); - if (target[0] != nullptr && target[0] != ptr_casted) { - LOG(WARNING) << "tvm.runtime.RegisterEnvCAPI overrides an existing function " << symbol_name; - } - target[0] = ptr_casted; - } - - struct WithGIL { - explicit WithGIL(EnvCAPIRegistry* self) : self(self) { - ICHECK(self->py_gil_state_ensure) << "Attempted to acquire GIL through EnvCAPIRegistry, " - << "but PyGILState_Ensure wasn't registered"; - ICHECK(self->py_gil_state_release) << "Attempted to acquire GIL through EnvCAPIRegistry, " - << "but PyGILState_Release wasn't registered"; - gil_state = self->py_gil_state_ensure(); - } - ~WithGIL() { - if (self && gil_state) { - self->py_gil_state_release(gil_state); - } - } - WithGIL(const WithGIL&) = delete; - WithGIL(WithGIL&&) = delete; - WithGIL& operator=(const WithGIL&) = delete; - WithGIL& operator=(WithGIL&&) = delete; - - EnvCAPIRegistry* self = nullptr; - void* gil_state = nullptr; - }; -}; - -void EnvCheckSignals() { EnvCAPIRegistry::Global()->CheckSignals(); } - -WrappedPythonObject::WrappedPythonObject(void* python_obj) : python_obj_(python_obj) { - if (python_obj_) { - EnvCAPIRegistry::Global()->IncRef(python_obj_); - } -} - -WrappedPythonObject::~WrappedPythonObject() { - if (python_obj_) { - EnvCAPIRegistry::Global()->DecRef(python_obj_); - } -} - -WrappedPythonObject::WrappedPythonObject(WrappedPythonObject&& other) : python_obj_(nullptr) { - std::swap(python_obj_, other.python_obj_); -} -WrappedPythonObject& WrappedPythonObject::operator=(WrappedPythonObject&& other) { - std::swap(python_obj_, other.python_obj_); - return *this; -} - -WrappedPythonObject::WrappedPythonObject(const WrappedPythonObject& other) - : WrappedPythonObject(other.python_obj_) {} -WrappedPythonObject& WrappedPythonObject::operator=(const WrappedPythonObject& other) { - return *this = WrappedPythonObject(other); -} -WrappedPythonObject& WrappedPythonObject::operator=(std::nullptr_t) { - return *this = WrappedPythonObject(nullptr); -} - -} // namespace runtime -} // namespace tvm - -/*! \brief entry to easily hold returning information */ -struct TVMFuncThreadLocalEntry { - /*! \brief result holder for returning strings */ - std::vector ret_vec_str; - /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; -}; - -/*! \brief Thread local store that can be used to hold return values. */ -typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; - -int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { - API_BEGIN(); - using tvm::runtime::GetRef; - tvm::ffi::Function::SetGlobal( - name, GetRef(static_cast(f)), override != 0); - API_END(); -} - -int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { - API_BEGIN(); - const auto fp = tvm::ffi::Function::GetGlobal(name); - if (fp.has_value()) { - TVMFFIAny val = tvm::ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(tvm::ffi::Any(*fp)); - *out = val.v_obj; - } else { - *out = nullptr; - } - API_END(); -} - -int TVMFuncListGlobalNames(int* out_size, const char*** out_array) { - API_BEGIN(); - TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get(); - ret->ret_vec_str = tvm::ffi::Function::ListGlobalNames(); - ret->ret_vec_charp.clear(); - for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { - ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); - } - *out_array = dmlc::BeginPtr(ret->ret_vec_charp); - *out_size = static_cast(ret->ret_vec_str.size()); - API_END(); -} - -int TVMFuncRemoveGlobal(const char* name) { - API_BEGIN(); - tvm::ffi::Function::RemoveGlobal(name); - API_END(); -} - -int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) { - API_BEGIN(); - tvm::runtime::EnvCAPIRegistry::Global()->Register(name, ptr); - API_END(); -} diff --git a/src/runtime/relax_vm/attn_backend.h b/src/runtime/relax_vm/attn_backend.h index 4d3a3ce9d832..2eb9cf3d6677 100644 --- a/src/runtime/relax_vm/attn_backend.h +++ b/src/runtime/relax_vm/attn_backend.h @@ -25,9 +25,10 @@ #ifndef TVM_RUNTIME_RELAX_VM_ATTN_BACKEND_H_ #define TVM_RUNTIME_RELAX_VM_ATTN_BACKEND_H_ -#include +#include +#include +#include #include -#include #include #include diff --git a/src/runtime/relax_vm/attn_utils.h b/src/runtime/relax_vm/attn_utils.h index 3e53d6bbc215..547d32c1208e 100644 --- a/src/runtime/relax_vm/attn_utils.h +++ b/src/runtime/relax_vm/attn_utils.h @@ -65,7 +65,7 @@ enum class AttnKind : int { }; /*! \brief Given the attention kind and other metadata, return the one-layer KV cache shape. */ -inline ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, +inline ffi::Shape GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int num_sequence, int64_t num_kv_heads, int64_t page_size, int64_t qk_head_dim, int64_t v_head_dim) { if (attn_kind == AttnKind::kMHA) { @@ -77,7 +77,7 @@ inline ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, i return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim}; } ICHECK(false); - return ShapeTuple(); + return ffi::Shape(); } /*! @@ -662,7 +662,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { int n_elem = last_page_len->size(); ICHECK_GT(n_elem, 0); NDArray view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); - ShapeTuple copy_shape{n_elem}; + ffi::Shape copy_shape{n_elem}; CopyVecDataToArray(view, last_page_len->data(), copy_shape); CopyVecDataToArray(view, sliding_window_offset->data(), copy_shape, /*dst_elem_offset=*/n_elem); @@ -689,7 +689,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { ICHECK_GT(n_elem, 0); NDArray view = commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); - ShapeTuple copy_shape{n_elem}; + ffi::Shape copy_shape{n_elem}; CopyVecDataToArray(view, src_data->data(), copy_shape); CopyVecDataToArray(view, dst_data->data(), copy_shape, /*dst_elem_offset=*/n_elem); @@ -705,8 +705,8 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { * It optionally supports specifying the shape of copy and the element * offset to the destination NDArray. */ - void CopyVecDataToArray(NDArray array, int32_t* vec_data, Optional shape = NullOpt, - int dst_elem_offset = 0) { + void CopyVecDataToArray(NDArray array, int32_t* vec_data, + Optional shape = std::nullopt, int dst_elem_offset = 0) { if (array->shape[0] == 0) { return; } diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 29adde567ad7..3d7904bd8f48 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -20,22 +20,19 @@ * \file src/runtime/relax_vm/builtin.cc */ #include -#include -#include +#include +#include +#include +#include #include #include #include -#include #include #include -#include -#include #include #include #include -#include "../runtime_base.h" - namespace tvm { namespace runtime { namespace relax_vm { @@ -66,7 +63,7 @@ NDArray AllocShapeHeap(void* ctx_ptr, int64_t size) { return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); } -TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap").set_body_typed(AllocShapeHeap); /*! * \brief Builtin match R.Prim function. @@ -106,7 +103,7 @@ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t } } -TVM_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue); /*! * \brief Builtin match shape function. @@ -115,15 +112,15 @@ TVM_REGISTER_GLOBAL("vm.builtin.match_prim_value").set_body_typed(MatchPrimValue * * \sa MatchShapeCode */ -void MatchShape(ffi::PackedArgs args, Any* rv) { +void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { // input shape the first argument can take in tensor or shape. - ShapeTuple input_shape; + ffi::Shape input_shape; if (auto opt_nd = args[0].as()) { input_shape = opt_nd.value().Shape(); } else { - input_shape = args[0].cast(); + input_shape = args[0].cast(); } - auto heap = args[1].as(); + auto heap = args[1].try_cast(); int64_t* heap_data = heap.has_value() ? static_cast((*heap)->data) : nullptr; int64_t size = args[2].cast(); const int64_t kBeginCode = 3; @@ -157,7 +154,7 @@ void MatchShape(ffi::PackedArgs args, Any* rv) { } } -TVM_REGISTER_GLOBAL("vm.builtin.match_shape").set_body_packed(MatchShape); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.match_shape").set_body_packed(MatchShape); /*! * \brief Builtin make prim value function. @@ -181,7 +178,7 @@ int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { } } -TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); /*! * \brief Builtin make shape function. @@ -190,9 +187,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_prim_value").set_body_typed(MakePrimValue); * * \sa MakeShapeCode */ -void MakeShape(ffi::PackedArgs args, Any* rv) { +void MakeShape(ffi::PackedArgs args, ffi::Any* rv) { // NOTE: heap can be nullptr - auto heap = args[0].as(); + auto heap = args[0].try_cast(); int64_t* heap_data = heap.has_value() ? static_cast((*heap)->data) : nullptr; int64_t size = args[1].cast(); const int64_t kBeginCode = 2; @@ -209,10 +206,10 @@ void MakeShape(ffi::PackedArgs args, Any* rv) { shape[i] = heap_data[reg]; } } - *rv = ShapeTuple(std::move(shape)); + *rv = ffi::Shape(std::move(shape)); } -TVM_REGISTER_GLOBAL("vm.builtin.make_shape").set_body_packed(MakeShape); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_shape").set_body_packed(MakeShape); /*! * \brief Builtin function to check if arg is Tensor(dtype, ndim) @@ -221,8 +218,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_shape").set_body_packed(MakeShape); * \param dtype The expected content data type. * \param err_ctx Additional context if error occurs. */ -void CheckTensorInfo(ffi::PackedArgs args, Any* rv) { - AnyView arg = args[0]; +void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { + ffi::AnyView arg = args[0]; int ndim = args[1].cast(); DataType dtype; Optional err_ctx; @@ -235,7 +232,7 @@ void CheckTensorInfo(ffi::PackedArgs args, Any* rv) { err_ctx = args[3].cast>(); } - auto opt_ptr = arg.as(); + auto opt_ptr = arg.try_cast(); CHECK(opt_ptr.has_value()) << "TypeError: " << err_ctx.value_or("") << " expect a Tensor but get " << arg.GetTypeKey(); @@ -252,7 +249,7 @@ void CheckTensorInfo(ffi::PackedArgs args, Any* rv) { } } -TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body_packed(CheckTensorInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body_packed(CheckTensorInfo); /*! * \brief Builtin function to check if arg is Shape(ndim) @@ -262,7 +259,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_tensor_info").set_body_packed(CheckTensorI */ void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { // a function that lazily get context for error reporting - auto* ptr = arg.as(); + auto* ptr = arg.as(); CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Shape but get " << arg->GetTypeKey(); if (ndim != -1) { @@ -272,7 +269,7 @@ void CheckShapeInfo(ObjectRef arg, int ndim, Optional err_ctx) { } } -TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo); /*! * \brief Builtin function to check if arg is PrimValue(dtype) @@ -280,7 +277,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. * \param err_ctx Additional context if error occurs. */ -void CheckPrimValueInfo(AnyView arg, DataType dtype, Optional err_ctx) { +void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, Optional err_ctx) { if (auto opt_obj = arg.as()) { LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype << ", but received ObjectRef of type " << opt_obj.value()->GetTypeKey(); @@ -299,7 +296,7 @@ void CheckPrimValueInfo(AnyView arg, DataType dtype, Optional err_ctx) { } } -TVM_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_prim_value_info").set_body_typed(CheckPrimValueInfo); /*! * \brief Builtin function to check if arg is Tuple with size elements. @@ -317,7 +314,7 @@ void CheckTupleInfo(ObjectRef arg, int64_t size, Optional err_ctx) { << " but get a Tuple with " << ptr->size() << " elements."; } -TVM_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_tuple_info").set_body_typed(CheckTupleInfo); /*! * \brief Builtin function to check if arg is a callable function. @@ -331,12 +328,12 @@ void CheckFuncInfo(ObjectRef arg, Optional err_ctx) { << arg->GetTypeKey(); } -TVM_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.check_func_info").set_body_typed(CheckFuncInfo); //------------------------------------------------- // Storage management. //------------------------------------------------- -Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_shape, Index device_index, +Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_index, DLDataType dtype_hint, String mem_scope) { VirtualMachine* vm = static_cast(ctx_ptr); @@ -356,61 +353,65 @@ Storage VMAllocStorage(void* ctx_ptr, ShapeTuple buffer_shape, Index device_inde return Storage(buffer, alloc); } -TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_storage").set_body_typed(VMAllocStorage); -TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.alloc_tensor").set_body_method(&StorageObj::AllocNDArray); //------------------------------------------------- // Closure function handling, calling convention //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.make_closure").set_body_packed([](ffi::PackedArgs args, Any* rv) { - VMClosure clo = args[0].cast(); - std::vector saved_args; - saved_args.resize(args.size() - 1); - for (size_t i = 0; i < saved_args.size(); ++i) { - saved_args[i] = args[i + 1]; - } - auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); - *rv = VMClosure(clo->func_name, impl); -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_closure") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + VMClosure clo = args[0].cast(); + std::vector saved_args; + saved_args.resize(args.size() - 1); + for (size_t i = 0; i < saved_args.size(); ++i) { + saved_args[i] = args[i + 1]; + } + auto impl = VMClosure::BindLastArgs(clo->impl, saved_args); + *rv = VMClosure(clo->func_name, impl); + }); -TVM_REGISTER_GLOBAL("vm.builtin.invoke_closure").set_body_packed([](ffi::PackedArgs args, Any* rv) { - // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments - VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); - ObjectRef vm_closure = args[1].cast(); - vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.invoke_closure") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + // args[0]: vm; args[1]: closure; args[2, 3, ...]: function arguments + VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + ObjectRef vm_closure = args[1].cast(); + vm->InvokeClosurePacked(vm_closure, args.Slice(2), rv); + }); -TVM_REGISTER_GLOBAL("vm.builtin.call_tir_dyn").set_body_packed([](ffi::PackedArgs args, Any* rv) { - ffi::Function func = args[0].cast(); - ShapeTuple to_unpack = args[args.size() - 1].cast(); - size_t num_tensor_args = args.size() - 2; +TVM_FFI_REGISTER_GLOBAL("vm.builtin.call_tir_dyn") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + ffi::Function func = args[0].cast(); + ffi::Shape to_unpack = args[args.size() - 1].cast(); + size_t num_tensor_args = args.size() - 2; - std::vector packed_args(num_tensor_args + to_unpack.size()); - std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); + std::vector packed_args(num_tensor_args + to_unpack.size()); + std::copy(args.data() + 1, args.data() + args.size() - 1, packed_args.data()); - for (size_t i = 0; i < to_unpack.size(); ++i) { - packed_args[i + num_tensor_args] = to_unpack[i]; - } - func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); -}); + for (size_t i = 0; i < to_unpack.size(); ++i) { + packed_args[i + num_tensor_args] = to_unpack[i]; + } + func.CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), rv); + }); //------------------------------------- // Builtin runtime operators. //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.shape_of").set_body_method(&NDArray::Shape); -TVM_REGISTER_GLOBAL("vm.builtin.copy").set_body_typed([](Any a) -> Any { return a; }); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.copy").set_body_typed([](ffi::Any a) -> ffi::Any { return a; }); -TVM_REGISTER_GLOBAL("vm.builtin.reshape").set_body_typed([](NDArray data, ShapeTuple new_shape) { - return data.CreateView(new_shape, data->dtype); -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.reshape") + .set_body_typed([](NDArray data, ffi::Shape new_shape) { + return data.CreateView(new_shape, data->dtype); + }); -TVM_REGISTER_GLOBAL("vm.builtin.null_value").set_body_typed([]() -> std::nullptr_t { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.null_value").set_body_typed([]() -> std::nullptr_t { return nullptr; }); -TVM_REGISTER_GLOBAL("vm.builtin.to_device") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.to_device") .set_body_typed([](NDArray data, int dev_type, int dev_id) { Device dst_device = {(DLDeviceType)dev_type, dev_id}; return data.CopyTo(dst_device); @@ -421,8 +422,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.to_device") * \param cond The condition * \return Bool */ -bool ReadIfCond(AnyView cond) { - if (auto opt_int = cond.as()) { +bool ReadIfCond(ffi::AnyView cond) { + if (auto opt_int = cond.try_cast()) { return opt_int.value(); } NDArray arr = cond.cast(); @@ -459,14 +460,14 @@ bool ReadIfCond(AnyView cond) { return result != 0; } -TVM_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.read_if_cond").set_body_typed(ReadIfCond); //------------------------------------- // Debugging API //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") - .set_body_packed([](ffi::PackedArgs args, Any* rv) -> void { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) -> void { ICHECK_GE(args.size(), 3); int num_args = args.size() - 3; ObjectRef io_effect = args[0].cast(); @@ -477,7 +478,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") << "Use the decorator `@tvm.register_func(\"" << debug_func_name << "\")` to register it."; String line_info = args[2].cast(); - std::vector call_args(num_args + 1); + std::vector call_args(num_args + 1); { call_args[0] = line_info; for (int i = 0; i < num_args; ++i) { @@ -491,23 +492,24 @@ TVM_REGISTER_GLOBAL("vm.builtin.invoke_debug_func") //------------------------------------- // Data structure API //------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.tuple_getitem") - .set_body_typed([](runtime::Array arr, int64_t index) { return arr[index]; }); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.tuple_getitem") + .set_body_typed([](Array arr, int64_t index) { return arr[index]; }); -TVM_REGISTER_GLOBAL("vm.builtin.tuple_reset_item") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.tuple_reset_item") .set_body_typed([](const ffi::ArrayObj* arr, int64_t index) { const_cast(arr)->SetItem(index, nullptr); }); -TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body_packed([](ffi::PackedArgs args, Any* rv) { - runtime::Array arr; - for (int i = 0; i < args.size(); ++i) { - arr.push_back(args[i]); - } - *rv = arr; -}); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.make_tuple") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + Array arr; + for (int i = 0; i < args.size(); ++i) { + arr.push_back(args[i]); + } + *rv = arr; + }); -TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { NDArray arr = data; if (data->device.device_type != kDLCPU) { arr = data.CopyTo(DLDevice{kDLCPU, 0}); @@ -538,10 +540,10 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data } out_shape.push_back(result); } - return ShapeTuple(out_shape); + return ffi::Shape(out_shape); }); -TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { if (data->byte_offset == 0) { return data; } @@ -602,26 +604,26 @@ TVM_DLL int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMF int TVMBackendAnyListSetPackedArg(void* anylist, int index, TVMFFIAny* args, int arg_offset) { using namespace tvm::runtime; - API_BEGIN(); + TVM_FFI_SAFE_CALL_BEGIN(); auto* list = static_cast(anylist); args[arg_offset] = list[index]; - API_END(); + TVM_FFI_SAFE_CALL_END(); } int TVMBackendAnyListResetItem(void* anylist, int index) { using namespace tvm::runtime; - API_BEGIN(); - auto* list = static_cast(anylist); + TVM_FFI_SAFE_CALL_BEGIN(); + auto* list = static_cast(anylist); list[index] = nullptr; - API_END(); + TVM_FFI_SAFE_CALL_END(); } int TVMBackendAnyListMoveFromPackedReturn(void* anylist, int index, TVMFFIAny* args, int ret_offset) { using namespace tvm::runtime; - API_BEGIN(); - auto* list = static_cast(anylist); + TVM_FFI_SAFE_CALL_BEGIN(); + auto* list = static_cast(anylist); list[index] = tvm::ffi::details::AnyUnsafe::MoveTVMFFIAnyToAny(std::move(args[ret_offset])); - API_END(); + TVM_FFI_SAFE_CALL_END(); } } // extern "C" diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 0c0c8eda493c..d3484cbc7b3e 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -22,9 +22,8 @@ * \brief The CUDA graph related builtin functions for Relax virtual machine. */ -#include -#include -#include +#include +#include #include #include "../../../support/utils.h" @@ -41,9 +40,9 @@ struct CUDAGraphCaptureKey { // The symbolic variables the capture function depends on. When the capture function is ran with // different symbolic variable values, the CUDA graph will be re-captured as a different version, // identified by this shape tuple. This is default constructed as an empty tuple. - ShapeTuple shape_expr; + ffi::Shape shape_expr; - CUDAGraphCaptureKey(int64_t index, const Optional& shape_expr) : index(index) { + CUDAGraphCaptureKey(int64_t index, const Optional& shape_expr) : index(index) { if (shape_expr) { this->shape_expr = shape_expr.value(); } @@ -150,7 +149,7 @@ class CUDAGraphExtensionNode : public VMExtensionNode { * \return The return value of the capture function. */ ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args, - int64_t entry_index, Optional shape_expr) { + int64_t entry_index, Optional shape_expr) { CUDAGraphCaptureKey entry_key{entry_index, shape_expr}; if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) { // Launch CUDA graph @@ -241,7 +240,7 @@ class CUDAGraphExtension : public VMExtension { } }; -TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK(args.size() == 5 || args.size() == 4); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); @@ -249,14 +248,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") auto capture_func = args[1].cast(); auto func_args = args[2].cast(); int64_t entry_index = args[3].cast(); - Optional shape_expr = NullOpt; + Optional shape_expr = std::nullopt; if (args.size() == 5) { - shape_expr = args[4].cast(); + shape_expr = args[4].cast(); } *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr); }); -TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { ICHECK_EQ(args.size(), 3); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index 07d37bb4d3ea..52a0588be35c 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -38,16 +38,6 @@ namespace relax_vm { /*! \brief The magic number for the serialized VM bytecode file */ constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; -/*! \brief Possible types in the constant pool */ -enum ConstantType : int { - kNDArray = 0, - kDLDataType = 1, - kShapeTuple = 2, - kString = 3, - kInt = 4, - kFloat = 5, -}; - #define STREAM_CHECK(val, section) \ ICHECK(val) << "Invalid VM file format in the " << section << " section." \ << "\n"; @@ -75,8 +65,8 @@ std::string VMExecutable::Stats() const { } oss.seekp(-2, oss.cur); oss << "], "; - } else if (auto opt_shape = it.as()) { - ShapeTuple shape = opt_shape.value(); + } else if (auto opt_shape = it.as()) { + ffi::Shape shape = opt_shape.value(); oss << "shapetuple["; for (size_t i = 0; i < shape.size(); ++i) { oss << shape.at(i) << ", "; @@ -220,7 +210,7 @@ Module VMExecutable::LoadFromBinary(void* stream) { return Module(exec); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable") .set_body_typed(VMExecutable::LoadFromBinary); Module VMExecutable::LoadFromFile(const String& file_name) { @@ -231,7 +221,7 @@ Module VMExecutable::LoadFromFile(const String& file_name) { return VMExecutable::LoadFromBinary(reinterpret_cast(strm)); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable") .set_body_typed(VMExecutable::LoadFromFile); void VMFuncInfo::Save(dmlc::Stream* strm) const { @@ -264,30 +254,30 @@ void VMExecutable::SaveConstantSection(dmlc::Stream* strm) { strm->Write(static_cast(this->constants.size())); for (const auto& it : this->constants) { if (auto opt_nd = it.as()) { - strm->Write(ConstantType::kNDArray); + strm->Write(ffi::TypeIndex::kTVMFFINDArray); runtime::SaveDLTensor(strm, opt_nd.value().operator->()); - } else if (auto opt_shape = it.as()) { - ShapeTuple shape = opt_shape.value(); - strm->Write(ConstantType::kShapeTuple); + } else if (auto opt_shape = it.as()) { + ffi::Shape shape = opt_shape.value(); + strm->Write(ffi::TypeIndex::kTVMFFIShape); strm->Write(shape.size()); for (size_t i = 0; i < shape.size(); ++i) { strm->Write(shape.at(i)); } } else if (auto opt_str = it.as()) { String str = opt_str.value(); - strm->Write(ConstantType::kString); + strm->Write(ffi::TypeIndex::kTVMFFIStr); strm->Write(str.size()); for (size_t i = 0; i < str.size(); ++i) { strm->Write(str.at(i)); } } else if (auto opt_int = it.as()) { - strm->Write(ConstantType::kInt); + strm->Write(ffi::TypeIndex::kTVMFFIInt); strm->Write(opt_int.value()); } else if (auto opt_float = it.as()) { - strm->Write(ConstantType::kFloat); + strm->Write(ffi::TypeIndex::kTVMFFIFloat); strm->Write(opt_float.value()); } else if (auto opt_dtype = it.as()) { - strm->Write(ConstantType::kDLDataType); + strm->Write(ffi::TypeIndex::kTVMFFIDataType); strm->Write(opt_dtype.value()); } else { LOG(FATAL) << "Unsupported constant pool type " << it.GetTypeKey(); @@ -320,27 +310,27 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { for (size_t i = 0; i < size; i++) { int constant_type; STREAM_CHECK(strm->Read(&constant_type, sizeof(constant_type)), "constant"); - if (constant_type == ConstantType::kNDArray) { + if (constant_type == ffi::TypeIndex::kTVMFFINDArray) { ndarray.Load(strm); ffi::Any cell; cell = ndarray; this->constants.push_back(cell); - } else if (constant_type == ConstantType::kShapeTuple) { + } else if (constant_type == ffi::TypeIndex::kTVMFFIShape) { uint64_t size; strm->Read(&size); - std::vector data(size); + std::vector data(size); for (size_t i = 0; i < size; ++i) { strm->Read(&(data[i])); } ffi::Any cell; - cell = ShapeTuple(data); + cell = ffi::Shape(data); this->constants.push_back(cell); - } else if (constant_type == ConstantType::kDLDataType) { + } else if (constant_type == ffi::TypeIndex::kTVMFFIDataType) { strm->Read(&dtype); ffi::Any cell; cell = dtype; this->constants.push_back(cell); - } else if (constant_type == ConstantType::kString) { + } else if (constant_type == ffi::TypeIndex::kTVMFFIStr) { uint64_t size; strm->Read(&size); std::vector data(size); @@ -350,13 +340,13 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { ffi::Any cell; cell = String(std::string(data.begin(), data.end())); this->constants.push_back(cell); - } else if (constant_type == ConstantType::kInt) { + } else if (constant_type == ffi::TypeIndex::kTVMFFIInt) { int64_t value; strm->Read(&value); ffi::Any cell; cell = value; this->constants.push_back(cell); - } else if (constant_type == ConstantType::kFloat) { + } else if (constant_type == ffi::TypeIndex::kTVMFFIFloat) { double value; strm->Read(&value); ffi::Any cell; @@ -364,7 +354,7 @@ void VMExecutable::LoadConstantSection(dmlc::Stream* strm) { this->constants.push_back(cell); } else { LOG(FATAL) << "Constant pool can only contain NDArray and DLDataType, but got " - << ArgTypeCode2Str(constant_type) << " when loading the VM constant pool."; + << ffi::TypeIndexToTypeKey(constant_type) << " when loading the VM constant pool."; } } } @@ -567,7 +557,7 @@ String VMExecutable::AsPython() const { return String(os.str()); } -TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile); +TVM_FFI_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile); } // namespace relax_vm } // namespace runtime diff --git a/src/runtime/relax_vm/hexagon/builtin.cc b/src/runtime/relax_vm/hexagon/builtin.cc index 3cfa4db71744..89f1708b28ed 100644 --- a/src/runtime/relax_vm/hexagon/builtin.cc +++ b/src/runtime/relax_vm/hexagon/builtin.cc @@ -22,9 +22,8 @@ * \brief The hexagon graph related builtin functions for Relax virtual machine. */ +#include #include -#include -#include #include #include "../../hexagon/hexagon_device_api.h" @@ -32,7 +31,7 @@ namespace tvm { namespace runtime { namespace relax_vm { -TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") .set_body_typed([](ffi::AnyView vm_ptr, NDArray src_arr, NDArray dst_arr, int queue_id, bool bypass_cache) { const DLTensor* dptr = dst_arr.operator->(); @@ -54,7 +53,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") CHECK(ret == DMA_SUCCESS); }); -TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") .set_body_typed([](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, [[maybe_unused]] NDArray src_arr, [[maybe_unused]] NDArray dst_arr) { ICHECK(inflight_dma >= 0); diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index 1af7cf78c944..12f52c0794e9 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -31,72 +31,74 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj); TVM_REGISTER_OBJECT_TYPE(RNNStateObj); // KV State base methods -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method(&KVStateObj::Clear); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence").set_body_method(&KVStateObj::AddSequence); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method(&KVStateObj::Clear); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence") + .set_body_method(&KVStateObj::AddSequence); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence") .set_body_method(&KVStateObj::RemoveSequence); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence").set_body_method(&KVStateObj::ForkSequence); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence") + .set_body_method(&KVStateObj::ForkSequence); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method(&KVStateObj::PopN); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 3 || args.size() == 4) << "KVState BeginForward only accepts 3 or 4 arguments"; KVState kv_state = args[0].cast(); - IntTuple seq_ids = args[1].cast(); - IntTuple append_lengths = args[2].cast(); - Optional token_tree_parent_ptr; + ffi::Shape seq_ids = args[1].cast(); + ffi::Shape append_lengths = args[2].cast(); + Optional token_tree_parent_ptr; if (args.size() == 4) { - token_tree_parent_ptr = args[3].cast>(); + token_tree_parent_ptr = args[3].cast>(); } kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr); }); -TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward").set_body_method(&KVStateObj::EndForward); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward").set_body_method(&KVStateObj::EndForward); // Attention KV Cache methods -TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv") .set_body_method(&AttentionKVCacheObj::DisaggPrepareRecv); -TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send") .set_body_method(&AttentionKVCacheObj::DisaggMarkSend); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") .set_body_method(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty") .set_body_method(&AttentionKVCacheObj::Empty); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages") .set_body_method(&AttentionKVCacheObj::GetNumAvailablePages); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length") .set_body_method(&AttentionKVCacheObj::GetTotalSequenceLength); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions") .set_body_method(&AttentionKVCacheObj::GetQueryPositions); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv") .set_body_method(&AttentionKVCacheObj::DebugGetKV); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla") .set_body_method(&AttentionKVCacheObj::DebugGetKVMLA); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray qkv_data, NDArray o_data) { - kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data), - sm_scale); + kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), std::nullopt, + std::move(o_data), sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_self_attention") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_self_attention") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, NDArray k_data, NDArray v_data, NDArray o_data, NDArray lse_data) { kv_cache->SelfAttention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), std::move(o_data), std::move(lse_data), sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_cross_attention") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_cross_attention") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, double sm_scale, NDArray q_data, NDArray o_data, NDArray lse_data) { kv_cache->CrossAttention(layer_id, std::move(q_data), std::move(o_data), std::move(lse_data), sm_scale); }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append_mla_kv") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append_mla_kv") .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, NDArray kv_data) { kv_cache->AppendMLAKV(layer_id, std::move(kv_data)); return kv_cache; }); -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") .set_body_typed([](AttentionKVCache kv_cache, NDArray o_self_attn, NDArray lse_self_attn, NDArray o_cross_attn, NDArray lse_cross_attn) { return kv_cache->MergeAttnOutputInplace(std::move(o_self_attn), std::move(lse_self_attn), @@ -104,13 +106,13 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_merge_attn_output_inplace") }); // RNN State methods -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_set") .set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) { state->Set(layer_id, state_id, data); return state; }); -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get").set_body_method(&RNNStateObj::DebugGet); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get").set_body_method(&RNNStateObj::DebugGet); } // namespace relax_vm } // namespace runtime diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 493fceaae9a2..5800c4e2db93 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -18,13 +18,13 @@ */ #ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_ #define TVM_RUNTIME_RELAX_VM_KV_STATE_H_ -#include +#include +#include #include +#include #include #include -#include - -#include "tvm/runtime/object.h" +#include namespace tvm { namespace runtime { @@ -95,7 +95,7 @@ class KVStateObj : public Object { * is a chain. */ virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, - const Optional& token_tree_parent_ptr = NullOpt) = 0; + const Optional& token_tree_parent_ptr = std::nullopt) = 0; /*! * \brief Mark the start of the forward function. diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 44079b48d1c5..8abeddcf18dc 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -35,11 +35,11 @@ * * We can evolve this implementation as we build more LM verticals. */ -#include -#include +#include +#include +#include #include #include -#include #include #include #include @@ -81,7 +81,7 @@ class AttentionKVCacheLegacyObj : public Object { * \brief View all current cached values as one array. * \param shape The cached values. */ - NDArray View(const ShapeTuple& shape) { + NDArray View(const ffi::Shape& shape) { CHECK_EQ(shape[0], fill_count) << "Requested shape do not match the filled count"; for (int i = 1; i < this->data->ndim; ++i) { CHECK_EQ(shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; @@ -237,7 +237,7 @@ class AttentionKVCacheLegacy : public ObjectRef { * \brief Create the attention kv cache. * \param init_data The initial reserved. */ - static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple reserve_shape, + static AttentionKVCacheLegacy Create(NDArray init_data, ffi::Shape reserve_shape, int init_fill_count) { auto n = make_object(); n->data = NDArray::Empty(reserve_shape, init_data->dtype, init_data->device); @@ -259,7 +259,7 @@ TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj); //------------------------------------------------- // Register runtime functions //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create") .set_body_typed(AttentionKVCacheLegacy::Create); AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, NDArray value) { @@ -267,14 +267,16 @@ AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, NDAr return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update").set_body_typed(AttentionKVCacheUpdate); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_update") + .set_body_typed(AttentionKVCacheUpdate); AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, NDArray value) { cache->Append(value); return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append").set_body_typed(AttentionKVCacheAppend); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append") + .set_body_typed(AttentionKVCacheAppend); AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cache, NDArray value, int64_t max_cache_size) { @@ -282,7 +284,7 @@ AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cac return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override") .set_body_typed(AttentionKVCacheWindowOverride); AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache, @@ -293,29 +295,29 @@ AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheL return cache; } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override_with_sinks") .set_body_typed(AttentionKVCacheWindowOverrideWithSinks); -NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ShapeTuple shape) { +NDArray AttentionKVCacheView(AttentionKVCacheLegacy cache, ffi::Shape shape) { return cache->View(shape); } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_view") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 1 || args.size() == 2) << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " << args.size() << "."; AttentionKVCacheLegacy cache = args[0].cast(); if (args.size() == 2) { - ShapeTuple shape = args[1].cast(); + ffi::Shape shape = args[1].cast(); *rv = cache->View(shape); } else { - std::vector shape; + std::vector shape; shape.push_back(cache->fill_count); for (int i = 1; i < cache->data->ndim; ++i) { shape.push_back(cache->data->shape[i]); } - *rv = cache->View(ShapeTuple(shape)); + *rv = cache->View(ffi::Shape(shape)); } }); @@ -325,7 +327,7 @@ void AttentionKVCacheArrayPopN(Array caches, int64_t n) } } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_popn") .set_body_typed(AttentionKVCacheArrayPopN); void AttentionKVCacheArrayClear(Array caches) { @@ -334,7 +336,7 @@ void AttentionKVCacheArrayClear(Array caches) { } } -TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_array_clear") .set_body_typed(AttentionKVCacheArrayClear); // NOTE this is a built-in highly related to LM so we put it here. @@ -399,7 +401,7 @@ int SampleTopPFromLogits(NDArray logits, double temperature, double top_p, doubl return data[0].second; } -TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_logits").set_body_typed(SampleTopPFromLogits); int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { ICHECK(prob.IsContiguous()); @@ -494,7 +496,7 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) { return sampled_index; } -TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb); NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { ICHECK(prob.IsContiguous()); @@ -531,7 +533,8 @@ NDArray MultinomialFromUniform(NDArray prob, NDArray uniform_sample) { return new_array; } -TVM_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform").set_body_typed(MultinomialFromUniform); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.multinomial_from_uniform") + .set_body_typed(MultinomialFromUniform); // This is an inplace operation. void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { @@ -554,7 +557,8 @@ void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) { } } -TVM_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty").set_body_typed(ApplyRepetitionPenalty); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty") + .set_body_typed(ApplyRepetitionPenalty); /*! * \brief Apply presence and frequency penalty. This is an inplace operation. @@ -589,7 +593,7 @@ void ApplyPresenceAndFrequencyPenalty(NDArray logits, NDArray token_ids, NDArray } } -TVM_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_presence_and_frequency_penalty") .set_body_typed(ApplyPresenceAndFrequencyPenalty); // This is an inplace operation. @@ -614,7 +618,7 @@ void ApplySoftmaxWithTemperature(NDArray logits, double temperature) { } } -TVM_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature") .set_body_typed(ApplySoftmaxWithTemperature); } // namespace relax_vm diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index fc6eb6bf360f..7341507e9a98 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -39,8 +39,8 @@ #define __STDC_FORMAT_MACROS #endif #include +#include #include -#include #include #include @@ -65,7 +65,7 @@ inline ValueType GetValue(const picojson::object& json, const std::string& key) } NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson::object& json) { - std::vector shape; + std::vector shape; { picojson::array shape_json = GetValue(json, "shape"); shape.reserve(shape_json.size()); @@ -80,7 +80,7 @@ NDArrayCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const picojson:: result.format = GetValue(json, "format"); result.nbytes = GetValue(json, "nbytes"); result.byte_offset = GetValue(json, "byteOffset"); - result.shape = ShapeTuple(std::move(shape)); + result.shape = ffi::Shape(std::move(shape)); return result; } @@ -153,7 +153,7 @@ void CopyNDArrayFromBytes(NDArray param, const void* data, size_t nbytes, if (staging_buffer->defined()) { size_t curr_size = runtime::GetDataSize(*(staging_buffer->value().operator->())); if (curr_size < nbytes) { - *staging_buffer = NullOpt; + *staging_buffer = std::nullopt; } } if (!staging_buffer->defined()) { @@ -162,7 +162,7 @@ void CopyNDArrayFromBytes(NDArray param, const void* data, size_t nbytes, NDArray staging_view = staging_buffer->value().CreateView(param.Shape(), param->dtype); staging_view.CopyFromBytes(data, nbytes); param.CopyFrom(staging_view); - TVMSynchronize(device.device_type, device.device_id, nullptr); + DeviceAPI::Get(device)->StreamSync(device, nullptr); } NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load( @@ -225,7 +225,7 @@ class NDArrayCache { if (it != pool->pool_.end()) { return (*it).second; } else { - return NullOpt; + return std::nullopt; } } @@ -266,9 +266,9 @@ class NDArrayCache { Map pool_; }; -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") - .set_body_packed([](ffi::PackedArgs args, Any* rv) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CHECK(args.size() == 2 || args.size() == 3); String name = args[0].cast(); bool is_override = args.size() == 2 ? false : args[2].cast(); @@ -285,14 +285,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update") } arr = NDArray::Empty(shape, tensor->dtype, tensor->device); arr.CopyFrom(tensor); - TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); + DeviceAPI::Get(arr->device)->StreamSync(arr->device, nullptr); } NDArrayCache::Update(name, arr, is_override); }); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load); // This param module node can be useful to get param dict in RPC mode // when the remote already have loaded parameters from file. @@ -353,18 +353,20 @@ class ParamModuleNode : public runtime::ModuleNode { Array params_; }; -TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache").set_body_typed(ParamModuleNode::Create); -TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_module_from_cache") + .set_body_typed(ParamModuleNode::Create); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name") .set_body_typed(ParamModuleNode::CreateByName); -TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams); -TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache") + .set_body_typed(ParamModuleNode::GetParams); +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") .set_body_typed(ParamModuleNode::GetParamByName); -TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked") - .set_body_packed([](ffi::PackedArgs args, Any* rv) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { Array names; names.reserve(args.size()); for (int i = 0; i < args.size(); ++i) { - if (!args[i].as()) { + if (!args[i].try_cast()) { LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].GetTypeKey() << " at " << i; } diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index c9fc851ea772..2f21e6978a68 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -20,12 +20,12 @@ * \file src/runtime/relax_vm/paged_kv_cache.cc * \brief Runtime paged KV cache object for language models. */ +#include #include #include #include #include #include -#include #include #include @@ -155,9 +155,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief The batch size of the current round of forwarding. */ int64_t cur_batch_size_; /*! \brief The ids of the sequences in the current round of forwarding. */ - IntTuple cur_seq_ids_; + ffi::Shape cur_seq_ids_; /*! \brief The append lengths of the sequences in the current round of forwarding. */ - IntTuple cur_append_lengths_; + ffi::Shape cur_append_lengths_; /*! \brief Whether the current batch of sequences are token chains (not token trees). */ std::vector is_chain_on_depths_; /*! \brief Number of fork depth in the current round of forward. */ @@ -244,7 +244,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Optional f_transpose_append_mha_; Optional f_transpose_append_mla_; Optional f_transfer_kv_; - Optional f_transfer_kv_page_to_page_ = NullOpt; + Optional f_transfer_kv_page_to_page_ = std::nullopt; ffi::Function f_compact_copy_; std::unique_ptr f_attention_prefill_ragged_; std::unique_ptr f_attention_prefill_; @@ -343,7 +343,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK(f_nvshmem_empty.has_value()); nvshmem_pages_ = (*f_nvshmem_empty)( - ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}), + ffi::Shape({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}), dtype, device) .cast(); for (int i = 0; i < num_layers; ++i) { @@ -362,7 +362,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr; } else { for (int i = 0; i < num_layers; ++i) { - ShapeTuple kv_cache_shape = + ffi::Shape kv_cache_shape = GetKVCacheShape(attn_kinds_[layer_id_begin_offset_ + i], num_total_pages, reserved_num_seqs, num_kv_heads, page_size, qk_head_dim, v_head_dim); pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device)); @@ -821,8 +821,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Attention **************/ - void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, - const Optional& opt_token_tree_parent_ptr) final { + void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, + const Optional& opt_token_tree_parent_ptr) final { // Note: MLA does not supported tree attention for now. if (attn_kinds_[0] == AttnKind::kMLA) { CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA"; @@ -1101,11 +1101,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - IntTuple DisaggPrepareRecv(int64_t seq_id, int append_length) final { + ffi::Shape DisaggPrepareRecv(int64_t seq_id, int append_length) final { // No CPU to GPU copy is needed. // Essentially we // (step 1.) redirect the preparation to BeginForward. - BeginForward({seq_id}, {append_length}, /*opt_token_tree_parent_ptr=*/NullOpt); + BeginForward({seq_id}, {append_length}, /*opt_token_tree_parent_ptr=*/std::nullopt); // (step 2.) fetch the append_position_map, compress and return. // Compression format: [n, begin_1, length_1, begin_2, length_2, ..., begin_n, length_n] // The compressed format will be decompressed to: @@ -1128,11 +1128,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { compressed_append_pos_map.back() + 1); // The compressed array size should be "num_segments * 2 + 1". CHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] * 2 + 1); - return IntTuple{compressed_append_pos_map}; + return ffi::Shape{compressed_append_pos_map}; } - void DisaggMarkSend(int64_t seq_id, int64_t begin, const IntTuple& compressed_remote_position_map, - int32_t recver_pe_offset) { + void DisaggMarkSend(int64_t seq_id, int64_t begin, + const ffi::Shape& compressed_remote_position_map, int32_t recver_pe_offset) { ICHECK(f_transfer_kv_.defined()); auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; @@ -1404,7 +1404,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Todo(ruihang): implement it } - void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& leaf_indices) final { + void CommitAcceptedTokenTreeNodes(const ffi::Shape& seq_ids, + const ffi::Shape& leaf_indices) final { CHECK_EQ(seq_ids.size(), leaf_indices.size()) << "The given seq_ids and leaf_indices have different size."; int num_seq_to_commit = seq_ids.size(); @@ -1631,7 +1632,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } void ConstructTokenTreeMask(const std::vector& sequences, - const IntTuple& token_tree_parent_ptr, + const ffi::Shape& token_tree_parent_ptr, const std::vector>& block_ids_on_depths, const std::vector>& trailing_blocks) { // Check whether the token tree of a sequence should be handled at the current depth. @@ -2283,13 +2284,13 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); // Register runtime functions //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") - .set_body_packed([](ffi::PackedArgs args, Any* rv) { +TVM_FFI_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // Todo: cuda graph arg CHECK(args.size() == 28 || args.size() == 29) << "Invalid number of KV cache constructor args: " << args.size(); - ShapeTuple cache_config = args[0].cast(); - ShapeTuple layer_indptr_tuple = args[1].cast(); + ffi::Shape cache_config = args[0].cast(); + ffi::Shape layer_indptr_tuple = args[1].cast(); int num_groups = 1; int group_id = 0; if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) { @@ -2305,15 +2306,15 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") int64_t num_kv_heads = args[3].cast(); int64_t qk_head_dim = args[4].cast(); int64_t v_head_dim = args[5].cast(); - IntTuple attn_kinds = args[6].cast(); + ffi::Shape attn_kinds = args[6].cast(); bool enable_kv_transfer = args[7].cast(); int rope_mode = args[8].cast(); double rotary_scale = args[9].cast(); double rotary_theta = args[10].cast(); - Optional rope_ext_factors = NullOpt; // args[11] + Optional rope_ext_factors = std::nullopt; // args[11] NDArray init = args[12].cast(); - Optional f_transpose_append_mha = NullOpt; // args[13] - Optional f_transpose_append_mla = NullOpt; // args[14] + Optional f_transpose_append_mha = std::nullopt; // args[13] + Optional f_transpose_append_mla = std::nullopt; // args[14] std::unique_ptr f_attention_prefill_ragged = ConvertRaggedPrefillFunc(args[15].cast>(), AttnKind::kMHA); std::unique_ptr f_attention_prefill = @@ -2343,7 +2344,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") if (auto opt_func = args[arg_idx].as()) { return opt_func.value(); } - return NullOpt; + return std::nullopt; }; f_transpose_append_mha = f_convert_optional_packed_func(13); f_transpose_append_mla = f_convert_optional_packed_func(14); diff --git a/src/runtime/relax_vm/rnn_state.cc b/src/runtime/relax_vm/rnn_state.cc index 9468a50d2071..d431fdb2ae2f 100644 --- a/src/runtime/relax_vm/rnn_state.cc +++ b/src/runtime/relax_vm/rnn_state.cc @@ -103,9 +103,9 @@ class RNNStateImpObj : public RNNStateObj { /*! \brief The batch size of the current round of forwarding. */ int64_t cur_batch_size_; /*! \brief The append lengths of the sequences in the current round of forwarding. */ - IntTuple cur_append_lengths_; + ffi::Shape cur_append_lengths_; /*! \brief The sequence ids of the current round of forwarding. */ - IntTuple cur_seq_ids_; + ffi::Shape cur_seq_ids_; /**************** Auxiliary Arrays on Device *****************/ @@ -173,8 +173,8 @@ class RNNStateImpObj : public RNNStateObj { Array layer_storages; layer_storages.reserve(num_states_per_layer_); for (int64_t state_id = 0; state_id < num_states_per_layer_; ++state_id) { - ShapeTuple state_shape = init_layer_value[state_id].Shape(); - std::vector storage_shape = {reserved_num_seqs, max_history}; + ffi::Shape state_shape = init_layer_value[state_id].Shape(); + std::vector storage_shape = {reserved_num_seqs, max_history}; storage_shape.insert(storage_shape.end(), state_shape.begin(), state_shape.end()); NDArray state_storage = NDArray::Empty(storage_shape, init_layer_value[state_id].DataType(), device); @@ -205,14 +205,14 @@ class RNNStateImpObj : public RNNStateObj { /************** Interaction **************/ - void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, - const Optional& opt_token_tree_parent_ptr) final { + void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, + const Optional& opt_token_tree_parent_ptr) final { CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; if (opt_token_tree_parent_ptr.defined()) { - IntTuple token_tree_parent_ptr = opt_token_tree_parent_ptr.value(); + ffi::Shape token_tree_parent_ptr = opt_token_tree_parent_ptr.value(); int matched_pos = 0; for (int64_t append_length : append_lengths) { for (int64_t i = 0; i < append_length; ++i) { @@ -464,7 +464,7 @@ TVM_REGISTER_OBJECT_TYPE(RNNStateImpObj); // Register runtime functions //------------------------------------------------- -TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_create") +TVM_FFI_REGISTER_GLOBAL("vm.builtin.rnn_state_create") .set_body_typed([](int64_t num_layers, // int64_t reserved_num_seqs, // int64_t max_history, // diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index c92509923b7e..8d0b928f8592 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -21,9 +21,9 @@ * \file src/runtime/relax_vm/vm.cc */ #include +#include #include #include -#include #include #include @@ -52,9 +52,9 @@ VMClosure::VMClosure(String func_name, ffi::Function impl) { * \param last_args The arguments to bound to in the end of the function. * \note The new function takes in arguments and append the last_args in the end. */ -ffi::Function VMClosure::BindLastArgs(ffi::Function func, std::vector last_args) { +ffi::Function VMClosure::BindLastArgs(ffi::Function func, std::vector last_args) { return ffi::Function([func, last_args](ffi::PackedArgs args, ffi::Any* rv) { - std::vector packed_args(args.size() + last_args.size()); + std::vector packed_args(args.size() + last_args.size()); std::copy(args.data(), args.data() + args.size(), packed_args.data()); for (size_t i = 0; i < last_args.size(); ++i) { packed_args[args.size() + i] = last_args[i]; @@ -68,14 +68,14 @@ ffi::Function VMClosure::BindLastArgs(ffi::Function func, std::vector last_ //----------------------------------------------------------- // Use the args after `starting_arg_idx` as a series of indices into `obj`, // indexing into nested Array and returning the final indexed object. -Any IndexIntoNestedObject(Any obj, ffi::PackedArgs args, int starting_arg_idx) { +ffi::Any IndexIntoNestedObject(ffi::Any obj, ffi::PackedArgs args, int starting_arg_idx) { for (int i = starting_arg_idx; i < args.size(); i++) { // the object must be an Array to be able to index into it if (!obj.as()) { LOG(FATAL) << "ValueError: Attempted to index into an object that is not an Array."; } int index = args[i].cast(); - auto arr = Downcast>(obj); + auto arr = Downcast>(obj); // make sure the index is in bounds if (index >= static_cast(arr.size())) { LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << arr.size() << ")."; @@ -110,12 +110,12 @@ Any ConvertObjectToDevice(Any src, const Device& dev, Allocator* alloc) { } } -ffi::Any ConvertArgToDevice(AnyView input, Device dev, Allocator* alloc) { +ffi::Any ConvertArgToDevice(ffi::AnyView input, Device dev, Allocator* alloc) { // in terms of memory-behavior. // To be extra careful, we copy DLTensor. // The developer can still explicitly allocate NDArray // in TVM Native API or NDArray::FromDLPack to regain zero copy behavior. - Any ret; + ffi::Any ret; if (auto opt_obj = input.as()) { ret = ConvertObjectToDevice(opt_obj.value(), dev, alloc); } else if (auto opt_dltensor = input.as()) { @@ -131,7 +131,7 @@ ffi::Any ConvertArgToDevice(AnyView input, Device dev, Allocator* alloc) { } ffi::Any ConvertRegToDevice(ffi::Any input, Device dev, Allocator* alloc) { - Any ret; + ffi::Any ret; if (auto opt_obj = input.as()) { ret = ConvertObjectToDevice(opt_obj.value(), dev, alloc); } else { @@ -162,15 +162,8 @@ struct VMFrame { std::vector register_file; /*! \brief Register in caller's frame to put return value */ RegName caller_return_register; - // The following fields are used for ffi::Function call within - // a single function scope. The space is reused across multiple - // packed func calls to increase cache locality and avoid re-allocation - /*! \brief Temporary argument value stack for packed func call. */ - std::vector call_arg_values; /*! \brief Temporary argument tcode stack for packed func call. */ - std::vector call_arg_tcodes; - - std::vector call_args; + std::vector call_args; VMFrame(Index pc, Index register_file_size) : return_pc(pc), register_file(register_file_size), caller_return_register(0) {} @@ -541,7 +534,7 @@ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedf auto* clo = closure_or_packedfunc.as(); ICHECK(clo != nullptr) << "Function expects a closure or ffi::Function "; - std::vector packed_args(args.size() + 1); + std::vector packed_args(args.size() + 1); // per convention, ctx ptr must be VirtualMachine* casted to void. // this and VirtualMachine* may or maynot be the same // do first cast to VirtualMachine* then to void* @@ -561,7 +554,7 @@ RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_pa auto* clo = closure_or_packed.as(); int clo_offset = clo != nullptr ? 1 : 0; - std::vector packed_args(args.size() + clo_offset); + std::vector packed_args(args.size() + clo_offset); if (clo != nullptr) { packed_args[0] = static_cast(static_cast(this)); @@ -605,7 +598,7 @@ Optional VirtualMachineImpl::GetClosureInternal(const String& func_na } auto it = exec_->func_map.find(func_name); if (it == exec_->func_map.end()) { - if (allow_missing) return NullOpt; + if (allow_missing) return std::nullopt; LOG(FATAL) << "ValueError: Unknown function: " << func_name; } @@ -733,7 +726,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { // NOTE: no changes and resize to those vector ref(otherwise can leads to segfault) // in the remainder part of the function. - std::vector& call_args = curr_frame->call_args; + std::vector& call_args = curr_frame->call_args; for (Index i = 0; i < instr.num_args; ++i) { Instruction::Arg arg = instr.args[i]; @@ -775,7 +768,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { call_args[2] = true; call_args[3] = nullptr; - Any rv; + ffi::Any rv; // store dtype to str since py callback cannot handle dtype atm. std::vector> temp_dtype; for (int i = 0; i < instr.num_args; ++i) { @@ -788,7 +781,7 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { } int ret_kind = static_cast(VMInstrumentReturnKind::kNoOp); instrument_.CallPacked(call_args.data(), call_args.size(), &rv); - if (auto opt_int = rv.as()) { + if (auto opt_int = rv.try_cast()) { ret_kind = opt_int.value(); } if (ret_kind != static_cast(VMInstrumentReturnKind::kSkipRun)) { @@ -913,7 +906,7 @@ void VirtualMachineImpl::_SetInstrument(ffi::PackedArgs args, ffi::Any* rv) { void VirtualMachineImpl::_GetOutputArity(ffi::PackedArgs args, ffi::Any* rv) { std::string func_name = args[0].cast(); RegType out = LookupVMOutput(func_name); - Any obj = IndexIntoNestedObject(out, args, 1); + ffi::Any obj = IndexIntoNestedObject(out, args, 1); if (const auto* arr = obj.as()) { *rv = static_cast(arr->size()); } else { @@ -924,7 +917,7 @@ void VirtualMachineImpl::_GetOutputArity(ffi::PackedArgs args, ffi::Any* rv) { void VirtualMachineImpl::_GetOutput(ffi::PackedArgs args, ffi::Any* rv) { std::string func_name = args[0].cast(); RegType out = LookupVMOutput(func_name); - Any obj = IndexIntoNestedObject(out, args, 1); + ffi::Any obj = IndexIntoNestedObject(out, args, 1); if (obj.as()) { LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " "Please specify another index argument."; diff --git a/src/runtime/rocm/rocm_common.h b/src/runtime/rocm/rocm_common.h index b258e37508df..ec3e744d3034 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/runtime/rocm/rocm_common.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 67991717552e..a5bc3b1a0da5 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -24,10 +24,10 @@ #include #include #include +#include #include #include #include -#include #include "rocm_common.h" @@ -251,15 +251,16 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.rocm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); -TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.rocm_host") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global(); + *rv = static_cast(ptr); + }); class ROCMTimerNode : public TimerNode { public: @@ -292,11 +293,11 @@ class ROCMTimerNode : public TimerNode { TVM_REGISTER_OBJECT_TYPE(ROCMTimerNode); -TVM_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) { +TVM_FFI_REGISTER_GLOBAL("profiling.timer.rocm").set_body_typed([](Device dev) { return Timer(make_object()); }); -TVM_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("runtime.get_rocm_stream").set_body_typed([]() { return static_cast(ROCMThreadEntry::ThreadLocal()->stream); }); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 44c7483624e6..2d3ba16de247 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -23,7 +23,7 @@ #include "rocm_module.h" #include -#include +#include #include #include @@ -231,12 +231,12 @@ Module ROCMModuleLoadBinary(void* strm) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc index a9e897f1a3c9..50f7195a2224 100644 --- a/src/runtime/rpc/rpc_channel.cc +++ b/src/runtime/rpc/rpc_channel.cc @@ -22,6 +22,8 @@ */ #include "rpc_channel.h" +#include + #include namespace tvm { @@ -41,7 +43,7 @@ size_t CallbackChannel::Send(const void* data, size_t size) { size_t CallbackChannel::Recv(void* data, size_t size) { Any ret = frecv_(size); - auto opt_bytes = ret.as(); + auto opt_bytes = ret.try_cast(); CHECK(opt_bytes.has_value()) << "CallbackChannel::Recv"; ffi::Bytes bytes = std::move(opt_bytes.value()); diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h index 62af2d92a8ac..3c8f6b404cf4 100644 --- a/src/runtime/rpc/rpc_channel.h +++ b/src/runtime/rpc/rpc_channel.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_ #define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ -#include +#include #include diff --git a/src/runtime/rpc/rpc_channel_logger.h b/src/runtime/rpc/rpc_channel_logger.h deleted file mode 100644 index 8fe68f669007..000000000000 --- a/src/runtime/rpc/rpc_channel_logger.h +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file rpc_channel_logger.h - * \brief A wrapper for RPCChannel with a NanoRPCListener for logging the commands. - */ -#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_ -#define TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_ - -#include - -#include -#include - -#include "../../support/ssize.h" -#include "../minrpc/minrpc_server_logging.h" -#include "rpc_channel.h" - -#define RX_BUFFER_SIZE 65536 - -namespace tvm { -namespace runtime { - -class Buffer { - public: - Buffer(uint8_t* data, size_t data_size_bytes) - : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0}, read_cursor_{0} {} - - size_t Write(const uint8_t* data, size_t data_size_bytes) { - size_t num_bytes_available = capacity_ - num_valid_bytes_; - size_t num_bytes_to_copy = data_size_bytes; - if (num_bytes_available < num_bytes_to_copy) { - num_bytes_to_copy = num_bytes_available; - } - - memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy); - num_valid_bytes_ += num_bytes_to_copy; - return num_bytes_to_copy; - } - - size_t Read(uint8_t* data, size_t data_size_bytes) { - size_t num_bytes_to_copy = data_size_bytes; - size_t num_bytes_available = num_valid_bytes_ - read_cursor_; - if (num_bytes_available < num_bytes_to_copy) { - num_bytes_to_copy = num_bytes_available; - } - - memcpy(data, &data_[read_cursor_], num_bytes_to_copy); - read_cursor_ += num_bytes_to_copy; - return num_bytes_to_copy; - } - - void Clear() { - num_valid_bytes_ = 0; - read_cursor_ = 0; - } - - size_t Size() const { return num_valid_bytes_; } - - private: - /*! \brief pointer to data buffer. */ - uint8_t* data_; - - /*! \brief The total number of bytes available in data_.*/ - size_t capacity_; - - /*! \brief number of valid bytes in the buffer. */ - size_t num_valid_bytes_; - - /*! \brief Read cursor position. */ - size_t read_cursor_; -}; - -/*! - * \brief A simple IO handler for MinRPCSniffer. - * - * \tparam Buffer* buffer to store received data. - */ -class SnifferIOHandler { - public: - explicit SnifferIOHandler(Buffer* receive_buffer) : receive_buffer_(receive_buffer) {} - - void MessageStart(size_t message_size_bytes) {} - - ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) { return 0; } - - void MessageDone() {} - - ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) { - return receive_buffer_->Read(buf, buf_size_bytes); - } - - void Close() {} - - void Exit(int code) {} - - private: - Buffer* receive_buffer_; -}; - -/*! - * \brief A simple rpc session that logs the received commands. - */ -class NanoRPCListener { - public: - NanoRPCListener() - : receive_buffer_(receive_storage_, receive_storage_size_bytes_), - io_(&receive_buffer_), - rpc_server_(&io_) {} - - void Listen(const uint8_t* data, size_t size) { receive_buffer_.Write(data, size); } - - void ProcessTxPacket() { - rpc_server_.ProcessOnePacket(); - ClearBuffer(); - } - - void ProcessRxPacket() { - rpc_server_.ProcessOneResponse(); - ClearBuffer(); - } - - private: - void ClearBuffer() { receive_buffer_.Clear(); } - - private: - size_t receive_storage_size_bytes_ = RX_BUFFER_SIZE; - uint8_t receive_storage_[RX_BUFFER_SIZE]; - Buffer receive_buffer_; - SnifferIOHandler io_; - MinRPCSniffer rpc_server_; - - void HandleCompleteMessage() { rpc_server_.ProcessOnePacket(); } - - static void HandleCompleteMessageCb(void* context) { - static_cast(context)->HandleCompleteMessage(); - } -}; - -/*! - * \brief A wrapper for RPCChannel, that also logs the commands sent. - * - * \tparam std::unique_ptr&& underlying RPCChannel unique_ptr. - */ -class RPCChannelLogging : public RPCChannel { - public: - explicit RPCChannelLogging(std::unique_ptr&& next) { next_ = std::move(next); } - - size_t Send(const void* data, size_t size) { - listener_.ProcessRxPacket(); - listener_.Listen((const uint8_t*)data, size); - listener_.ProcessTxPacket(); - return next_->Send(data, size); - } - - size_t Recv(void* data, size_t size) { - size_t ret = next_->Recv(data, size); - listener_.Listen((const uint8_t*)data, size); - return ret; - } - - private: - std::unique_ptr next_; - NanoRPCListener listener_; -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_ diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 710965d07824..ffe031fadfb4 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -20,9 +20,9 @@ /*! * \file rpc_device_api.cc */ +#include #include #include -#include #include @@ -150,7 +150,7 @@ class RPCDeviceAPI final : public DeviceAPI { } }; -TVM_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("device_api.rpc").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { static RPCDeviceAPI inst; DeviceAPI* ptr = &inst; *rv = static_cast(ptr); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 23edfa9bb520..7ee721405619 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -23,10 +23,9 @@ */ #include "rpc_endpoint.h" -#include +#include +#include #include -#include -#include #include #include @@ -41,7 +40,6 @@ #include "../../support/arena.h" #include "../../support/ring_buffer.h" #include "../../support/utils.h" -#include "../object_internal.h" #include "rpc_local_session.h" namespace tvm { @@ -189,14 +187,14 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); } - uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args, - bool client_mode) { - return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this); + uint64_t PackedSeqGetNumBytes(const ffi::AnyView* packed_args, int num_args, bool client_mode) { + return RPCReference::PackedSeqGetNumBytes(reinterpret_cast(packed_args), + num_args, client_mode, this); } - void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, - bool client_mode) { - RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this); + void SendPackedSeq(const ffi::AnyView* packed_args, int num_args, bool client_mode) { + RPCReference::SendPackedSeq(reinterpret_cast(packed_args), num_args, + client_mode, this); } // Endian aware IO handling @@ -228,7 +226,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // which is needed for wasm and other env that goes through C API if (obj->IsInstance()) { auto* ref = static_cast(obj); - this->template Write(kRuntimeRPCObjectRefTypeIndex); + this->template Write(runtime::TypeIndex::kRuntimeRPCObjectRef); uint64_t handle = reinterpret_cast(ref->object_handle()); this->template Write(handle); } else { @@ -246,7 +244,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { } } - void ReadObject(int* tcode, TVMValue* value) { + void ReadObject(TVMFFIAny* out) { // NOTE: for now all remote object are encoded as RPCObjectRef // follow the same disco protocol in case we would like to upgrade later // @@ -254,7 +252,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // which is needed for wasm and other env that goes through C API uint32_t type_index; this->template Read(&type_index); - if (type_index == kRuntimeRPCObjectRefTypeIndex) { + if (type_index == runtime::TypeIndex::kRuntimeRPCObjectRef) { uint64_t handle; this->template Read(&handle); // Always wrap things back in RPCObjectRef @@ -263,8 +261,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { RPCObjectRef rpc_obj(make_object(reinterpret_cast(handle), nullptr)); // Legacy ABI translation // TODO(tqchen): remove this once we have upgraded to new ABI - AnyView rpc_obj_view = rpc_obj; - AnyViewToLegacyTVMArgValue(rpc_obj_view.CopyToTVMFFIAny(), value, tcode); + *reinterpret_cast(out) = rpc_obj; object_arena_.push_back(rpc_obj); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " @@ -342,7 +339,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { return; } else { ICHECK_EQ(init_header_step_, 1); - this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->ReadArray(remote_key_->data(), remote_key_->length()); this->SwitchToState(kRecvPacketNumBytes); } } @@ -351,7 +348,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) { RPCCode code = RPCCode::kNone; this->Read(&code); - if (code >= RPCCode::kSyscallCodeStart) { this->HandleSyscall(code); } else { @@ -397,15 +393,9 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \note The ffi::PackedArgs is available until we switchstate. */ ffi::PackedArgs RecvPackedSeq() { - TVMValue* values; - int* tcodes; + ffi::AnyView* packed_args; int num_args; - RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this); - - // Legacy ABI translation - // TODO(tqchen): remove this once we have upgraded to new ABI - AnyView* packed_args = reinterpret_cast(this->ArenaAlloc(num_args)); - LegacyTVMArgsToPackedArgs(values, tcodes, num_args, packed_args); + RPCReference::RecvPackedSeq(reinterpret_cast(&packed_args), &num_args, this); return ffi::PackedArgs(packed_args, num_args); } @@ -426,12 +416,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { * \param args The arguments. */ void ReturnPackedSeq(ffi::PackedArgs args) { - // Legacy ABI translation - // TODO(tqchen): remove this once we have upgraded to new ABI - TVMValue* values = this->ArenaAlloc(args.size()); - int* tcodes = this->ArenaAlloc(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes); - RPCReference::ReturnPackedSeq(values, tcodes, args.size(), this); + RPCReference::ReturnPackedSeq(reinterpret_cast(args.data()), args.size(), + this); } /*! @@ -745,20 +731,15 @@ void RPCEndpoint::Init() { std::lock_guard lock(mutex_); RPCCode code = static_cast(all_args[0].cast()); ffi::PackedArgs args = all_args.Slice(1); - // Legacy ABI translation - // TODO(tqchen): remove this once we have upgraded to new ABI - TVMValue* values = handler_->ArenaAlloc(args.size()); - int* tcodes = handler_->ArenaAlloc(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes); // run transmission uint64_t packet_nbytes = - sizeof(code) + handler_->PackedSeqGetNumBytes(values, tcodes, args.size(), true); + sizeof(code) + handler_->PackedSeqGetNumBytes(args.data(), args.size(), true); // All packet begins with packet nbytes handler_->Write(packet_nbytes); handler_->Write(code); - handler_->SendPackedSeq(values, tcodes, args.size(), true); + handler_->SendPackedSeq(args.data(), args.size(), true); code = HandleUntilReturnEvent(true, [rv](ffi::PackedArgs args) { ICHECK_EQ(args.size(), 1); @@ -838,8 +819,12 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int even writer_.bytes_available()); } ICHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + // if the code is kShutdown, return 0 to indicate the server should exit if (code == RPCCode::kShutdown) return 0; + // if the writer has bytes available, return 2 to indicate the server should send data + // usually by calling the handler again if (writer_.bytes_available() != 0) return 2; + // otherwise, return 1 to indicate the server should and read return 1; } @@ -849,22 +834,16 @@ void RPCEndpoint::InitRemoteSession(ffi::PackedArgs args) { std::string protocol_ver = kRPCProtocolVer; uint64_t length = protocol_ver.length(); - // Legacy ABI translation - // TODO(tqchen): remove this once we have upgraded to new ABI - TVMValue* values = handler_->ArenaAlloc(args.size()); - int* tcodes = handler_->ArenaAlloc(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes); - // run transmission uint64_t packet_nbytes = sizeof(code) + sizeof(length) + length + - handler_->PackedSeqGetNumBytes(values, tcodes, args.size(), true); + handler_->PackedSeqGetNumBytes(args.data(), args.size(), true); // All packet begins with packet nbytes handler_->Write(packet_nbytes); handler_->Write(code); handler_->Write(length); handler_->WriteArray(protocol_ver.data(), length); - handler_->SendPackedSeq(values, tcodes, args.size(), true); + handler_->SendPackedSeq(args.data(), args.size(), true); code = HandleUntilReturnEvent(true, [](ffi::PackedArgs args) {}); ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); @@ -879,20 +858,14 @@ void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, ffi::PackedArgs args, RPCCode code = RPCCode::kCallFunc; uint64_t handle = reinterpret_cast(h); - // Legacy ABI translation - // TODO(tqchen): remove this once we have upgraded to new ABI - TVMValue* values = handler_->ArenaAlloc(args.size()); - int* tcodes = handler_->ArenaAlloc(args.size()); - PackedArgsToLegacyTVMArgs(args.data(), args.size(), values, tcodes); - // run transmission uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + - handler_->PackedSeqGetNumBytes(values, tcodes, args.size(), true); + handler_->PackedSeqGetNumBytes(args.data(), args.size(), true); handler_->Write(packet_nbytes); handler_->Write(code); handler_->Write(handle); - handler_->SendPackedSeq(values, tcodes, args.size(), true); + handler_->SendPackedSeq(args.data(), args.size(), true); code = HandleUntilReturnEvent(true, encode_return); ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code); diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index a420e6d92f41..195adef053bd 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ #define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ -#include +#include #include #include @@ -34,7 +34,6 @@ #include "../../support/ring_buffer.h" #include "../minrpc/rpc_reference.h" #include "rpc_channel.h" -#include "rpc_channel_logger.h" #include "rpc_session.h" namespace tvm { diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 97d62cd586fc..c178db59a230 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -21,7 +21,7 @@ * \file rpc_event_impl.cc * \brief Event driven RPC server implementation. */ -#include +#include #include @@ -44,6 +44,6 @@ ffi::Function CreateEventDrivenServer(ffi::Function fsend, std::string name, }); } -TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); +TVM_FFI_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 5761828876e1..a64bbb713250 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -23,8 +23,9 @@ */ #include "rpc_local_session.h" +#include #include -#include +#include #include #include @@ -47,10 +48,9 @@ void LocalSession::EncodeReturn(ffi::Any rv, const FEncodeReturn& encode_return) AnyView packed_args[3]; // NOTE: this is the place that we need to handle special RPC-related // ABI convention for return value passing that is built on top of Any FFI. - // We need to encode object pointers as opaque raw pointers for passing - // TODO(tqchen): move to RPC to new ABI + // first argument is always the type index. + packed_args[0] = rv.type_index(); if (rv == nullptr) { - packed_args[0] = static_cast(kTVMNullptr); packed_args[1] = rv; encode_return(ffi::PackedArgs(packed_args, 2)); } else if (rv.as()) { @@ -59,43 +59,25 @@ void LocalSession::EncodeReturn(ffi::Any rv, const FEncodeReturn& encode_return) // The second pack value is a customized deleter that deletes the NDArray. TVMFFIAny ret_any = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv)); void* opaque_handle = ret_any.v_obj; - packed_args[0] = static_cast(kTVMNDArrayHandle); - packed_args[1] = - static_cast(ObjectHandleToTVMArrayHandle(static_cast(opaque_handle))); + packed_args[1] = TVMFFINDArrayGetDLTensorPtr(opaque_handle); packed_args[2] = opaque_handle; encode_return(ffi::PackedArgs(packed_args, 3)); } else if (const auto* bytes = rv.as()) { // always pass bytes as byte array - packed_args[0] = static_cast(kTVMBytes); TVMFFIByteArray byte_arr; byte_arr.data = bytes->data; byte_arr.size = bytes->size; packed_args[1] = &byte_arr; encode_return(ffi::PackedArgs(packed_args, 2)); } else if (const auto* str = rv.as()) { - // always pass bytes as raw string - packed_args[0] = static_cast(kTVMStr); packed_args[1] = str->data; encode_return(ffi::PackedArgs(packed_args, 2)); } else if (rv.as()) { TVMFFIAny ret_any = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv)); void* opaque_handle = ret_any.v_obj; packed_args[1] = opaque_handle; - if (ret_any.type_index == ffi::TypeIndex::kTVMFFIModule) { - packed_args[0] = static_cast(kTVMModuleHandle); - } else if (ret_any.type_index == ffi::TypeIndex::kTVMFFIFunction) { - packed_args[0] = static_cast(kTVMPackedFuncHandle); - } else { - packed_args[0] = static_cast(kTVMObjectHandle); - } encode_return(ffi::PackedArgs(packed_args, 2)); } else { - AnyView temp = rv; - TVMValue val; - int type_code; - AnyViewToLegacyTVMArgValue(temp.CopyToTVMFFIAny(), &val, &type_code); - // normal POD encoding through rv - packed_args[0] = type_code; packed_args[1] = rv; encode_return(ffi::PackedArgs(packed_args, 2)); } @@ -139,7 +121,7 @@ void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) } void LocalSession::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) { - ICHECK_EQ(nbytes, GetDataSize(*from)); + ICHECK_EQ(nbytes, ffi::GetDataSize(*from)); DLTensor to; to.data = to_bytes; to.device = {kDLCPU, 0}; @@ -165,7 +147,7 @@ DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) { return DeviceAPI::Get(dev, allow_missing); } -TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); }); diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index 4019552ebcd1..9035b486c995 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ +#include #include -#include #include #include diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index c50c92ee995a..67faa3329be5 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -21,10 +21,10 @@ * \file rpc_module.cc * \brief RPC runtime module. */ -#include +#include +#include #include #include -#include #include #include @@ -120,9 +120,7 @@ class RPCWrappedFunc : public Object { case ffi::TypeIndex::kTVMFFIFunction: case ffi::TypeIndex::kTVMFFIModule: { packed_args[i] = UnwrapRemoteValueToHandle(args[i]); - // hack, need to force set the type index to the correct one - // so legacy RPC ABI translation can work - // TODO(tqchen): remove this once we migrate to use new ABI as transport + // need to force set the type index to the correct one TVMFFIAny temp = packed_args[i].CopyToTVMFFIAny(); temp.type_index = args[i].type_index(); packed_args[i] = AnyView::CopyFromTVMFFIAny(temp); @@ -290,30 +288,34 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const AnyView& arg) const { } void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) const { - int tcode = args[0].cast(); - // TODO(tqchen): move to RPC to new ABI - if (tcode == kTVMNullptr) { + int type_index = args[0].cast(); + if (type_index == ffi::TypeIndex::kTVMFFINone) { *rv = nullptr; return; - } else if (tcode == kTVMPackedFuncHandle) { + } else if (type_index == ffi::TypeIndex::kTVMFFIFunction) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto wf = std::make_shared(handle, sess_); *rv = ffi::Function( [wf](ffi::PackedArgs args, ffi::Any* rv) { return wf->operator()(args, rv); }); - } else if (tcode == kTVMModuleHandle) { + } else if (type_index == ffi::TypeIndex::kTVMFFIModule) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto n = make_object(handle, sess_); *rv = Module(n); - } else if (tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle) { + } else if (type_index == ffi::TypeIndex::kTVMFFINDArray || + type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) { ICHECK_EQ(args.size(), 3); auto tensor = args[1].cast(); void* nd_handle = args[2].cast(); *rv = NDArrayFromRemoteOpaqueHandle(sess_, tensor->data, tensor, AddRPCSessionMask(tensor->device, sess_->table_index()), nd_handle); - } else if (tcode == kTVMObjectHandle) { + } else if (type_index == ffi::TypeIndex::kTVMFFIBytes || + type_index == ffi::TypeIndex::kTVMFFIStr) { + ICHECK_EQ(args.size(), 2); + *rv = args[1]; + } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto n = make_object(handle, sess_); @@ -387,7 +389,7 @@ inline void CPUCacheFlush(int begin_index, const ffi::PackedArgs& args) { } } -TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") +TVM_FFI_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, int number, int repeat, int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, @@ -433,40 +435,40 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") } }); -TVM_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg") +TVM_FFI_REGISTER_GLOBAL("cache_flush_cpu_non_first_arg") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { CPUCacheFlush(1, args); }); // server function registration. -TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) { - parent->Import(child); -}); +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.ImportModule") + .set_body_typed([](Module parent, Module child) { parent->Import(child); }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") .set_body_typed([](Module parent, std::string name, bool query_imports) { return parent->GetFunction(name, query_imports); }); // functions to access an RPC module. -TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { +TVM_FFI_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { std::string tkey = sess->type_key(); ICHECK_EQ(tkey, "rpc"); return static_cast(sess.operator->())->LoadModule(name); }); -TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { +TVM_FFI_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { std::string tkey = parent->type_key(); ICHECK_EQ(tkey, "rpc"); static_cast(parent.operator->())->ImportModule(child); }); -TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Module m = args[0].cast(); - std::string tkey = m->type_key(); - ICHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); -}); +TVM_FFI_REGISTER_GLOBAL("rpc.SessTableIndex") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + Module m = args[0].cast(); + std::string tkey = m->type_key(); + ICHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); + }); -TVM_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.NDArrayFromRemoteOpaqueHandle") .set_body_typed([](Module mod, void* remote_array, DLTensor* template_tensor, Device dev, void* ndarray_handle) -> NDArray { return NDArrayFromRemoteOpaqueHandle(RPCModuleGetSession(mod), remote_array, template_tensor, diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index 25472de72777..b9121968137b 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -112,13 +112,14 @@ Module CreatePipeClient(std::vector cmd) { return CreateRPCSessionModule(CreateClientSession(endpt)); } -TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::vector cmd; - for (int i = 0; i < args.size(); ++i) { - cmd.push_back(args[i].cast()); - } - *rv = CreatePipeClient(cmd); -}); +TVM_FFI_REGISTER_GLOBAL("rpc.CreatePipeClient") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].cast()); + } + *rv = CreatePipeClient(cmd); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index 823fa232a953..eeb76c2b1512 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -21,7 +21,7 @@ * \file rpc_server_env.cc * \brief Server environment of the RPC. */ -#include +#include #include "../file_utils.h" @@ -35,14 +35,14 @@ std::string RPCGetPath(const std::string& name) { return (*f)(name).cast(); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.upload") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string file_name = RPCGetPath(args[0].cast()); auto data = args[1].cast(); SaveBinaryToFile(file_name, data); }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.download") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.download") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string file_name = RPCGetPath(args[0].cast()); std::string data; @@ -51,7 +51,7 @@ TVM_REGISTER_GLOBAL("tvm.rpc.server.download") *rv = ffi::Bytes(data); }); -TVM_REGISTER_GLOBAL("tvm.rpc.server.remove") +TVM_FFI_REGISTER_GLOBAL("tvm.rpc.server.remove") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { std::string file_name = RPCGetPath(args[0].cast()); RemoveFile(file_name); diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 76e07e00fb2b..ace9cf9b9485 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -23,8 +23,8 @@ */ #include "rpc_session.h" +#include #include -#include #include #include @@ -35,7 +35,7 @@ namespace runtime { bool RPCSession::IsAsync() const { return false; } void RPCSession::SendException(FAsyncCallback callback, const char* msg) { - AnyView packed_args[1] = {msg}; + ffi::AnyView packed_args[1] = {msg}; callback(RPCCode::kException, ffi::PackedArgs(packed_args, 1)); } @@ -51,7 +51,7 @@ void RPCSession::AsyncCallFunc(PackedFuncHandle func, ffi::PackedArgs packed_arg void RPCSession::AsyncCopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes, RPCSession::FAsyncCallback callback) { - AnyView packed_args[1] = {nullptr}; + ffi::AnyView packed_args[1] = {nullptr}; try { this->CopyToRemote(local_from_bytes, remote_to, nbytes); @@ -63,7 +63,7 @@ void RPCSession::AsyncCopyToRemote(void* local_from_bytes, DLTensor* remote_to, void RPCSession::AsyncCopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes, RPCSession::FAsyncCallback callback) { - AnyView packed_args[1] = {nullptr}; + ffi::AnyView packed_args[1] = {nullptr}; try { this->CopyFromRemote(remote_from, local_to_bytes, nbytes); diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 271e26dfd04e..c0ec2067eb5f 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,8 +24,10 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ +#include #include -#include +#include +#include #include #include diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 286d143bad6c..2564242bdf0f 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,7 +21,7 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ -#include +#include #include @@ -98,9 +98,6 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string k } std::unique_ptr channel = std::make_unique(sock); - if (enable_logging) { - channel.reset(new RPCChannelLogging(std::move(channel))); - } auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key); endpt->InitRemoteSession(init_seq); @@ -124,7 +121,7 @@ void RPCServerLoop(ffi::Function fsend, ffi::Function frecv) { ->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); int port = args[1].cast(); auto key = args[2].cast(); @@ -132,7 +129,7 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi: *rv = RPCClientConnect(url, port, key, enable_logging, args.Slice(4)); }); -TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("rpc.ServerLoop").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_int = args[0].as()) { RPCServerLoop(opt_int.value()); } else { @@ -165,7 +162,7 @@ class SimpleSockHandler : public dmlc::Stream { support::TCPSocket sock_; }; -TVM_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) { +TVM_FFI_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, String msg) { auto handler = SimpleSockHandler(sockfd); RPCReference::ReturnException(msg.c_str(), &handler); return; diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h deleted file mode 100644 index 3037c8d84ff0..000000000000 --- a/src/runtime/runtime_base.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file runtime_base.h - * \brief Base of all C APIs - */ -#ifndef TVM_RUNTIME_RUNTIME_BASE_H_ -#define TVM_RUNTIME_RUNTIME_BASE_H_ - -#include - -#include - -/*! \brief macro to guard beginning and end section of all functions */ -#define API_BEGIN() try { -/*! \brief every function starts with API_BEGIN(); - and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() \ - } \ - catch (::tvm::runtime::EnvErrorAlreadySet & _except_) { \ - return -2; \ - } \ - catch (std::exception & _except_) { \ - return TVMAPIHandleException(_except_); \ - } \ - return 0; // NOLINT(*) -/*! - * \brief every function starts with API_BEGIN(); - * and finishes with API_END() or API_END_HANDLE_ERROR - * The finally clause contains procedure to cleanup states when an error happens. - */ -#define API_END_HANDLE_ERROR(Finalize) \ - } \ - catch (::tvm::runtime::EnvErrorAlreadySet & _except_) { \ - return -2; \ - } \ - catch (std::exception & _except_) { \ - Finalize; \ - return TVMAPIHandleException(_except_); \ - } \ - return 0; // NOLINT(*) - -/*! - * \brief handle exception throwed out - * \param e the exception - * \return the return value of API after exception is handled - */ -int TVMAPIHandleException(const std::exception& e); - -#endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/src/runtime/spirv/spirv_shader.h b/src/runtime/spirv/spirv_shader.h index 293dc5b78638..06b331d3334e 100644 --- a/src/runtime/spirv/spirv_shader.h +++ b/src/runtime/spirv/spirv_shader.h @@ -20,10 +20,11 @@ #ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ #define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ -#include +#include +#include +#include #include #include -#include #include diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 47633b326cba..08beb8cbc530 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -24,10 +24,9 @@ */ #include "./static_library.h" -#include +#include +#include #include -#include -#include #include @@ -127,8 +126,8 @@ Module LoadStaticLibrary(const std::string& filename, Array func_names) return Module(node); } -TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_static_library") +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_static_library") .set_body_typed(StaticLibraryNode::LoadFromBinary); } // namespace runtime diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 818e77ebd980..196d2448b93f 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -26,7 +26,7 @@ #ifndef TVM_RUNTIME_STATIC_LIBRARY_H_ #define TVM_RUNTIME_STATIC_LIBRARY_H_ -#include +#include #include #include diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index 30fca708b8e8..46c08e4afd9a 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -21,9 +21,9 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ +#include +#include #include -#include -#include #include @@ -112,13 +112,14 @@ class SystemLibModuleRegistry { std::unordered_map lib_map_; }; -TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - std::string symbol_prefix = ""; - if (args.size() != 0) { - symbol_prefix = args[0].cast(); - } - *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); -}); +TVM_FFI_REGISTER_GLOBAL("runtime.SystemLib") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + std::string symbol_prefix = ""; + if (args.size() != 0) { + symbol_prefix = args[0].cast(); + } + *rv = SystemLibModuleRegistry::Global()->GetOrCreateModule(symbol_prefix); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 457f44799d7c..d266fb7da8fa 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -22,12 +22,11 @@ * \brief Threadpool for multi-threading runtime. */ #include +#include +#include +#include #include -#include -#include #include -#include -#include #include #if TVM_THREADPOOL_USE_OPENMP #include @@ -379,7 +378,7 @@ class ThreadPool { * \brief args[0] is the AffinityMode, args[1] is the number of threads. * args2 is a list of CPUs which is used to set the CPU affinity. */ -TVM_REGISTER_GLOBAL("runtime.config_threadpool") +TVM_FFI_REGISTER_GLOBAL("runtime.config_threadpool") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { threading::ThreadGroup::AffinityMode mode = static_cast(args[0].cast()); @@ -395,7 +394,7 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool") threading::Configure(mode, nthreads, cpus); }); -TVM_REGISTER_GLOBAL("runtime.NumThreads").set_body_typed([]() -> int32_t { +TVM_FFI_REGISTER_GLOBAL("runtime.NumThreads").set_body_typed([]() -> int32_t { return threading::NumThreads(); }); diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 049e6467d1fc..914fe67819de 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -24,7 +24,7 @@ #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ -#include +#include #include #include diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 01c0f1603fa2..ef835f20d171 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -21,8 +21,9 @@ * \file threading_backend.cc * \brief Native threading backend */ +#include +#include #include -#include #include #if defined(__linux__) || defined(__ANDROID__) @@ -436,8 +437,8 @@ int MaxConcurrency() { // This global function can be used by disco runtime to bind processes // to CPUs. -TVM_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") - .set_body_typed([](IntTuple cpu_ids) { +TVM_FFI_REGISTER_GLOBAL("tvm.runtime.threading.set_current_thread_affinity") + .set_body_typed([](ffi::Shape cpu_ids) { SetThreadAffinity(CURRENT_THREAD_HANDLE, std::vector{cpu_ids.begin(), cpu_ids.end()}); }); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index f1e0ef587ecc..fb4776c98afc 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -20,11 +20,10 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_ #define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_ -#include +#include +#include #include #include -#include -#include #include #include diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index fcb3e764bf86..12181f8c159d 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -455,12 +455,13 @@ VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { return const_cast(const_cast(this)->device(device_id)); } -TVM_REGISTER_GLOBAL("device_api.vulkan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = VulkanDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.vulkan") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = VulkanDeviceAPI::Global(); + *rv = static_cast(ptr); + }); -TVM_REGISTER_GLOBAL("device_api.vulkan.get_target_property") +TVM_FFI_REGISTER_GLOBAL("device_api.vulkan.get_target_property") .set_body_typed([](Device dev, const std::string& property) { ffi::Any rv; VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 600d7d6f870c..063dc5bde009 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -20,7 +20,7 @@ #include "vulkan_module.h" #include -#include +#include #include "../file_utils.h" #include "vulkan_wrapped_func.h" @@ -64,9 +64,9 @@ Module VulkanModuleLoadBinary(void* strm) { return VulkanModuleCreate(smap, fmap, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index ab212c2eade4..f4922a1bf01d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -289,7 +289,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, auto fit = fmap_.find(func_name); ICHECK(fit != fmap_.end()); for (DLDataType arg_type : fit->second.arg_types) { - if (arg_type.code == kTVMOpaqueHandle) { + if (arg_type.code == kDLOpaqueHandle) { push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); ++num_buffer; } else { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 51df91238954..13f272d7c946 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include namespace tvm { @@ -46,7 +46,7 @@ void IRBuilderFrameNode::AddCallback(ffi::TypedFunction callback) { IRBuilder::IRBuilder() { ObjectPtr n = make_object(); n->frames.clear(); - n->result = NullOpt; + n->result = std::nullopt; data_ = n; } @@ -60,7 +60,7 @@ void IRBuilder::EnterWithScope() { CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: " << n->frames.size() << ". Please use a fresh new builder every time building IRs"; - n->result = NullOpt; + n->result = std::nullopt; std::vector* stack = ThreadLocalBuilderStack(); stack->push_back(*this); } @@ -101,20 +101,24 @@ void Namer::Name(ObjectRef node, String name) { TVM_REGISTER_NODE_TYPE(IRBuilderFrameNode); TVM_REGISTER_NODE_TYPE(IRBuilderNode); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameEnter") .set_body_method(&IRBuilderFrameNode::EnterWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameExit") .set_body_method(&IRBuilderFrameNode::ExitWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderFrameAddCallback") .set_body_method(&IRBuilderFrameNode::AddCallback); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return IRBuilder(); }); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter") + .set_body_method(&IRBuilder::EnterWithScope); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit") + .set_body_method(&IRBuilder::ExitWithScope); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope") + .set_body_typed(IRBuilder::IsInScope); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method(&IRBuilderNode::Get); -TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.IRBuilderName") + .set_body_typed(IRBuilder::Name); } // namespace ir_builder } // namespace script diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 2b02a80e3eaf..6cb61147a96a 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include namespace tvm { diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 2fb8da1e9f24..270f4623ef0c 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -16,9 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include -#include #include #include #include @@ -106,7 +106,7 @@ Optional ModuleGetAttr(const String& key) { return frame->attrs[key].cast(); } } - return NullOpt; + return std::nullopt; } void ModuleSetAttr(const String& key, const Optional& value, bool allow_override) { @@ -165,14 +165,14 @@ VDevice LookupVDevice(String target_kind, int device_index) { return VDevice(); } -TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetAttr").set_body_typed(ModuleGetAttr); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleSetAttr").set_body_typed(ModuleSetAttr); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); -TVM_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGetAttr").set_body_typed(ModuleGetAttr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleSetAttr").set_body_typed(ModuleSetAttr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index 74caa95e9012..fcf9e0eb2c5b 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -54,7 +54,7 @@ Expr MakeCallTIRDist(Expr func, Tuple args, Array n = make_object(); const IRBuilder& ir_builder = IRBuilder::Current(); - Optional mod = NullOpt; + Optional mod = std::nullopt; if (const Optional mod_frame = ir_builder->GetLastFrame()) { mod = tvm::IRModule(mod_frame.value()->functions); } @@ -144,12 +144,13 @@ void FuncRetValue(const tvm::relax::Expr& value) { frame->output = std::move(normalized_value); } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo").set_body_typed(FuncRetStructInfo); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetStructInfo") + .set_body_typed(FuncRetStructInfo); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); ///////////////////////////// BindingBlock ////////////////////////////// @@ -191,9 +192,9 @@ void DataflowBlockOutput(const Array& vars) { } } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") .set_body_typed(DataflowBlockOutput); /////////////////////////////// Bindings /////////////////////////////// @@ -236,9 +237,9 @@ tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding) { return binding->var; } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.EmitVarBinding").set_body_typed(EmitVarBinding); /////////////////////////////// SeqExpr /////////////////////////////// @@ -247,15 +248,15 @@ SeqExprFrame SeqExpr() { return SeqExprFrame(n); } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.SeqExpr").set_body_typed(SeqExpr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.SeqExpr").set_body_typed(SeqExpr); ///////////////////////////// If Then Else ///////////////////////////// IfFrame If(tvm::relax::Expr condition) { ObjectPtr n = make_object(); n->condition = condition; - n->then_expr = NullOpt; - n->else_expr = NullOpt; + n->then_expr = std::nullopt; + n->else_expr = std::nullopt; return IfFrame(n); } @@ -269,9 +270,9 @@ ElseFrame Else() { return ElseFrame(n); } -TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); } // namespace relax } // namespace ir_builder diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 1d99a1b85252..da772f608579 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -57,10 +57,10 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt PrimFuncFrame PrimFunc(bool is_private) { ObjectPtr n = make_object(); - n->name = NullOpt; + n->name = std::nullopt; n->is_private = is_private; n->args.clear(); - n->ret_type = NullOpt; + n->ret_type = std::nullopt; n->buffer_map.clear(); n->attrs = {}; n->env_threads.clear(); @@ -157,14 +157,14 @@ BlockFrame Block(String name, bool no_realize) { ObjectPtr n = make_object(); n->name = name; n->iter_vars.clear(); - n->reads = NullOpt; - n->writes = NullOpt; - n->init = NullOpt; + n->reads = std::nullopt; + n->writes = std::nullopt; + n->init = std::nullopt; n->alloc_buffers.clear(); n->match_buffers.clear(); - n->annotations = NullOpt; + n->annotations = std::nullopt; n->iter_values.clear(); - n->predicate = NullOpt; + n->predicate = std::nullopt; n->no_realize = no_realize; return BlockFrame(n); } @@ -335,7 +335,7 @@ Array Remap(String kinds, Array bindings, DataType dtype) { n->f_make_for_loop = [annotations](Array vars, Array doms, tvm::tir::Stmt body) { \ ICHECK_EQ(vars.size(), 1); \ ICHECK_EQ(doms.size(), 1); \ - return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, \ + return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ annotations.value_or(Map())); \ }; \ return ForFrame(n); \ @@ -386,7 +386,7 @@ ForFrame Grid(Array extents) { Range dom = doms[i]; Var var = vars[i]; body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), - /*thread_binding=*/NullOpt, /*annotations=*/{}); + /*thread_binding=*/std::nullopt, /*annotations=*/{}); } return body; }; @@ -486,7 +486,7 @@ AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, AttrFrame Attr(ffi::Any node, String attr_key, PrimExpr value) { // convert POD value to PrimExpr if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - node = node.as().value(); + node = node.cast(); } ObjectPtr n = make_object(); n->node = node.cast(); @@ -504,8 +504,8 @@ WhileFrame While(PrimExpr condition) { IfFrame If(PrimExpr condition) { ObjectPtr n = make_object(); n->condition = condition; - n->then_stmts = NullOpt; - n->else_stmts = NullOpt; + n->then_stmts = std::nullopt; + n->else_stmts = std::nullopt; return IfFrame(n); } @@ -531,7 +531,7 @@ Var EnvThread(String thread_tag, DataType dtype) { } void BufferStore(Buffer buffer, PrimExpr value, Array indices, - Optional predicate = NullOpt) { + Optional predicate = std::nullopt) { runtime::DataType buffer_dtype = buffer->dtype; bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); @@ -657,9 +657,9 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) Namer::Name(var->var, name); }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Buffer").set_body_typed(BufferDecl); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.PrimFunc").set_body_typed(PrimFunc); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Arg") .set_body_typed([](String name, ObjectRef obj) -> ObjectRef { using namespace tvm::tir; if (auto var = obj.as()) { @@ -671,45 +671,45 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Arg") LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); throw; }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncName").set_body_typed(FuncName); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncAttrs").set_body_typed(FuncAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.FuncRet").set_body_typed(FuncRet); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuffer); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Unroll").set_body_typed(Unroll); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(ThreadBinding); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Attr").set_body_typed(Attr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.While").set_body_typed(While); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") .set_body_typed([](ffi::Variant thread_tag_or_var, PrimExpr extent) { if (auto var = thread_tag_or_var.as()) { return LaunchThread(var.value(), extent); @@ -721,60 +721,60 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") throw; } }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); #define TVM_TMP_STR(x) #x -#define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType) \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); - -TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float); -TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt); -TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int); - -#define TVM_REGISTER_GLOBAL_LANES(Prefix, Func) \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ - TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); - -#define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ - TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); - -TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); -TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); -TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); -TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); - -TVM_REGISTER_GLOBAL("script.ir_builder.tir.min") +#define TVM_FFI_REGISTER_GLOBAL_SIZE(Prefix, DType) \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); + +TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float); +TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt); +TVM_FFI_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int); + +#define TVM_FFI_REGISTER_GLOBAL_LANES(Prefix, Func) \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ + TVM_FFI_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); + +#define TVM_FFI_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ + TVM_FFI_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); + +TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); +TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); +TVM_FFI_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Float4E2M1FN").set_body_typed(Float4E2M1FN); +TVM_FFI_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); + +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.min") .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.max") +TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.max") .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); } // namespace tir } // namespace ir_builder diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 615032ea0036..8f1fd77d782d 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -16,9 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include +#include #include -#include #include namespace tvm { @@ -177,7 +177,7 @@ ScopeDoc::ScopeDoc(Optional lhs, ExprDoc rhs, Array body) { ScopeDoc::ScopeDoc(ExprDoc rhs, Array body) { ObjectPtr n = make_object(); - n->lhs = NullOpt; + n->lhs = std::nullopt; n->rhs = rhs; n->body = body; this->data_ = std::move(n); @@ -234,49 +234,50 @@ DocStringDoc::DocStringDoc(String docs) { } TVM_REGISTER_NODE_TYPE(DocNode); -TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") +TVM_FFI_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") .set_body_typed([](Doc doc, Array source_paths) { doc->source_paths = source_paths; }); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr") +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocAttr") .set_body_method(&ExprDocNode::Attr); -TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::operator[]); -TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::operator[]); +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprDocCall") .set_body_method, Array, Array>( &ExprDocNode::Call); TVM_REGISTER_NODE_TYPE(StmtDocNode); -TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment") +TVM_FFI_REGISTER_GLOBAL("script.printer.StmtDocSetComment") .set_body_typed([](StmtDoc doc, Optional comment) { doc->comment = comment; }); TVM_REGISTER_NODE_TYPE(StmtBlockDocNode); -TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array stmts) { +TVM_FFI_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array stmts) { return StmtBlockDoc(stmts); }); TVM_REGISTER_NODE_TYPE(LiteralDocNode); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); +TVM_FFI_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); -TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); +TVM_FFI_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { + return IdDoc(name); +}); TVM_REGISTER_NODE_TYPE(AttrAccessDocNode); -TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc").set_body_typed([](ExprDoc value, String attr) { - return AttrAccessDoc(value, attr); -}); +TVM_FFI_REGISTER_GLOBAL("script.printer.AttrAccessDoc") + .set_body_typed([](ExprDoc value, String attr) { return AttrAccessDoc(value, attr); }); TVM_REGISTER_NODE_TYPE(IndexDocNode); -TVM_REGISTER_GLOBAL("script.printer.IndexDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.IndexDoc") .set_body_typed([](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); TVM_REGISTER_NODE_TYPE(CallDocNode); -TVM_REGISTER_GLOBAL("script.printer.CallDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.CallDoc") .set_body_typed([](ExprDoc callee, // Array args, // Array kwargs_keys, // @@ -285,104 +286,103 @@ TVM_REGISTER_GLOBAL("script.printer.CallDoc") }); TVM_REGISTER_NODE_TYPE(OperationDocNode); -TVM_REGISTER_GLOBAL("script.printer.OperationDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.OperationDoc") .set_body_typed([](int32_t kind, Array operands) { return OperationDoc(OperationDocNode::Kind(kind), operands); }); TVM_REGISTER_NODE_TYPE(LambdaDocNode); -TVM_REGISTER_GLOBAL("script.printer.LambdaDoc").set_body_typed([](Array args, ExprDoc body) { - return LambdaDoc(args, body); -}); +TVM_FFI_REGISTER_GLOBAL("script.printer.LambdaDoc") + .set_body_typed([](Array args, ExprDoc body) { return LambdaDoc(args, body); }); TVM_REGISTER_NODE_TYPE(TupleDocNode); -TVM_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { +TVM_FFI_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { return TupleDoc(elements); }); TVM_REGISTER_NODE_TYPE(ListDocNode); -TVM_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { +TVM_FFI_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { return ListDoc(elements); }); TVM_REGISTER_NODE_TYPE(DictDocNode); -TVM_REGISTER_GLOBAL("script.printer.DictDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.DictDoc") .set_body_typed([](Array keys, Array values) { return DictDoc(keys, values); }); TVM_REGISTER_NODE_TYPE(SliceDocNode); -TVM_REGISTER_GLOBAL("script.printer.SliceDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.SliceDoc") .set_body_typed([](Optional start, Optional stop, Optional step) { return SliceDoc(start, stop, step); }); TVM_REGISTER_NODE_TYPE(AssignDocNode); -TVM_REGISTER_GLOBAL("script.printer.AssignDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.AssignDoc") .set_body_typed([](ExprDoc lhs, Optional rhs, Optional annotation) { return AssignDoc(lhs, rhs, annotation); }); TVM_REGISTER_NODE_TYPE(IfDocNode); -TVM_REGISTER_GLOBAL("script.printer.IfDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.IfDoc") .set_body_typed([](ExprDoc predicate, Array then_branch, Array else_branch) { return IfDoc(predicate, then_branch, else_branch); }); TVM_REGISTER_NODE_TYPE(WhileDocNode); -TVM_REGISTER_GLOBAL("script.printer.WhileDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.WhileDoc") .set_body_typed([](ExprDoc predicate, Array body) { return WhileDoc(predicate, body); }); TVM_REGISTER_NODE_TYPE(ForDocNode); -TVM_REGISTER_GLOBAL("script.printer.ForDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.ForDoc") .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array body) { return ForDoc(lhs, rhs, body); }); TVM_REGISTER_NODE_TYPE(ScopeDocNode); -TVM_REGISTER_GLOBAL("script.printer.ScopeDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.ScopeDoc") .set_body_typed([](Optional lhs, ExprDoc rhs, Array body) { return ScopeDoc(lhs, rhs, body); }); TVM_REGISTER_NODE_TYPE(ExprStmtDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) { +TVM_FFI_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) { return ExprStmtDoc(expr); }); TVM_REGISTER_NODE_TYPE(AssertDocNode); -TVM_REGISTER_GLOBAL("script.printer.AssertDoc") - .set_body_typed([](ExprDoc test, Optional msg = NullOpt) { +TVM_FFI_REGISTER_GLOBAL("script.printer.AssertDoc") + .set_body_typed([](ExprDoc test, Optional msg = std::nullopt) { return AssertDoc(test, msg); }); TVM_REGISTER_NODE_TYPE(ReturnDocNode); -TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) { +TVM_FFI_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) { return ReturnDoc(value); }); TVM_REGISTER_NODE_TYPE(FunctionDocNode); -TVM_REGISTER_GLOBAL("script.printer.FunctionDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.FunctionDoc") .set_body_typed([](IdDoc name, Array args, Array decorators, Optional return_type, Array body) { return FunctionDoc(name, args, decorators, return_type, body); }); TVM_REGISTER_NODE_TYPE(ClassDocNode); -TVM_REGISTER_GLOBAL("script.printer.ClassDoc") +TVM_FFI_REGISTER_GLOBAL("script.printer.ClassDoc") .set_body_typed([](IdDoc name, Array decorators, Array body) { return ClassDoc(name, decorators, body); }); TVM_REGISTER_NODE_TYPE(CommentDocNode); -TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) { +TVM_FFI_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) { return CommentDoc(comment); }); TVM_REGISTER_NODE_TYPE(DocStringDocNode); -TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) { +TVM_FFI_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) { return DocStringDoc(docs); }); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 2bb8e2a1dc51..85b5b755d253 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include #include -#include #include #include @@ -727,7 +727,7 @@ String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { return result.substr(0, last_space); } -TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); +TVM_FFI_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); } // namespace printer } // namespace script diff --git a/src/script/printer/ir/distributed.cc b/src/script/printer/ir/distributed.cc index 29e45bc5c598..194c8f52b1aa 100644 --- a/src/script/printer/ir/distributed.cc +++ b/src/script/printer/ir/distributed.cc @@ -26,16 +26,15 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](runtime::ShapeTuple n, ObjectPath n_p, IRDocsifier d) -> Doc { - int s = n.size(); - Array results; - results.reserve(s); - for (int i = 0; i < s; ++i) { - results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayIndex(i))); - } - return TupleDoc(results); - }); + .set_dispatch("", [](ffi::Shape n, ObjectPath n_p, IRDocsifier d) -> Doc { + int s = n.size(); + Array results; + results.reserve(s); + for (int i = 0; i < s; ++i) { + results.push_back(d->AsDoc(Integer(n[i]), n_p->ArrayIndex(i))); + } + return TupleDoc(results); + }); } // namespace printer } // namespace script diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 7187c2d512a4..c8f029d225a8 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -100,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->stmts.push_back(func.value()); } else if (auto expr = doc.as()) { ExprDoc lhs = IdDoc(gv->name_hint); - AssignDoc assignment(lhs, expr.value(), NullOpt); + AssignDoc assignment(lhs, expr.value(), std::nullopt); (*f)->stmts.push_back(assignment); } else { LOG(FATAL) << "TypeError: " diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index 05d70ea4dab4..caa5cbe895bd 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -49,7 +49,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::vector items{dict.begin(), dict.end()}; bool is_str_map = true; for (const auto& kv : items) { - if (!kv.first.as()) { + if (!kv.first.as()) { is_str_map = false; break; } diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index f74e08848920..8c72eb4ef318 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include #include -#include #include #include @@ -62,14 +61,14 @@ IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const St void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) { ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; - obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); + obj2info.insert({obj, VariableInfo{std::move(doc_factory), std::nullopt}}); frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); } Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { auto it = obj2info.find(obj); if (it == obj2info.end()) { - return NullOpt; + return std::nullopt; } return it->second.creator(); } @@ -82,7 +81,8 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { if (index == static_cast(array.size())) { array.push_back(obj); } - return IdDoc("metadata")[{LiteralDoc::Str(key, NullOpt)}][{LiteralDoc::Int(index, NullOpt)}]; + return IdDoc( + "metadata")[{LiteralDoc::Str(key, std::nullopt)}][{LiteralDoc::Int(index, std::nullopt)}]; } void IRDocsifierNode::AddGlobalInfo(const String& name, const GlobalInfo& ginfo) { @@ -138,13 +138,13 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, } visited_.insert(obj); stack_.push_back(obj); - if (obj->IsInstance()) { - const ArrayObj* array = static_cast(obj); + if (obj->IsInstance()) { + const ffi::ArrayObj* array = static_cast(obj); for (Any element : *array) { this->RecursiveVisitAny(&element); } - } else if (obj->IsInstance()) { - const MapObj* map = static_cast(obj); + } else if (obj->IsInstance()) { + const ffi::MapObj* map = static_cast(obj); for (std::pair kv : *map) { this->RecursiveVisitAny(&kv.first); this->RecursiveVisitAny(&kv.second); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 27301482ccc4..5e414e90c262 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -75,8 +75,8 @@ ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // N } TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); (*p) << '['; for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { @@ -88,8 +88,8 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); (*p) << '{'; for (auto it = op->begin(); it != op->end(); ++it) { if (it != op->begin()) { @@ -107,8 +107,8 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); (*p) << '['; for (size_t i = 0; i < op->size; ++i) { if (i != 0) { diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index c8b616b4bcb5..22baf1c21c74 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -44,7 +44,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc { using relax::StructInfo; using relax::MatchStructInfo; - Optional ann = NullOpt; + Optional ann = std::nullopt; if (d->cfg->show_all_struct_info) { ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); } @@ -83,7 +83,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](relax::If n, ObjectPath n_p, IRDocsifier d) -> Doc { - return PrintIfExpr(n, n_p, d, NullOpt, NullOpt); + return PrintIfExpr(n, n_p, d, std::nullopt, std::nullopt); }); TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax); diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index bd259609127f..82c2083044ec 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -122,7 +122,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op) && !n->op.same_as(call_tir_with_grad_op) && !n->op.same_as(call_tir_local_view) && !n->op.same_as(call_tir_inplace_op)) { - return NullOpt; + return std::nullopt; } ICHECK(n->args.size() == 2 || n->args.size() == 3); ICHECK(n->sinfo_args.size() == 1); @@ -206,7 +206,7 @@ Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& Optional PrintAssertOp(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& assert_op = Op::Get("relax.assert_op"); if (!n->op.same_as(assert_op)) { - return NullOpt; + return std::nullopt; } ICHECK(n->args.size() >= 2); // special handling: it is important to indicate that the format string (second argument) @@ -226,7 +226,7 @@ Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& hint_on_device_op = Op::Get("relax.hint_on_device"); if (!n->op.same_as(hint_on_device_op)) { - return NullOpt; + return std::nullopt; } Array args; @@ -246,7 +246,7 @@ Optional PrintToVDevice(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& to_vdevice_op = Op::Get("relax.to_vdevice"); if (!n->op.same_as(to_vdevice_op)) { - return NullOpt; + return std::nullopt; } Array args; @@ -269,7 +269,7 @@ Optional PrintRelaxPrint(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& print_op = Op::Get("relax.print"); if (!n->op.same_as(print_op)) { - return NullOpt; + return std::nullopt; } ICHECK(n->args.size() >= 1); // special handling: it is important to indicate that the format string (first argument) diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index f6cbde0b4b23..808177b15020 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -83,7 +83,7 @@ Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) DataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { - return NullOpt; + return std::nullopt; } if (dtype == DataType::Int(8)) { @@ -128,7 +128,7 @@ Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) } else if (dtype == DataType::Bool()) { return LiteralDoc::Boolean(*reinterpret_cast(data), p); } else { - return NullOpt; + return std::nullopt; } } @@ -154,7 +154,7 @@ Doc PrintRelaxVar(relax::Var n, ObjectPath p, IRDocsifier d) { ExprDoc ann = d->AsDoc(n->struct_info_, p->Attr("struct_info_")); Frame f = d->frames.back(); ExprDoc var = DefineVar(n, f, d); - f->stmts.push_back(AssignDoc(var, NullOpt, ann)); + f->stmts.push_back(AssignDoc(var, std::nullopt, ann)); } return d->GetVarDoc(n).value(); } diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 655e03d3b2f5..99e30ab520a5 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->is_func = true; (*f)->func_vars = &func_vars; // Step 1. Print the return type - Optional ret_type = NullOpt; + Optional ret_type = std::nullopt; if (const auto& func_sinfo = relax::MatchStructInfo(n)) { ret_type = d->AsDoc(func_sinfo.value()->ret, // n_p->Attr("struct_info_")->Attr("ret")); @@ -68,7 +68,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0, l = n->params.size(); i < l; ++i) { params.push_back(AssignDoc( /*lhs=*/DefineVar(n->params[i], *f, d), - /*rhs=*/NullOpt, StructInfoAsAnn(n->params[i], params_p->ArrayIndex(i), d, NullOpt))); + /*rhs=*/std::nullopt, + StructInfoAsAnn(n->params[i], params_p->ArrayIndex(i), d, std::nullopt))); } } // Step 3. Clean up func variables diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc index 1ac0b5ba14df..c0010034e436 100644 --- a/src/script/printer/relax/region.cc +++ b/src/script/printer/relax/region.cc @@ -88,7 +88,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array non_dataflow_vars; Array stmts = PrintBindingBlock(n, n_p, d, &non_dataflow_vars); stmts.push_back(ExprStmtDoc(Relax(d, "output")->Call(non_dataflow_vars))); - return ScopeDoc(NullOpt, Relax(d, "dataflow")->Call({}), stmts); + return ScopeDoc(std::nullopt, Relax(d, "dataflow")->Call({}), stmts); }); TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 2ae5663385f3..eafd67365dad 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -60,7 +60,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { } IdDoc var = d->Define(n, GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); var->source_paths.push_back(n_p); - f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), NullOpt)); + f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), std::nullopt)); } if (Optional doc = d->GetVarDoc(n)) { return doc.value(); @@ -110,7 +110,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "under relax's dispatch token"; if (!f->module_alias_printed) { // If the module_alias is not defined before, define it. - f->stmts.push_back(AssignDoc(IdDoc(d->cfg->module_alias), doc.value(), NullOpt)); + f->stmts.push_back(AssignDoc(IdDoc(d->cfg->module_alias), doc.value(), std::nullopt)); f->module_alias_printed = true; } return IdDoc(d->cfg->module_alias); diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc index 9b26a942be82..3d7abe821745 100644 --- a/src/script/printer/relax/type.cc +++ b/src/script/printer/relax/type.cc @@ -82,7 +82,7 @@ TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::TensorTypeNode, ReprPrintRelax); TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax); -TVM_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); +TVM_FFI_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax); } // namespace printer } // namespace script diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 989e9a63b1d9..e28fd9c8036b 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -82,7 +82,7 @@ inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const IRDocsif inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& v_p, const IRDocsifier& d, const Optional& rhs) { if (!v->struct_info_.defined()) { - return NullOpt; + return std::nullopt; } bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info; @@ -94,7 +94,7 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& } } if (attempt_to_hide_struct_info) { - Optional inferred_sinfo = NullOpt; + Optional inferred_sinfo = std::nullopt; if (auto opt = rhs.as()) { auto call = opt.value(); if (auto opt = call->op.as()) { @@ -103,10 +103,10 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& static auto op_map_infer_struct_info = Op::GetAttrMap("FInferStructInfo"); - auto temp_builder = relax::BlockBuilder::Create(NullOpt); + auto temp_builder = relax::BlockBuilder::Create(std::nullopt); inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder); } else if (auto opt = call->op.as()) { - auto temp_builder = relax::BlockBuilder::Create(NullOpt); + auto temp_builder = relax::BlockBuilder::Create(std::nullopt); inferred_sinfo = DeriveCallRetStructInfo(opt.value(), call, temp_builder, temp_builder->GetAnalyzer()); } @@ -125,7 +125,7 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& } if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) { - return NullOpt; + return std::nullopt; } } return d->AsDoc(v->struct_info_, v_p->Attr("struct_info_")); diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 178ed4fe75fa..519bc9d66ca6 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -99,7 +99,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // } else { rhs = rhs->Call({dom}); } - (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt)); + (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, std::nullopt)); }; auto print_remapped_iter_var = [&]() { @@ -129,10 +129,10 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; } ExprDoc rhs = TIR(d, "axis")->Attr("remap"); - ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt); + ExprDoc binding_str = LiteralDoc::Str(binding_type, std::nullopt); binding_str->source_paths = std::move(binding_paths); rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)}); - (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt)); + (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, std::nullopt)); remap_vars_indices.clear(); } }; @@ -182,7 +182,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // IdDoc lhs = DefineBuffer(buffer, *frame, d); ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d, BufferVarDefinition::DataPointer); - (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + (*frame)->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } // Step 6. Handle `match_buffer` for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { @@ -196,7 +196,8 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // tir::Stmt init = block->init.value(); With init_frame(d, init); AsDocBody(init, block_p->Attr("init"), init_frame->get(), d); - (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR(d, "init")->Call({}), (*init_frame)->stmts)); + (*frame)->stmts.push_back( + ScopeDoc(std::nullopt, TIR(d, "init")->Call({}), (*init_frame)->stmts)); } // Step 8. Handle block body AsDocBody(block->body, block_p->Attr("body"), frame->get(), d); @@ -204,9 +205,9 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // Array kwargs_values; if (!realize) { kwargs_keys.push_back("no_realize"); - kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); + kwargs_values.push_back(LiteralDoc::Boolean(true, std::nullopt)); } - return ScopeDoc(NullOpt, + return ScopeDoc(std::nullopt, TIR(d, "block") // ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))}, kwargs_keys, kwargs_values), @@ -225,7 +226,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Block block, ObjectPath p, IRDocsifier d) -> Doc { - return PrintBlock(d, block, p, NullOpt, NullOpt); + return PrintBlock(d, block, p, std::nullopt, std::nullopt); }); TVM_SCRIPT_REPR(tir::BlockNode, ReprPrintTIR); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 18c7afe50454..0427c359049b 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -119,7 +119,7 @@ Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, if (is_new_var(e)) { if (try_inline_def(e, e_p, [=]() { return d->AsDoc(buffer, buffer_p) - ->Attr("strides")[{LiteralDoc::Int(i, NullOpt)}]; + ->Attr("strides")[{LiteralDoc::Int(i, std::nullopt)}]; })) { results.push_back(LiteralDoc::Str(Downcast(e)->name_hint, e_p)); continue; @@ -175,9 +175,9 @@ Map BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, d->AsDoc(buffer->axis_separators, buffer_p->Attr("axis_separators"))); } if (var_def_lhs.size() == 1) { - frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], NullOpt)); + frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], std::nullopt)); } else if (var_def_lhs.size() > 1) { - frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs), TupleDoc(var_def_rhs), NullOpt)); + frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs), TupleDoc(var_def_rhs), std::nullopt)); } return kwargs; } @@ -231,7 +231,7 @@ Array BufferIndices(const Array& indices, const ObjectPath& p, ramp_p->Attr("base")); ExprDoc stop = d->AsDoc(ramp->base + ramp->lanes * ramp->stride, // ramp_p->Attr("lanes")); - Optional step = NullOpt; + Optional step = std::nullopt; if (stride->value != 1) { step = d->AsDoc(ramp->stride, ramp_p->Attr("stride")); } @@ -256,7 +256,7 @@ Array BufferSlices(const Array& region, const ObjectPath& p, const I indices.push_back(min); } else { ExprDoc max = d->AsDoc(range->min + range->extent, range_p->Attr("extent")); - indices.push_back(SliceDoc(min, max, NullOpt)); + indices.push_back(SliceDoc(min, max, std::nullopt)); } } return indices; @@ -285,7 +285,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return AssignDoc( /*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)], - /*rhs=*/value, NullOpt); + /*rhs=*/value, std::nullopt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -310,7 +310,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); ExprDoc rhs = BufferDecl(buffer, "Buffer", {}, p, opt_f.value(), d, BufferVarDefinition::DataPointer); - opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } } if (Optional doc = d->GetVarDoc(buffer)) { @@ -328,7 +328,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc src_buffer = d->AsDoc(stmt->source, p->Attr("source")); ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer", {src_buffer}, p->Attr("buffer"), d->frames.back(), d, BufferVarDefinition::MatchBuffer); - return AssignDoc(lhs, rhs, NullOpt); + return AssignDoc(lhs, rhs, std::nullopt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -343,7 +343,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](tir::ProducerStore store, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(store->producer->GetNameHint()); prefix = prefix[BufferIndices(store->indices, p->Attr("indices"), d)]; - return AssignDoc(prefix, d->AsDoc(store->value, p->Attr("value")), NullOpt); + return AssignDoc(prefix, d->AsDoc(store->value, p->Attr("value")), std::nullopt); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -355,7 +355,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ->Call({prefix, d->AsDoc(stmt->condition, p->Attr("condition"))}); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - return ScopeDoc(NullOpt, prefix, (*f)->stmts); + return ScopeDoc(std::nullopt, prefix, (*f)->stmts); }); TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index b81c8ef5af36..549247449e33 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -33,7 +33,7 @@ ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const IRD if (var->IsInstance()) { kwargs_keys.push_back("is_size_var"); - kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); + kwargs_values.push_back(LiteralDoc::Boolean(true, std::nullopt)); } if (const auto* ptr_type = type.as()) { @@ -65,7 +65,7 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) if (Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); ExprDoc rhs = PrintVarCreation(var, var_p, d); - opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } else { LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; } diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 107521c94791..0df53c481f0c 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -65,10 +65,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 3. If not `T.grid`, print loop kind accordingly ExprDoc lhs = DefineVar(loop->loop_var, *f, d); - Optional min = NullOpt; - Optional max = NullOpt; - Optional annotations = NullOpt; - Optional thread = NullOpt; + Optional min = std::nullopt; + Optional max = std::nullopt; + Optional annotations = std::nullopt; + Optional thread = std::nullopt; if (tir::is_zero(loop->min)) { max = d->AsDoc(loop->extent, loop_p->Attr("extent")); } else { diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 61b1c54cf5dd..10f7bd74520b 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -95,13 +95,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d); - args.push_back(AssignDoc(lhs, NullOpt, annotation)); + args.push_back(AssignDoc(lhs, std::nullopt, annotation)); buffer_inlined.insert(buffer.get()); continue; } } ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); - args.push_back(AssignDoc(DefineVar(var, *f, d), NullOpt, a)); + args.push_back(AssignDoc(DefineVar(var, *f, d), std::nullopt, a)); } // Step 2. Handle `func->attrs` if (func->attrs.defined() && !func->attrs->dict.empty()) { @@ -138,7 +138,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d, BufferVarDefinition::MatchBuffer); - (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + (*f)->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } } // Step 4. Handle `func->body` @@ -159,7 +159,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } } - return NullOpt; + return std::nullopt; }(); if (d->cfg->syntax_sugar && implicit_root_block) { tir::Block root_block = implicit_root_block.value(); @@ -172,13 +172,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *f, d, BufferVarDefinition::DataPointer); - (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + (*f)->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } AsDocBody(root_block->body, root_block_p->Attr("body"), f->get(), d); } else { AsDocBody(func->body, p->Attr("body"), f->get(), d); } - Optional ret_type = NullOpt; + Optional ret_type = std::nullopt; if (func->ret_type.defined()) { const auto* as_tuple = func->ret_type.as(); if (!as_tuple || as_tuple->fields.size()) { diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index b7ba456dc2b5..1d310c2a5a9f 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -27,7 +27,7 @@ Doc DoConciseScoping(const Optional& lhs, const ExprDoc& rhs, Arrayinsert(stmts->begin(), AssignDoc(lhs.value(), rhs, NullOpt)); + stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, std::nullopt)); } else { stmts->insert(stmts->begin(), ExprStmtDoc(rhs)); } @@ -66,14 +66,14 @@ bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IR Optional FindReturnValue(const tir::Stmt& node) { auto eval = node.as(); - if (!eval) return NullOpt; + if (!eval) return std::nullopt; auto call = eval->value.as(); - if (!call) return NullOpt; + if (!call) return std::nullopt; - if (!call->op.same_as(tir::builtin::ret())) return NullOpt; + if (!call->op.same_as(tir::builtin::ret())) return std::nullopt; - if (call->args.size() != 1) return NullOpt; + if (call->args.size() != 1) return std::nullopt; return call->args[0]; } @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) p->Attr("var")->Attr("type_annotation")); if (const auto* tuple_type = stmt->var->type_annotation.as()) { if (tuple_type->fields.empty()) { - type_doc = NullOpt; + type_doc = std::nullopt; } } // Step 2. RHS @@ -119,7 +119,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) AsDocBody(stmt->body, p->Attr("body"), f->get(), d); // Step 4. Dispatch if (var_defined) { - return ScopeDoc(NullOpt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts); + return ScopeDoc(std::nullopt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts); } else if (concise) { stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); return StmtBlockDoc(*stmts); @@ -143,7 +143,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) stmts->insert(stmts->begin(), AssertDoc(cond, msg)); return StmtBlockDoc(*stmts); } - return ScopeDoc(NullOpt, TIR(d, "Assert")->Call({cond, msg}), (*f)->stmts); + return ScopeDoc(std::nullopt, TIR(d, "Assert")->Call({cond, msg}), (*f)->stmts); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -275,9 +275,9 @@ ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { runtime::DataType dtype = arr.DataType(); for (int i = 0; i < tot_dim; i++) { if (dtype.is_float()) { - result.push_back(LiteralDoc::Float(data_ptr[i], NullOpt)); + result.push_back(LiteralDoc::Float(data_ptr[i], std::nullopt)); } else { - result.push_back(LiteralDoc::Int(data_ptr[i], NullOpt)); + result.push_back(LiteralDoc::Int(data_ptr[i], std::nullopt)); } if (i == NUM_PRINT) { break; @@ -354,7 +354,7 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, OptionalAsDoc(range->min, range_p->Attr("min")), d->AsDoc(range->min + range->extent, range_p->Attr("extent")), // - NullOpt)); + std::nullopt)); } buffer = buffer[bounds]; } @@ -379,7 +379,7 @@ void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p, ->Call({LiteralDoc::Str(iter_var->thread_tag, // iter_var_p->Attr("thread_tag"))}); ExprDoc lhs = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); - f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); + f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p, @@ -408,19 +408,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); - ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d); + ExprDoc rhs = DocsifyBufferRealize(stmt.get(), std::nullopt, p, d); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - return DoConciseScoping(NullOpt, rhs, &(*f)->stmts, concise); + return DoConciseScoping(std::nullopt, rhs, &(*f)->stmts, concise); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); - Optional lhs = NullOpt; - Optional rhs = NullOpt; - Optional define_var = NullOpt; + Optional lhs = std::nullopt; + Optional rhs = std::nullopt; + Optional define_var = std::nullopt; tir::Stmt body = stmt->body; ObjectPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 7bdba2b1c6d5..d1bc56d13960 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -139,7 +139,7 @@ inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, cons */ inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) { if (!d->common_prefix.count(var.get())) { - return NullOpt; + return std::nullopt; } int n_frames = d->frames.size(); std::unordered_map tir_to_frame; @@ -163,7 +163,7 @@ inline Optional FindLowestVarDef(const ObjectRef& var, const IRDocsifier& if (fallback_frame != nullptr) { return GetRef(fallback_frame); } - return NullOpt; + return std::nullopt; } /*! \brief Redirected method for the ReprPrinter */ diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 9ca9e87c1f14..03341c4cd90f 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -40,7 +40,7 @@ namespace printer { inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { try { - p->stream << TVMScriptPrinter::Script(obj, NullOpt); + p->stream << TVMScriptPrinter::Script(obj, std::nullopt); } catch (const tvm::Error& e) { if (ReprLegacyPrinter::CanDispatch(obj)) { LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" @@ -148,7 +148,8 @@ inline bool HasMultipleLines(const std::string& str) { } inline Optional GetBindingName(const IRDocsifier& d) { - return d->cfg->binding_names.empty() ? Optional(NullOpt) : d->cfg->binding_names.back(); + return d->cfg->binding_names.empty() ? Optional(std::nullopt) + : d->cfg->binding_names.back(); } inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f) { @@ -158,7 +159,7 @@ inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f if (Optional sym = f->GetAttr(tvm::attr::kGlobalSymbol)) { return sym.value(); } - return NullOpt; + return std::nullopt; } inline String GenerateUniqueName(std::string name_hint, diff --git a/src/support/array.h b/src/support/array.h index 6fd30503f016..f49439aeb3ff 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -18,8 +18,8 @@ */ #ifndef TVM_SUPPORT_ARRAY_H_ #define TVM_SUPPORT_ARRAY_H_ +#include #include -#include #include #include @@ -70,7 +70,7 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector } /*! - * \brief Convert a tvm::runtime::Array to std::vector + * \brief Convert a tvm::Array to std::vector * \tparam TSrc The type of elements in the source Array * \tparam TDst The type of elements in the result vector * \return The result vector @@ -79,7 +79,7 @@ template inline std::vector AsVector(const Array& vec); /*! - * \brief Convert a std::vector to tvm::runtime::Array + * \brief Convert a std::vector to tvm::Array * \tparam TSrc The type of elements in the source vector * \tparam TDst The type of elements in the result Array * \return The result Array @@ -88,7 +88,7 @@ template inline Array AsArray(const std::vector& vec); /*! - * \brief Convert a tvm::runtime::Array to std::list + * \brief Convert a tvm::Array to std::list * \tparam T The type of elements in the source array * \return The result list */ @@ -100,7 +100,7 @@ inline std::list AsList(const Array& array) { } /*! - * \brief Convert a std::list to tvm::runtime::Array + * \brief Convert a std::list to tvm::Array * \tparam T The type of elements in the source list * \return The result list */ @@ -116,10 +116,10 @@ inline Array AsArray(const std::list& list) { * \param shape The shape tuple * \return An array of the shape tuple */ -inline Array AsArray(const ShapeTuple& shape) { +inline Array AsArray(const ffi::Shape& shape) { Array result; result.reserve(shape->size); - for (ShapeTuple::index_type i : shape) { + for (ffi::Shape::index_type i : shape) { result.push_back(Integer(i)); } return result; diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 4482272ced53..d0d9a35db83e 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -21,11 +21,11 @@ * FFI registration code used for frontend testing purposes. * \file ffi_testing.cc */ +#include +#include #include #include -#include #include -#include #include #include @@ -53,13 +53,22 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback") +TVM_FFI_REGISTER_GLOBAL("testing.GetShapeSize").set_body_typed([](ffi::Shape shape) { + return static_cast(shape.size()); +}); + +TVM_FFI_REGISTER_GLOBAL("testing.GetShapeElem").set_body_typed([](ffi::Shape shape, int idx) { + ICHECK_LT(idx, shape.size()); + return shape[idx]; +}); + +TVM_FFI_REGISTER_GLOBAL("testing.test_wrap_callback") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); *ret = ffi::TypedFunction([pf]() { pf(); }); }); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") +TVM_FFI_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); auto result = ffi::TypedFunction([pf]() { @@ -71,22 +80,23 @@ TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err") *ret = result; }); -TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") +TVM_FFI_REGISTER_GLOBAL("testing.test_check_eq_callback") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { auto msg = args[0].cast(); *ret = ffi::TypedFunction([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); }); -TVM_REGISTER_GLOBAL("testing.device_test").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto dev = args[0].cast(); - int dtype = args[1].cast(); - int did = args[2].cast(); - CHECK_EQ(static_cast(dev.device_type), dtype); - CHECK_EQ(static_cast(dev.device_id), did); - *ret = dev; -}); +TVM_FFI_REGISTER_GLOBAL("testing.device_test") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + auto dev = args[0].cast(); + int dtype = args[1].cast(); + int did = args[2].cast(); + CHECK_EQ(static_cast(dev.device_type), dtype); + CHECK_EQ(static_cast(dev.device_id), did); + *ret = dev; + }); -TVM_REGISTER_GLOBAL("testing.identity_cpp") +TVM_FFI_REGISTER_GLOBAL("testing.identity_cpp") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); ICHECK(identity_func.has_value()) @@ -105,7 +115,7 @@ void ErrorTest(int x, int y) { } } -TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); +TVM_FFI_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); class FrontendTestModuleNode : public runtime::ModuleNode { public: @@ -145,22 +155,23 @@ runtime::Module NewFrontendTestModule() { return runtime::Module(n); } -TVM_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule); +TVM_FFI_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule); -TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { +TVM_FFI_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { std::chrono::duration duration(static_cast(timeout * 1e9)); std::this_thread::sleep_for(duration); }); -TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant { - if (x % 2 == 0) { - return IntImm(DataType::Int(64), x / 2); - } else { - return String("argument was odd"); - } -}); +TVM_FFI_REGISTER_GLOBAL("testing.ReturnsVariant") + .set_body_typed([](int x) -> Variant { + if (x % 2 == 0) { + return IntImm(DataType::Int(64), x / 2); + } else { + return String("argument was odd"); + } + }); -TVM_REGISTER_GLOBAL("testing.AcceptsVariant") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { if (auto opt_str = arg.as()) { return opt_str.value()->GetTypeKey(); @@ -169,25 +180,25 @@ TVM_REGISTER_GLOBAL("testing.AcceptsVariant") } }); -TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); -TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); -TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray").set_body_typed([](Array arg) -> Any { +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsObjectRefArray").set_body_typed([](Array arg) -> Any { return arg[0]; }); -TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") .set_body_typed([](Map map, Any key) -> Any { return map[key]; }); -TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") .set_body_typed([](Map map) -> ObjectRef { return map; }); -TVM_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { return expr; }); -TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") .set_body_typed([](Array arr) -> ObjectRef { for (ObjectRef item : arr) { CHECK(item->IsInstance()) @@ -196,7 +207,7 @@ TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") return arr; }); -TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") .set_body_typed([](Array> arr) -> ObjectRef { for (auto item : arr) { CHECK(item.as() || item.as()) @@ -205,7 +216,7 @@ TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") return arr; }); -TVM_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") +TVM_FFI_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") .set_body_typed([](Map map) -> ObjectRef { for (const auto& kv : map) { ObjectRef value = kv.second; @@ -254,19 +265,21 @@ class TestingEventLogger { std::vector entries_; }; -TVM_REGISTER_GLOBAL("testing.record_event").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args.size() != 0 && args[0].as()) { - TestingEventLogger::ThreadLocal()->Record(args[0].cast()); - } else { - TestingEventLogger::ThreadLocal()->Record("X"); - } -}); +TVM_FFI_REGISTER_GLOBAL("testing.record_event") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + if (args.size() != 0 && args[0].try_cast()) { + TestingEventLogger::ThreadLocal()->Record(args[0].cast()); + } else { + TestingEventLogger::ThreadLocal()->Record("X"); + } + }); -TVM_REGISTER_GLOBAL("testing.reset_events").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - TestingEventLogger::ThreadLocal()->Reset(); -}); +TVM_FFI_REGISTER_GLOBAL("testing.reset_events") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + TestingEventLogger::ThreadLocal()->Reset(); + }); -TVM_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("testing.dump_events").set_body_typed([]() { TestingEventLogger::ThreadLocal()->Dump(); }); } // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index e78cfdd016f1..01b49bb92e79 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -17,9 +17,9 @@ * under the License. */ #include +#include #include #include -#include #ifndef TVM_INFO_GIT_COMMIT_HASH #define TVM_INFO_GIT_COMMIT_HASH "NOT-FOUND" @@ -367,6 +367,6 @@ TVM_DLL ffi::Map GetLibInfo() { return result; } -TVM_REGISTER_GLOBAL("support.GetLibInfo").set_body_typed(GetLibInfo); +TVM_FFI_REGISTER_GLOBAL("support.GetLibInfo").set_body_typed(GetLibInfo); } // namespace tvm diff --git a/src/support/socket.h b/src/support/socket.h index e3972488d4b8..e9e2f87f9dbf 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -47,8 +47,9 @@ #include #include #endif + +#include #include -#include #include #include diff --git a/src/support/utils.h b/src/support/utils.h index c855ae736ade..eb0d4b9a8827 100644 --- a/src/support/utils.h +++ b/src/support/utils.h @@ -32,7 +32,7 @@ #endif // __hexagon__ #endif // _WIN32 -#include +#include #include #include diff --git a/src/target/build_common.h b/src/target/build_common.h index 7c9ad8cb3c68..70f15d091ed2 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -24,8 +24,8 @@ #ifndef TVM_TARGET_BUILD_COMMON_H_ #define TVM_TARGET_BUILD_COMMON_H_ +#include #include -#include #include #include #include diff --git a/src/target/codegen.cc b/src/target/codegen.cc index d7db5cc8e9e2..8ddc071cba0f 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -22,10 +22,10 @@ * \brief Common utilities to generated C style code. */ #include +#include #include -#include +#include #include -#include #include #include #include @@ -361,17 +361,17 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, .cast(); } -TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); +TVM_FFI_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export a few auxiliary function to the runtime namespace. -TVM_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string { +TVM_FFI_REGISTER_GLOBAL("runtime.ModuleImportsBlobName").set_body_typed([]() -> std::string { return runtime::symbol::tvm_dev_mblob; }); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") +TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") .set_body_typed([](const runtime::Module& mod) { std::string buffer = PackImportsToBytes(mod); - ShapeTuple::index_type size = buffer.size(); + ffi::Shape::index_type size = buffer.size(); DLDataType uchar; uchar.code = kDLUInt; uchar.bits = 8; @@ -384,8 +384,8 @@ TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToNDArray") return array; }); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); +TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); +TVM_FFI_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/myfloat/myfloat.cc b/src/target/datatype/myfloat/myfloat.cc index c0c2fffa03da..afee8a7c4bf0 100644 --- a/src/target/datatype/myfloat/myfloat.cc +++ b/src/target/datatype/myfloat/myfloat.cc @@ -26,7 +26,7 @@ * * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? */ -#include +#include #include #include diff --git a/src/target/datatype/posit/posit-wrapper.cc b/src/target/datatype/posit/posit-wrapper.cc index 700c5cb9dbe9..bb2af37ec921 100644 --- a/src/target/datatype/posit/posit-wrapper.cc +++ b/src/target/datatype/posit/posit-wrapper.cc @@ -28,7 +28,7 @@ * * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? */ -#include +#include #include diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 79065d0024c5..88f96b6a707b 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -18,7 +18,8 @@ */ #include "registry.h" -#include +#include +#include namespace tvm { namespace datatype { @@ -26,23 +27,23 @@ namespace datatype { using ffi::Any; using ffi::PackedArgs; -TVM_REGISTER_GLOBAL("dtype.register_custom_type") +TVM_FFI_REGISTER_GLOBAL("dtype.register_custom_type") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { datatype::Registry::Global()->Register(args[0].cast(), static_cast(args[1].cast())); }); -TVM_REGISTER_GLOBAL("dtype.get_custom_type_code") +TVM_FFI_REGISTER_GLOBAL("dtype.get_custom_type_code") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0].cast()); }); -TVM_REGISTER_GLOBAL("dtype.get_custom_type_name") +TVM_FFI_REGISTER_GLOBAL("dtype.get_custom_type_name") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeName(args[0].cast()); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") +TVM_FFI_REGISTER_GLOBAL("runtime._datatype_get_type_registered") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); }); diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index 46b189880f64..b1a1a4a7f50b 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -20,8 +20,7 @@ #ifndef TVM_TARGET_DATATYPE_REGISTRY_H_ #define TVM_TARGET_DATATYPE_REGISTRY_H_ -#include -#include +#include #include #include @@ -38,7 +37,7 @@ namespace datatype { * directly---see the TVM globals registered in the corresponding .cc file. * Currently, user should manually choose a type name and a type code, * ensuring that neither conflict with existing types. - * 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to + * 2. Use TVM_FFI_REGISTER_GLOBAL to register the lowering functions needed to * lower the custom datatype. In general, these will look like: * For Casts: tvm.datatype.lower..Cast.. * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index ea8ccd98b1af..ac45476f7702 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -24,7 +24,7 @@ #ifndef TVM_TARGET_INTRIN_RULE_H_ #define TVM_TARGET_INTRIN_RULE_H_ -#include +#include #include #include diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 1399fc083a08..9d968cdb6478 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -25,7 +25,7 @@ #include #include -#include +#include #include "../../arith/scalable_expression.h" #include "codegen_cpu.h" @@ -57,8 +57,8 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) { #if TVM_LLVM_VERSION >= 130 // Add vscale_range() function attribute when appropriate. if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) { - unsigned int max_val = - *std::max_element(arith::kAArch64VScaleValues.begin(), arith::kAArch64VScaleValues.end()); + auto kVScaleValues = arith::GetVScaleValues(Target::Current()); + unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); func->addFnAttr( llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val)); } @@ -106,7 +106,7 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { this->VisitStmt(op->body); } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_aarch64") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAArch64()); }); diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index a1cff52beb2a..048c4160b118 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -47,9 +47,9 @@ #endif #include #include -#include +#include +#include #include -#include #include "../../runtime/rocm/rocm_module.h" #include "../build_common.h" @@ -273,13 +273,13 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { #endif auto cg = std::make_unique(); - cg->Init("TVMAMDGPUModule", llvm_target.get(), NullOpt, false, false); + cg->Init("TVMAMDGPUModule", llvm_target.get(), std::nullopt, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); auto fbitcode = tvm::ffi::Function::GetGlobalRequired("tvm_callback_rocm_bitcode_path"); - auto bitcode_files = fbitcode().cast>(); + auto bitcode_files = fbitcode().cast>(); for (auto& bitcode_path : bitcode_files) { std::unique_ptr mlib = llvm_instance.LoadIR(bitcode_path); @@ -356,9 +356,9 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); +TVM_FFI_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenAMDGPU()); }); diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 3abebec2a36e..03ef982d1308 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -24,7 +24,7 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include #if TVM_LLVM_VERSION >= 100 #include #endif @@ -132,7 +132,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenARM()); }); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 0a5223ae029b..bfbd65e524fb 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -49,7 +49,7 @@ #include #include #include -#include +#include #include #include @@ -75,12 +75,10 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, CodeGenLLVM::Init(module_name, llvm_target, system_lib_prefix, dynamic_lookup, target_c_runtime); system_lib_prefix_ = system_lib_prefix; dbg_info_ = CreateDebugInfo(module_.get()); - static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); func_handle_map_.clear(); export_system_symbols_.clear(); // Runtime types. - t_tvm_shape_index_ = llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits()); // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: @@ -89,7 +87,7 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: // typedef struct { uint8_t code; uint8_t bits; uint16_t lanes; } DLDataType; t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); - // Defined in include/tvm/runtime/c_runtime_api.h: + // Defined in include/tvm/runtime/base.h: // typedef void* TVMFunctionHandle; t_tvm_func_handle_ = t_void_p_; // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: @@ -1158,7 +1156,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { } } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenCPU()); }); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 22708c61178a..baf7497bc0d1 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -499,7 +499,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { } } - cg->Init("TVMHexagonModule", llvm_target.get(), NullOpt, false, false); + cg->Init("TVMHexagonModule", llvm_target.get(), std::nullopt, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); @@ -589,9 +589,9 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str); } -TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); +TVM_FFI_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_hexagon") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenHexagon()); }); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f77e32efd587..634c9c2b57a5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -90,7 +90,7 @@ #include #include #include -#include +#include #include #include @@ -2325,19 +2325,18 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) return nullptr; } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetDefaultTargetTriple").set_body_typed([]() -> std::string { - return llvm::sys::getDefaultTargetTriple(); -}); +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetDefaultTargetTriple") + .set_body_typed([]() -> std::string { return llvm::sys::getDefaultTargetTriple(); }); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetProcessTriple").set_body_typed([]() -> std::string { +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetProcessTriple").set_body_typed([]() -> std::string { return llvm::sys::getProcessTriple(); }); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> std::string { +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUName").set_body_typed([]() -> std::string { return llvm::sys::getHostCPUName().str(); }); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.GetHostCPUFeatures") .set_body_typed([]() -> Map { #if TVM_LLVM_VERSION >= 190 Map ret; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 530d1772df31..f7e4e819030e 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -117,7 +117,7 @@ class CodeGenLLVM : public ExprFunctor, * \param module_name The name of the module. * \param tm Target machine model * \param ctx The context. - * \param system_lib_prefix If the value is not NullOpt, insert system lib registration. + * \param system_lib_prefix If the value is not std::nullopt, insert system lib registration. * The value corresponds to the prefix of the system lib symbols. * \param dynamic_lookup Whether dynamically lookup runtime function * or use the runtime function table passed by caller. diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index fa035367408c..a0ffb5a1ce10 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -324,7 +324,7 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { int compute_ver = GetCUDAComputeVersion(target); auto cg = std::make_unique(); - cg->Init("TVMPTXModule", llvm_target.get(), NullOpt, false, false); + cg->Init("TVMPTXModule", llvm_target.get(), std::nullopt, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); @@ -368,9 +368,9 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); +TVM_FFI_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenNVPTX()); }); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 954d4e7efd56..435b453d49ba 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -30,7 +30,7 @@ #include #endif #include -#include +#include #include #include @@ -132,7 +132,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems); } -TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") +TVM_FFI_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") .set_body_packed([](const ffi::PackedArgs& targs, ffi::Any* rv) { *rv = static_cast(new CodeGenX86_64()); }); diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 2730c0a34d63..e519c9eef397 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -25,9 +25,13 @@ #include "intrin_rule_llvm.h" #include +#define _USE_MATH_DEFINES +#include #include #include +#include + #include "../intrin_rule.h" namespace tvm { @@ -160,6 +164,84 @@ TVM_REGISTER_OP("tir.sinh") return ret; }); +TVM_REGISTER_OP("tir.asin") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr x2 = x * x; + PrimExpr term1 = x; + PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6); + PrimExpr term5 = term3 * x2 * make_const(x.dtype(), 9) / make_const(x.dtype(), 40); + PrimExpr term7 = term5 * x2 * make_const(x.dtype(), 25) / make_const(x.dtype(), 112); + PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456); + PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160); + PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11; + /* --- domain limit check --- */ + PrimExpr lower = make_const(x.dtype(), -1.0); + PrimExpr upper = make_const(x.dtype(), 1.0); + PrimExpr out_range = tir::Or(x upper); + // Use a quiet NaN constant + PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); + // select: if out of [-1,1] → NaN, else → series + return tir::Select(out_range, nan_const, series); + }); + +TVM_REGISTER_OP("tir.acos") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in acos legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr half_pi = make_const(x.dtype(), M_PI / 2); + PrimExpr asin_x = asin(x); + return half_pi - asin_x; + }); + +TVM_REGISTER_OP("tir.atan") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in atan legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + PrimExpr denom = sqrt(x * x + one); + return asin(x / denom); + }); + +TVM_REGISTER_OP("tir.asinh") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + PrimExpr sqrt_val = sqrt(x * x + one); + return log(x + sqrt_val); + }); + +TVM_REGISTER_OP("tir.acosh") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + PrimExpr sqrt_val = sqrt(x * x - one); + return log(x + sqrt_val); + }); + +TVM_REGISTER_OP("tir.atanh") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr) << "Invalid call node in atanh legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1.0); + return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5); + }); + TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index a0e040a2048e..4b64e92127d3 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -26,7 +26,7 @@ #ifdef TVM_LLVM_VERSION -#include +#include #include #include #include diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index e7be40fb9041..48fc64172215 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -22,7 +22,7 @@ */ #ifdef TVM_LLVM_VERSION -#include +#include #include #include #include diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index c80d8388da9c..30afcee92acc 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -23,7 +23,7 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include #include #include #include diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index b6191f5e854c..ebf9a754b101 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -53,22 +53,10 @@ #include #include #include -#if TVM_LLVM_VERSION >= 190 -#include -#else -#if TVM_LLVM_VERSION >= 140 -#include -#endif -#endif -#if TVM_LLVM_VERSION >= 160 -#include -#else -#include -#endif -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include @@ -299,34 +287,25 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // code model code_model_ = llvm::CodeModel::Medium; #if TVM_LLVM_VERSION >= 140 - // VLEN inference - const auto cpu_name = GetOrCreateTargetMachine(false)->getMCSubtargetInfo()->getCPU(); - const auto canon_arch = llvm::RISCV::getMArchFromMcpu(cpu_name); - auto ISAInfo = - llvm::RISCVISAInfo::parseArchString(canon_arch, /*EnableExperimentalExtensions=*/true); - // infer VLEN from LLVM RISCVInfo parser - if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) { - vector_width_ = (*ISAInfo)->getMinVLen(); - } - // infer VLEN from LLVM options (zvlXXXb override) - for (const auto& attr : attrs_) { - if (attr.find("zvl") != std::string::npos) { - std::string vec; - for (char c : attr) { - if (std::isdigit(c)) vec += c; + // get VLEN from the LLVM backend (zvlXXXb) + Map features = GetAllLLVMCpuFeatures(); + // check vector ISA + if (features.count("v") > 0) { + vector_width_ = 0; + int zvlbits = 0; + for (const auto& [attr, val] : features) { + if (std::string(attr).find("zvl") != std::string::npos) { + std::string vec; + for (char c : std::string(attr)) { + if (std::isdigit(c)) vec += c; + } + zvlbits = std::stoi(vec); + // max of the multiple zvlXXXb + if (vector_width_ < zvlbits) vector_width_ = zvlbits; } - vector_width_ = std::stoi(vec); } } #endif - if (vector_width_ > 0) { - // push cl-opt to LLVM - llvm_options_.push_back( - ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_))); - } else { - // fallback default (codegen will warn) - llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256")); - } } // Target options @@ -943,9 +922,7 @@ const int LLVMTargetInfo::GetVectorWidth() { } else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) { vector_width_ = 128; } else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) { - vector_width_ = 256; - LOG(WARNING) << "LLVM RVV VLEN inference failed, " - << "using 256 bits, set -vector-width=XXX to override"; + vector_width_ = 128; } else { // fallback default vector_width_ = 128; diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index 36d15b3a7715..f2468a8ef99f 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -32,10 +32,10 @@ #endif #include #include +#include +#include +#include #include -#include -#include -#include #include #include diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 396b02063b34..ed70d8692635 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -31,8 +31,10 @@ #include #include #include +#if _WIN32 #include #include +#endif #include #include #include @@ -53,14 +55,13 @@ #include #include #include +#include +#include +#include #include -#include -#include #include #include #include -#include -#include #include #include #include @@ -503,9 +504,13 @@ void LLVMModuleNode::InitORCJIT() { const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple& triple) -> std::unique_ptr { +#if _WIN32 auto GetMemMgr = []() { return std::make_unique(); }; auto ObjLinkingLayer = std::make_unique(session, std::move(GetMemMgr)); +#else + auto ObjLinkingLayer = std::make_unique(session); +#endif if (triple.isOSBinFormatCOFF()) { ObjLinkingLayer->setOverrideObjectFlagsWithResponsibilityFlags(true); ObjLinkingLayer->setAutoClaimResponsibilityForObjectSymbols(true); @@ -615,14 +620,14 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, return nullptr; } -TVM_REGISTER_GLOBAL("target.build.llvm") +TVM_FFI_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { auto n = make_object(); n->Init(mod, target); return runtime::Module(n); }); -TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") +TVM_FFI_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { auto llvm_instance = std::make_unique(); With llvm_target(*llvm_instance, target_str); @@ -637,7 +642,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") return runtime::Module(n); }); -TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") +TVM_FFI_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") .set_body_typed([](std::string name) -> int64_t { #if TVM_LLVM_VERSION >= 200 return static_cast(llvm::Intrinsic::lookupIntrinsicID(name)); @@ -646,7 +651,7 @@ TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") #endif }); -TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t id) -> String { #if TVM_LLVM_VERSION >= 130 return std::string(llvm::Intrinsic::getBaseName(static_cast(id))); #elif TVM_LLVM_VERSION >= 40 @@ -661,7 +666,7 @@ TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t #endif }); -TVM_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> String { #if TVM_LLVM_VERSION >= 120 #if defined(__i386__) || defined(_M_IX86) || defined(__x86_64__) || defined(_M_X64) using namespace llvm::sys::detail::x86; @@ -677,34 +682,35 @@ TVM_REGISTER_GLOBAL("target.llvm_get_system_x86_vendor").set_body_typed([]() -> return "unimplemented"; }); -TVM_REGISTER_GLOBAL("target.llvm_get_vector_width").set_body_typed([](const Target& target) -> int { - auto use_target = target.defined() ? target : Target::Current(false); - // ignore non "llvm" target - if (target.defined()) { - if (target->kind->name != "llvm") { - return -1; - } - } - auto llvm_instance = std::make_unique(); - LLVMTargetInfo llvm_backend(*llvm_instance, use_target); - return llvm_backend.GetVectorWidth(); -}); +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_vector_width") + .set_body_typed([](const Target& target) -> int { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return -1; + } + } + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetVectorWidth(); + }); -TVM_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_triple").set_body_typed([]() -> String { return llvm::sys::getDefaultTargetTriple(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_system_cpu").set_body_typed([]() -> String { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_system_cpu").set_body_typed([]() -> String { return llvm::sys::getHostCPUName().str(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { auto llvm_instance = std::make_unique(); LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); return llvm_backend.GetAllLLVMTargets(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") .set_body_typed([](const Target& target) -> Array { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -718,7 +724,7 @@ TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") return llvm_backend.GetAllLLVMTargetArches(); }); -TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") +TVM_FFI_REGISTER_GLOBAL("target.llvm_get_cpu_features") .set_body_typed([](const Target& target) -> Map { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -732,7 +738,7 @@ TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") return llvm_backend.GetAllLLVMCpuFeatures(); }); -TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") +TVM_FFI_REGISTER_GLOBAL("target.llvm_cpu_has_feature") .set_body_typed([](const String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -748,7 +754,7 @@ TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") return has_feature; }); -TVM_REGISTER_GLOBAL("target.target_has_feature") +TVM_FFI_REGISTER_GLOBAL("target.target_has_feature") .set_body_typed([](const String feature, const Target& target) -> bool { auto use_target = target.defined() ? target : Target::Current(false); // ignore non "llvm" target @@ -762,11 +768,11 @@ TVM_REGISTER_GLOBAL("target.target_has_feature") return llvm_target.TargetHasCPUFeature(feature); }); -TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { +TVM_FFI_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { return TVM_LLVM_VERSION / 10; }); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadfile_ll") .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { auto n = make_object(); n->SetJITEngine("orcjit"); @@ -774,7 +780,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") return runtime::Module(n); }); -TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") +TVM_FFI_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { LLVMInstance llvm_instance; auto* tm = With(llvm_instance, target_str) @@ -782,7 +788,7 @@ TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") return tm != nullptr; }); -TVM_REGISTER_GLOBAL("codegen.codegen_blob") +TVM_FFI_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string, std::string c_symbol_prefix) -> runtime::Module { auto n = make_object(); diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 716a0a6c7007..2070d7da3e0c 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -27,7 +27,7 @@ #ifdef TVM_LLVM_VERSION -#include +#include #include #include #include diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 58f80858d087..068f6c2f7196 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -169,7 +169,7 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); +TVM_FFI_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", String); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ad9456a1e9a5..344c0857c4d4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -338,22 +338,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ")"; return os.str(); } else { - ICHECK_LT(kind, builtin::kTVMValueKindBound_); - std::ostringstream os; - os << "(((TVMValue*)"; - this->PrintExpr(buffer, os); - os << ")[" << index << "]."; - if (t.is_handle()) { - os << "v_handle"; - } else if (t.is_float()) { - os << "v_float64"; - } else if (t.is_int()) { - os << "v_int64"; - } else { - LOG(FATAL) << "Do not know how to handle type" << t; - } - os << ")"; - return os.str(); + TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; } } diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index ef86f09ca28e..ad73fc9079e9 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -43,7 +43,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d declared_globals_.clear(); decl_stream << "// tvm target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; - decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; + decl_stream << "#include \"tvm/runtime/base.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; decl_stream << "#include \n"; @@ -285,24 +285,20 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); ICHECK(num != nullptr); - static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant"); - size_t unit = sizeof(TVMValue); + static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); + size_t unit = sizeof(TVMFFIAny); size_t size = 0; if (type == "shape") { - size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit; - } else if (type == "arg_value") { - size = (num->value * sizeof(TVMValue) + unit - 1) / unit; + size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit; } else if (type == "tvm_ffi_any") { size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit; - } else if (type == "arg_tcode") { - size = (num->value * sizeof(int) + unit - 1) / unit; } else if (type == "array") { size = (num->value * sizeof(DLTensor) + unit - 1) / unit; } else { LOG(FATAL) << "Unknown stack alloca type " << type; } this->PrintIndent(); - this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; + this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; os << stack_name; } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { this->PrintCallPacked(op); @@ -408,6 +404,6 @@ runtime::Module BuildCHost(IRModule mod, Target target) { return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } -TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); +TVM_FFI_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6bbb388a94cb..c3014b11a5be 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -24,7 +24,7 @@ #include "codegen_cuda.h" #include -#include +#include #include #include diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index d115916def54..0f87a16c449b 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -466,6 +466,6 @@ runtime::Module BuildMetal(IRModule mod, Target target) { return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); } -TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); +TVM_FFI_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 9814696b3728..b94dc17bff33 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -672,7 +672,7 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str()); } -TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); +TVM_FFI_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { auto prototype_keys = target->GetKeys(); @@ -684,7 +684,7 @@ String DeviceScopeCompatibilityFromTarget(Target target, String memory_scope) { return memory_scope; } -TVM_REGISTER_GLOBAL("DeviceScopeCompatibility.opencl") +TVM_FFI_REGISTER_GLOBAL("DeviceScopeCompatibility.opencl") .set_body_typed(DeviceScopeCompatibilityFromTarget); } // namespace codegen diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 8d1ad91746b6..995eddee027e 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -779,7 +779,7 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { return runtime::Module(n); } -TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { +TVM_FFI_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { return BuildWebGPU(mod, target); }); diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ec3cadb8c8e4..5e1f132fb5a5 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,10 +23,9 @@ */ #include +#include #include #include -#include -#include #include #include @@ -175,7 +174,7 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_c") +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_c") .set_body_typed(CSourceModuleNode::LoadFromBinary); /*! @@ -249,9 +248,9 @@ runtime::Module DeviceSourceModuleCreate( return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); +TVM_FFI_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") +TVM_FFI_REGISTER_GLOBAL("runtime.CSourceModuleCreate") .set_body_typed([](String code, String fmt, Optional> func_names, Optional> const_vars) { return CSourceModuleCreate(code, fmt, func_names.value_or({}), const_vars.value_or({})); diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 5690ef05de5c..f3dbd624ec00 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -35,7 +35,7 @@ runtime::Module BuildSPIRV(IRModule mod, Target target) { return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), spirv_text); } -TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { +TVM_FFI_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { return BuildSPIRV(mod, target); }); diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index e5f869de1718..3010b74dd976 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -21,7 +21,7 @@ * \file intrin_rule_spirv.cc */ #include -#include +#include #include #include #include @@ -91,6 +91,39 @@ TVM_REGISTER_OP("tir.sin").set_attr("vulkan.FLowerIntrinsic", TVM_REGISTER_OP("tir.cos").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.tan").set_attr("vulkan.FLowerIntrinsic", + DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.asin") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.acos") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.atan") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.sinh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.cosh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.tanh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.asinh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.acosh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.atanh") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + +TVM_REGISTER_OP("tir.atan2") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + TVM_REGISTER_OP("tir.log").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); @@ -103,9 +136,6 @@ TVM_REGISTER_OP("tir.sqrt") TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.tanh") - .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); - TVM_REGISTER_OP("tir.erf").set_attr("vulkan.FLowerIntrinsic", codegen::intrin ::DispatchFastErf); } // namespace intrin diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 057007be72ac..5df779c59547 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -24,7 +24,7 @@ #ifndef TVM_TARGET_SPIRV_IR_BUILDER_H_ #define TVM_TARGET_SPIRV_IR_BUILDER_H_ -#include +#include #include // clang-format off diff --git a/src/target/tag.cc b/src/target/tag.cc index 0f398d6e19b4..f6e2307b75e1 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -22,8 +22,8 @@ * \brief Target tag registry */ +#include #include -#include #include #include @@ -33,8 +33,8 @@ namespace tvm { TVM_REGISTER_NODE_TYPE(TargetTagNode); -TVM_REGISTER_GLOBAL("target.TargetTagListTags").set_body_typed(TargetTag::ListTags); -TVM_REGISTER_GLOBAL("target.TargetTagAddTag").set_body_typed(TargetTag::AddTag); +TVM_FFI_REGISTER_GLOBAL("target.TargetTagListTags").set_body_typed(TargetTag::ListTags); +TVM_FFI_REGISTER_GLOBAL("target.TargetTagAddTag").set_body_typed(TargetTag::AddTag); /********** Registry-related code **********/ @@ -47,7 +47,7 @@ TargetTagRegEntry& TargetTagRegEntry::RegisterOrGet(const String& target_tag_nam Optional TargetTag::Get(const String& target_tag_name) { const TargetTagRegEntry* reg = TargetTagRegistry::Global()->Get(target_tag_name); if (reg == nullptr) { - return NullOpt; + return std::nullopt; } return Target(reg->tag_->config); } diff --git a/src/target/target.cc b/src/target/target.cc index 7a29ca2ef537..d9e3f9b51ee7 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,10 +21,10 @@ * \file src/target/target.cc */ #include +#include #include #include #include -#include #include #include #include @@ -40,8 +40,6 @@ #include #include -#include "../runtime/object_internal.h" - namespace tvm { TVM_REGISTER_NODE_TYPE(TargetNode); @@ -76,7 +74,7 @@ class TargetInternal { static std::string Interpret(const std::string& str); static std::string Uninterpret(const std::string& str); static std::string StringifyAtomicType(const Any& obj); - static std::string StringifyArray(const ArrayObj& array); + static std::string StringifyArray(const ffi::ArrayObj& array); static constexpr char quote = '\''; static constexpr char escape = '\\'; @@ -111,7 +109,7 @@ static std::vector DeduplicateKeys(const std::vector& keys) { template static T ObjTypeCheck(const Any& obj, const std::string& expected_type) { - auto opt = obj.as(); + auto opt = obj.try_cast(); if (!opt.has_value()) { TVM_FFI_THROW(TypeError) << "Expects type \"" << expected_type << "\", but gets \"" << obj.GetTypeKey() << "\" for object: " << obj; @@ -394,7 +392,7 @@ Any TargetInternal::ParseType(const std::string& str, const TargetKindNode::Valu } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target return Target(TargetInternal::FromString(interp_str)); - } else if (info.type_index == ArrayObj::RuntimeTypeIndex()) { + } else if (info.type_index == ffi::ArrayObj::RuntimeTypeIndex()) { // Parsing array std::vector result; for (const std::string& substr : SplitString(interp_str, ',')) { @@ -426,11 +424,11 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf // Parsing target if (auto opt = obj.as()) { return opt.value(); - } else if (auto str = obj.as()) { + } else if (auto str = obj.try_cast()) { return Target(TargetInternal::FromString(str.value())); - } else if (const auto* ptr = obj.as()) { + } else if (const auto* ptr = obj.as()) { for (const auto& kv : *ptr) { - if (!kv.first.as()) { + if (!kv.first.as()) { TVM_FFI_THROW(TypeError) << "Target object requires key of dict to be str, but get: " << kv.first.GetTypeKey(); } @@ -440,9 +438,9 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf } TVM_FFI_THROW(TypeError) << "Expect type 'dict' or 'str' to construct Target, but get: " + obj.GetTypeKey(); - } else if (info.type_index == ArrayObj::RuntimeTypeIndex()) { + } else if (info.type_index == ffi::ArrayObj::RuntimeTypeIndex()) { // Parsing array - const auto* array = ObjTypeCheck(obj, "Array"); + const auto* array = ObjTypeCheck(obj, "Array"); std::vector result; for (const Any& e : *array) { try { @@ -453,9 +451,9 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf } } return Array(result); - } else if (info.type_index == MapObj::RuntimeTypeIndex()) { + } else if (info.type_index == ffi::MapObj::RuntimeTypeIndex()) { // Parsing map - const auto* map = ObjTypeCheck(obj, "Map"); + const auto* map = ObjTypeCheck(obj, "Map"); std::unordered_map result; for (const auto& kv : *map) { Any key, val; @@ -487,9 +485,9 @@ Any TargetInternal::ParseType(const Any& obj, const TargetKindNode::ValueTypeInf std::string TargetInternal::StringifyAtomicType(const Any& obj) { if (obj.type_index() == ffi::TypeIndex::kTVMFFIBool) { - return std::to_string(obj.as().value()); + return std::to_string(obj.cast()); } else if (obj.type_index() == ffi::TypeIndex::kTVMFFIInt) { - return std::to_string(obj.as().value()); + return std::to_string(obj.cast()); } else if (auto opt_str = obj.as()) { std::string s = opt_str.value(); auto u = Uninterpret(s); @@ -502,7 +500,7 @@ std::string TargetInternal::StringifyAtomicType(const Any& obj) { TVM_FFI_UNREACHABLE(); } -std::string TargetInternal::StringifyArray(const ArrayObj& array) { +std::string TargetInternal::StringifyArray(const ffi::ArrayObj& array) { std::vector elements; for (const Any& item : array) { @@ -531,7 +529,7 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map std::string value; // skip undefined attrs if (obj == nullptr) continue; - if (const auto* array = obj.as()) { + if (const auto* array = obj.as()) { value = String(StringifyArray(*array)); } else { value = StringifyAtomicType(obj); @@ -602,7 +600,7 @@ Target::Target(Target target, Target host) { Target::Target(TargetKind kind, Optional host, String tag, Array keys, Map attrs) { - auto data = runtime::make_object(); + auto data = ffi::make_object(); data->kind = std::move(kind); data->host = std::move(host); data->tag = std::move(tag); @@ -651,7 +649,7 @@ Optional TargetNode::GetHost() const { return this->host.as(); } Target Target::WithoutHost() const { if ((*this)->GetHost()) { auto output = make_object(*get()); - output->host = NullOpt; + output->host = std::nullopt; return Target(output); } else { return *this; @@ -761,9 +759,9 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { const auto& arg = args[0]; if (auto opt_target = arg.as()) { *rv = Target(opt_target.value()); - } else if (auto opt_str = arg.as()) { + } else if (auto opt_str = arg.try_cast()) { *rv = Target(opt_str.value()); - } else if (auto opt_map = arg.as>()) { + } else if (auto opt_map = arg.try_cast>()) { *rv = Target(opt_map.value()); } else { LOG(FATAL) << "TypeError: Cannot create target with type: " << args[0].GetTypeKey(); @@ -850,7 +848,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // parse 'kind' if (config.count(kKind)) { - if (auto kind = config[kKind].as()) { + if (auto kind = config[kKind].try_cast()) { target->kind = GetTargetKind(kind.value()); ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) << "Cannot use both set_attrs_preprocessor and set_target_parser"; @@ -875,7 +873,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // parse "tag" if (config.count(kTag)) { - if (auto tag = config[kTag].as()) { + if (auto tag = config[kTag].try_cast()) { target->tag = tag.value(); config.erase(kTag); } else { @@ -891,9 +889,9 @@ ObjectPtr TargetInternal::FromConfig(Map config) { bool has_user_keys = config.count(kKeys); if (has_user_keys) { // user provided keys - if (const auto* cfg_keys = config[kKeys].as()) { + if (const auto* cfg_keys = config[kKeys].as()) { for (const Any& e : *cfg_keys) { - if (auto key = e.as()) { + if (auto key = e.try_cast()) { keys.push_back(key.value()); } else { TVM_FFI_THROW(TypeError) << "Expect 'keys' to be an array of strings, but it " @@ -907,7 +905,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { } // add device name if (config.count(kDeviceName)) { - if (auto device = config.at(kDeviceName).as()) { + if (auto device = config.at(kDeviceName).try_cast()) { keys.push_back(device.value()); } } @@ -926,7 +924,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { target->host = ffi::Function(ConstructorDispatcher)(config[kHost]).cast(); config.erase(kHost); } else { - target->host = NullOpt; + target->host = std::nullopt; } // parse attrs std::unordered_map attrs; @@ -945,7 +943,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = attrs.at("from_device").as().value(); + int device_id = attrs.at("from_device").cast(); attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1010,16 +1008,16 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, /********** Registry **********/ -TVM_REGISTER_GLOBAL("target.Target").set_body_packed(TargetInternal::ConstructorDispatcher); -TVM_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); -TVM_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); -TVM_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); -TVM_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); -TVM_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); -TVM_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) { +TVM_FFI_REGISTER_GLOBAL("target.Target").set_body_packed(TargetInternal::ConstructorDispatcher); +TVM_FFI_REGISTER_GLOBAL("target.TargetEnterScope").set_body_typed(TargetInternal::EnterScope); +TVM_FFI_REGISTER_GLOBAL("target.TargetExitScope").set_body_typed(TargetInternal::ExitScope); +TVM_FFI_REGISTER_GLOBAL("target.TargetCurrent").set_body_typed(Target::Current); +TVM_FFI_REGISTER_GLOBAL("target.TargetExport").set_body_typed(TargetInternal::Export); +TVM_FFI_REGISTER_GLOBAL("target.WithHost").set_body_typed(TargetInternal::WithHost); +TVM_FFI_REGISTER_GLOBAL("target.TargetGetDeviceType").set_body_typed([](const Target& target) { return target->GetTargetDeviceType(); }); -TVM_REGISTER_GLOBAL("target.TargetGetFeature") +TVM_FFI_REGISTER_GLOBAL("target.TargetGetFeature") .set_body_typed([](const Target& target, const String& feature_key) -> Any { if (auto opt_any = target->GetFeature(feature_key)) { return opt_any.value(); diff --git a/src/target/target_info.cc b/src/target/target_info.cc index a63e45a81a4a..6e673905d3c2 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -20,8 +20,8 @@ /*! * \file target/target_info.cc */ +#include #include -#include #include namespace tvm { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 2a8dad4162cf..cdec2ede0643 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -21,9 +21,9 @@ * \file src/target/target_kind.cc * \brief Target kind registry */ +#include #include #include -#include #include #include @@ -90,7 +90,7 @@ const AttrRegistryMapContainerMap& TargetKind::GetAttrMapContainer( Optional TargetKind::Get(const String& target_kind_name) { const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name); if (reg == nullptr) { - return NullOpt; + return std::nullopt; } return reg->kind_; } @@ -145,7 +145,7 @@ void CheckOrSetAttr(Map* attrs, const String& name, const Stri if (iter == attrs->end()) { attrs->Set(name, value); } else { - auto str = (*iter).second.as(); + auto str = (*iter).second.try_cast(); ICHECK(str && str.value() == value) << "ValueError: Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; } @@ -446,7 +446,7 @@ TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break /********** Registry **********/ -TVM_REGISTER_GLOBAL("target.TargetKindGetAttr") +TVM_FFI_REGISTER_GLOBAL("target.TargetKindGetAttr") .set_body_typed([](TargetKind kind, String attr_name) -> ffi::Any { auto target_attr_map = TargetKind::GetAttrMap(attr_name); ffi::Any rv; @@ -455,10 +455,11 @@ TVM_REGISTER_GLOBAL("target.TargetKindGetAttr") } return rv; }); -TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds); -TVM_REGISTER_GLOBAL("target.ListTargetKindOptions") +TVM_FFI_REGISTER_GLOBAL("target.ListTargetKinds") + .set_body_typed(TargetKindRegEntry::ListTargetKinds); +TVM_FFI_REGISTER_GLOBAL("target.ListTargetKindOptions") .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); -TVM_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName") +TVM_FFI_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName") .set_body_typed([](String target_kind_name) { TargetKind kind = TargetKind::Get(target_kind_name).value(); return TargetKindRegEntry::ListTargetKindOptions(kind); diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index 3842776a6fd4..a39756662621 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -191,7 +191,7 @@ VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { virtual_device->target, virtual_device->memory_scope); } -TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") +TVM_FFI_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope); } // namespace tvm diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 765113cddc35..294b34bf5d2e 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include #include @@ -148,7 +148,7 @@ ComputeOp::ComputeOp(std::string name, std::string tag, Map at data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ComputeOp") +TVM_FFI_REGISTER_GLOBAL("te.ComputeOp") .set_body_typed([](std::string name, std::string tag, Optional> attrs, Array axis, Array body) { return ComputeOp(name, tag, attrs.value_or({}), axis, body); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index d087f845cc0f..1534cfc35889 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -20,8 +20,8 @@ #include "create_primfunc.h" #include +#include #include -#include #include #include #include @@ -299,7 +299,7 @@ Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, CreateFuncInfo* info) { Map annotations; auto mutate_attr = [&info](const ffi::Any& value) -> ffi::Any { - if (auto tensor_value = value.as()) { + if (auto tensor_value = value.try_cast()) { return info->tensor2buffers.at(tensor_value.value()); } else { return value; @@ -309,7 +309,7 @@ Map GenerateBlockAnnotations(const te::ComputeOp& compute_op, const String& key = pair.first; const Any& value = pair.second; // TensorIR will not allow Tensor data structure - if (value.as()) { + if (value.as()) { const auto array_value = Downcast>(value); annotations.Set(key, array_value.Map(mutate_attr)); } else { @@ -337,7 +337,7 @@ Stmt GenerateInitStmt(const Array& indices, const Array& buffe auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); }; - Optional init = NullOpt; + Optional init = std::nullopt; Stmt body; int n_buffers = buffers.size(); Array init_stmts; @@ -521,7 +521,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in // for the leaf scope, we ensure at least one block var exists IterVar dummy(Range::FromMinExtent(0, 1), Var("vi", DataType::Int(32)), IterVarType::kDataPar); - cur_scope.AddBlockIter(NullOpt, dummy, 0); + cur_scope.AddBlockIter(std::nullopt, dummy, 0); } scopes.push_back(cur_scope); } @@ -569,7 +569,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in /*writes=*/{}, /*name_hint=*/info->FreshName(buffers[i]->name), /*body=*/body, - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/annotations))); @@ -584,7 +584,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in auto block_name = info->FreshName(compute_op->name + "_l" + std::to_string(i)); const auto& block_iters = cur.block_iters; - Optional init{NullOpt}; + Optional init{std::nullopt}; if (reduce && std::any_of(block_iters.begin(), block_iters.end(), [](const IterVar& iter) { return iter->iter_type == IterVarType::kCommReduce; })) { @@ -659,7 +659,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf /*writes=*/{}, /*name_hint=*/info->FreshName(extern_op->name), /*body=*/std::move(body), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/extern_op->attrs)); @@ -784,15 +784,16 @@ PrimFunc CreatePrimFunc(const Array& arg_list, return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); } -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - Array arg_list = args[0].cast>(); - std::optional index_dtype_override{std::nullopt}; - // Add conversion to make std::optional compatible with FFI. - if (args[1] != nullptr) { - index_dtype_override = args[1].cast(); - } - *ret = CreatePrimFunc(arg_list, index_dtype_override); -}); +TVM_FFI_REGISTER_GLOBAL("te.CreatePrimFunc") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + Array arg_list = args[0].cast>(); + std::optional index_dtype_override{std::nullopt}; + // Add conversion to make std::optional compatible with FFI. + if (args[1] != nullptr) { + index_dtype_override = args[1].cast(); + } + *ret = CreatePrimFunc(arg_list, index_dtype_override); + }); // Relax version impl PrimFunc GenerateAndCompletePrimFunc(const Array& arg_tir_var_list, diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index dc045156d114..eb4a6183dd5c 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -20,7 +20,7 @@ #ifndef TVM_TE_OPERATION_CREATE_PRIMFUNC_H_ #define TVM_TE_OPERATION_CREATE_PRIMFUNC_H_ -#include +#include #include #include diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 2bee0555570e..9f8531998e88 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -22,7 +22,7 @@ * \file extern_op.cc */ #include -#include +#include #include #include @@ -70,7 +70,7 @@ ExternOp::ExternOp(std::string name, std::string tag, Map attr data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ExternOp") +TVM_FFI_REGISTER_GLOBAL("te.ExternOp") .set_body_typed([](std::string name, std::string tag, Optional> attrs, Array inputs, Array input_placeholders, Array output_placeholders, Stmt body) { diff --git a/src/te/operation/graph.cc b/src/te/operation/graph.cc index aee7f2afb188..e2bbced85f89 100644 --- a/src/te/operation/graph.cc +++ b/src/te/operation/graph.cc @@ -23,7 +23,7 @@ */ #include "graph.h" -#include +#include #include #include #include @@ -80,9 +80,9 @@ Array PostDFSOrder(const Array& roots, const ReadGraph& g) return post_order; } -TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); +TVM_FFI_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); -TVM_REGISTER_GLOBAL("schedule.PostDFSOrder") +TVM_FFI_REGISTER_GLOBAL("schedule.PostDFSOrder") .set_body_typed([](const Array& roots, const ReadGraph& g) { return PostDFSOrder(roots, g); }); diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 53a001a5bb4d..cce70420c0bd 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -21,8 +21,8 @@ * \brief Placeholder op. * \file placeholder_op.cc */ -#include -#include +#include +#include #include namespace tvm { @@ -61,7 +61,7 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { return PlaceholderOp(name, shape, dtype).output(0); } -TVM_REGISTER_GLOBAL("te.Placeholder") +TVM_FFI_REGISTER_GLOBAL("te.Placeholder") .set_body_typed([](Variant> shape_arg, DataType dtype, std::string name) { auto shape = [&]() -> Array { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 5e6ad5c78f38..f4860cf71ef7 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -21,7 +21,7 @@ * \brief Scan Operator. * \file scan_op.cc */ -#include +#include #include #include @@ -97,7 +97,7 @@ ScanOp::ScanOp(std::string name, std::string tag, Optional data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.ScanOp") +TVM_FFI_REGISTER_GLOBAL("te.ScanOp") .set_body_typed([](std::string name, std::string tag, Optional> attrs, IterVar axis, Array init, Array update, Array state_placeholder, Array inputs) { diff --git a/src/te/tensor.cc b/src/te/tensor.cc index f46c095f3b08..a23f4b494ece 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -20,7 +20,7 @@ /*! * \file tensor.cc */ -#include +#include #include #include @@ -98,7 +98,7 @@ Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_in data_ = std::move(n); } -TVM_REGISTER_GLOBAL("te.Tensor") +TVM_FFI_REGISTER_GLOBAL("te.Tensor") .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { return Tensor(shape, dtype, op, value_index); }); @@ -112,19 +112,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Other tensor ops. -TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); +TVM_FFI_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); -TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { +TVM_FFI_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { return static_cast(std::hash()(tensor)); }); -TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { +TVM_FFI_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); -TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); +TVM_FFI_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); -TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); +TVM_FFI_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); } // namespace te } // namespace tvm diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index d8fcee859f03..ce13ac56c81d 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -402,8 +402,9 @@ Array> GetBlockReadWriteRegion(const Block& block, return {reads, writes}; } -TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); -TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion") + .set_body_typed(GetBlockReadWriteRegion); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index dd1fce0fbef7..aca4c99e1197 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -62,7 +62,8 @@ class LCADetector : public StmtExprVisitor { Map> buffer_lca; for (const auto& kv : detector.buffer_lca_) { const Buffer& buffer = GetRef(kv.first); - const Optional stmt = kv.second ? GetRef>(kv.second->stmt) : NullOpt; + const Optional stmt = + kv.second ? GetRef>(kv.second->stmt) : std::nullopt; buffer_lca.Set(buffer, stmt); } return buffer_lca; @@ -345,6 +346,7 @@ Map> DetectBufferAccessLCA(const PrimFunc& func) { return LCADetector::Detect(func); } -TVM_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca").set_body_typed(DetectBufferAccessLCA); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.detect_buffer_access_lca") + .set_body_typed(DetectBufferAccessLCA); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 04d3fc729c71..de208ce9c1e0 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -22,7 +22,7 @@ * \brief Calculate allocated memory per memory scope required by PrimFuncs. */ #include -#include +#include #include #include #include @@ -96,7 +96,7 @@ tvm::Map > CalculateAllocatedBytes(const IRMod return results; } -TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") .set_body_typed([](ObjectRef obj) -> tvm::Map > { if (auto func = obj.as()) { return CalculateAllocatedBytes(func.value()); @@ -155,7 +155,7 @@ Array GetVTCMCompactionPasses() { return pass_list; } -TVM_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { return GetVTCMCompactionPasses(); }); @@ -191,7 +191,7 @@ Pass VerifyVTCMLimit(Optional default_target) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 1065ad3bf1e0..a9c2b9ecc609 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -24,7 +24,7 @@ #include "control_flow_graph.h" -#include +#include #include #include #include @@ -163,7 +163,7 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { continue; } - Optional lane_var = NullOpt; + Optional lane_var = std::nullopt; IntImm num_lanes; Array indices = op->indices.Map([&](const auto& index) { @@ -522,9 +522,9 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { ICHECK_LE(to_block, out_->control_flow_.size()); auto& forward = out_->control_flow_[from_block].successors.emplace_back( - ControlFlowGraph::ControlFlowEdge{to_block, {}, NullOpt}); + ControlFlowGraph::ControlFlowEdge{to_block, {}, std::nullopt}); auto& backward = out_->control_flow_[to_block].predecessors.emplace_back( - ControlFlowGraph::ControlFlowEdge{from_block, {}, NullOpt}); + ControlFlowGraph::ControlFlowEdge{from_block, {}, std::nullopt}); return {forward, backward}; } @@ -554,7 +554,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { With analyzer_context; size_t old_num_constraints{0}; size_t new_num_constraints{0}; - Optional assume{NullOpt}; + Optional assume{std::nullopt}; // Disable default-generated copy/move assignment and constructors InternalConstraintContext(const InternalConstraintContext&) = delete; @@ -642,7 +642,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock::Make Analyzer local_analyzer; - Optional lane_var = NullOpt; + Optional lane_var = std::nullopt; IntImm num_lanes; Array index_expressions = indices.Map([&](const auto& index) { @@ -1090,7 +1090,7 @@ class BufferRegionCollector : public ExprVisitor { Analyzer local_analyzer; if (!is_zero(unknown_region)) { - new_regions.insert(new_regions.begin(), Known{unknown_region, NullOpt}); + new_regions.insert(new_regions.begin(), Known{unknown_region, std::nullopt}); } std::vector updated_regions; @@ -1329,7 +1329,7 @@ Optional> ControlFlowGraph::GetIndexVariables(const Buffer& buf) cons if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) { return (*it).second; } else { - return NullOpt; + return std::nullopt; } } diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index 543feeecfea1..f4babffbb74c 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 86431f1ac2f2..07d6500570f8 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -21,10 +21,10 @@ * \file tir/analysis/deep_equal.cc * \brief Deep equality checking. */ +#include #include #include #include -#include #include namespace tvm { @@ -65,10 +65,10 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } - return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt); + return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, std::nullopt); } -TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { return ExprDeepEqual()(lhs, rhs); }); diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index c4851e255f0e..df41e7da1807 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -173,6 +173,13 @@ class FlopEstimator : private ExprFunctor, return cond; } + TResult VisitStmt_(const WhileNode* op) override { + // TODO(jikechao): Improve while loop FLOP estimation with loop bound analysis + TResult result = VisitExpr(op->condition); + result += VisitStmt(op->body); + return result; + } + TResult VisitStmt_(const LetStmtNode* let) override { TResult value = VisitExpr(let->value); value += VisitStmt(let->body); @@ -193,6 +200,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const DeclBufferNode* op) override { return VisitStmt(op->body); } + TResult VisitStmt_(const EvaluateNode* op) override { return TResult(); } TResult VisitStmt_(const SeqStmtNode* seq) override { TResult result; @@ -238,17 +246,18 @@ double EstimateTIRFlops(const IRModule& mod) { return PostprocessResults(result) + cached_result; } -TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { - if (auto mod = obj.as()) { - return EstimateTIRFlops(mod.value()); - } else if (auto stmt = obj.as()) { - return EstimateTIRFlops(stmt.value()); - } else { - LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " - << obj->GetTypeKey(); - throw; - } -}); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops") + .set_body_typed([](ObjectRef obj) -> double { + if (auto mod = obj.as()) { + return EstimateTIRFlops(mod.value()); + } else if (auto stmt = obj.as()) { + return EstimateTIRFlops(stmt.value()); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " + << obj->GetTypeKey(); + throw; + } + }); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/identify_memcpy.cc b/src/tir/analysis/identify_memcpy.cc index e36bb3a4f379..dcffe1c1d6b8 100644 --- a/src/tir/analysis/identify_memcpy.cc +++ b/src/tir/analysis/identify_memcpy.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include #include @@ -282,7 +282,7 @@ std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* an } // Expose the IdentifyMemCpy functionality to Python API for purpose of unit testing. -TVM_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { +TVM_FFI_REGISTER_GLOBAL("tir.analysis._identify_memcpy").set_body_typed([](const Stmt& stmt) { Array output; struct Visitor : arith::IRVisitorWithAnalyzer { diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index ee893987c91e..4af823604971 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -91,7 +91,7 @@ bool IsPureFunction(const PrimFunc& func, bool assert_on_error) { return PurityChecker::Check(func, assert_on_error); } -TVM_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc index dbe114df4973..898a92adc7db 100644 --- a/src/tir/analysis/oob_checker.cc +++ b/src/tir/analysis/oob_checker.cc @@ -123,7 +123,7 @@ transform::Pass OOBChecker() { return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {}); } -TVM_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 527c7d04c6f3..b5a23e35d276 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -139,12 +139,12 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) { return nullptr; } -TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { +TVM_FFI_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { auto ret = FindAnchorBlock(mod); if (ret) { return Optional(GetRef(ret)); } - return Optional(NullOpt); + return Optional(std::nullopt); }); } // namespace tir diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 654d3332c755..0d75cebac798 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -199,7 +199,7 @@ Array UndefinedVars(const PrimExpr& expr, const Array& args) { return m.undefined_; } -TVM_REGISTER_GLOBAL("tir.analysis.UndefinedVars") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.UndefinedVars") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_stmt = args[0].as()) { *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index d109736863c7..ef46a41687ad 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -24,7 +24,7 @@ * in a block exceeds the limit */ -#include +#include #include #include #include @@ -321,7 +321,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { return errs.size() == 0; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); namespace transform { @@ -346,7 +346,7 @@ Pass VerifyGPUCode(Map constraints) { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index f8681189a1e6..bc567879c22b 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -21,8 +21,8 @@ * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ +#include #include -#include #include #include #include @@ -186,7 +186,7 @@ std::vector VerifyMemory_(const PrimFunc& func) { bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); namespace transform { @@ -211,7 +211,7 @@ Pass VerifyMemory() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); } // namespace transform } // namespace tir diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index f238ffd763b1..33abb39c367f 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -23,7 +23,7 @@ * SSA requires each varaible to be only defined once. * \file verify_ssa.cc */ -#include +#include #include #include #include @@ -139,7 +139,7 @@ bool VerifySSA(const PrimFunc& func) { return visitor.is_ssa_; } -TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); +TVM_FFI_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); namespace transform { @@ -155,7 +155,7 @@ Pass VerifySSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); } // namespace transform diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index cfdc2f35515a..a0c5f4829bf8 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -22,7 +22,7 @@ * \brief Check if schedulable tir is well-formed. */ -#include +#include #include #include @@ -368,7 +368,7 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { return true; } -TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") +TVM_FFI_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed") .set_body_typed([](const ObjectRef& obj, bool assert_mode) { if (auto opt = obj.as()) { return VerifyWellFormed(opt.value(), assert_mode); diff --git a/src/tir/ir/block_dependence_info.cc b/src/tir/ir/block_dependence_info.cc index 13f38c86bf2a..dc1e3c48c924 100644 --- a/src/tir/ir/block_dependence_info.cc +++ b/src/tir/ir/block_dependence_info.cc @@ -85,14 +85,14 @@ BlockDependenceInfo::BlockDependenceInfo(IRModule mod) { } TVM_REGISTER_NODE_TYPE(BlockDependenceInfoNode); -TVM_REGISTER_GLOBAL("tir.BlockDependenceInfo") +TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfo") .set_body_typed([](IRModule mod) -> BlockDependenceInfo { return BlockDependenceInfo(mod); }); -TVM_REGISTER_GLOBAL("tir.BlockDependenceInfoGetBlockScope") +TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfoGetBlockScope") .set_body_method(&BlockDependenceInfoNode::GetBlockScope); -TVM_REGISTER_GLOBAL("tir.BlockDependenceInfoGetSRef") +TVM_FFI_REGISTER_GLOBAL("tir.BlockDependenceInfoGetSRef") .set_body_typed([](BlockDependenceInfo self, Stmt stmt) -> Optional { auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(NullOpt); + return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); }); } // namespace tir diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc index 5320c5d68a37..381fae73a475 100644 --- a/src/tir/ir/block_scope.cc +++ b/src/tir/ir/block_scope.cc @@ -190,18 +190,21 @@ TVM_REGISTER_NODE_TYPE(StmtSRefNode); TVM_REGISTER_NODE_TYPE(DependencyNode); TVM_REGISTER_NODE_TYPE(BlockScopeNode); -TVM_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> Optional { +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> Optional { return GetRef>(sref->stmt); }); -TVM_REGISTER_GLOBAL("tir.StmtSRefParent").set_body_typed([](StmtSRef sref) -> Optional { - return GetRef>(sref->parent); -}); -TVM_REGISTER_GLOBAL("tir.StmtSRefRootMark") // +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefParent") + .set_body_typed([](StmtSRef sref) -> Optional { + return GetRef>(sref->parent); + }); +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefRootMark") // .set_body_typed(StmtSRef::RootMark); -TVM_REGISTER_GLOBAL("tir.StmtSRefInlineMark") // +TVM_FFI_REGISTER_GLOBAL("tir.StmtSRefInlineMark") // .set_body_typed(StmtSRef::InlineMark); -TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc").set_body_method(&BlockScopeNode::GetDepsBySrc); -TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst").set_body_method(&BlockScopeNode::GetDepsByDst); +TVM_FFI_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc") + .set_body_method(&BlockScopeNode::GetDepsBySrc); +TVM_FFI_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst") + .set_body_method(&BlockScopeNode::GetDepsByDst); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 3b94c2ae757a..bce9c2c4e1a8 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -21,8 +21,8 @@ * \file buffer.cc */ #include +#include #include -#include #include #include #include @@ -640,7 +640,7 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std TVM_REGISTER_NODE_TYPE(BufferNode); -TVM_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { +TVM_FFI_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].cast(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; @@ -658,17 +658,18 @@ TVM_REGISTER_GLOBAL("tir.Buffer").set_body_packed([](ffi::PackedArgs args, ffi:: axis_separators, span); }); -TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); +TVM_FFI_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); -TVM_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer").set_body_method(&Buffer::GetFlattenedBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.BufferGetFlattenedBuffer") + .set_body_method(&Buffer::GetFlattenedBuffer); -TVM_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); +TVM_FFI_REGISTER_GLOBAL("tir.BufferOffsetOf").set_body_method(&Buffer::OffsetOf); -TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); +TVM_FFI_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); -TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); +TVM_FFI_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); -TVM_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); +TVM_FFI_REGISTER_GLOBAL("tir.BufferStorageScope").set_body_method(&Buffer::scope); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 119322455b1c..96f87344cbea 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -22,7 +22,7 @@ * \brief Data Layout expression. */ #include -#include +#include #include #include @@ -427,43 +427,45 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) { +TVM_FFI_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) { return Layout(name, dtype); }); -TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::Get(axis)); -}); +TVM_FFI_REGISTER_GLOBAL("tir.LayoutIndexOf") + .set_body_typed([](Layout layout, std::string axis) -> int { + return layout.IndexOf(LayoutAxis::Get(axis)); + }); -TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") +TVM_FFI_REGISTER_GLOBAL("tir.LayoutFactorOf") .set_body_typed([](Layout layout, std::string axis) -> int { return layout.FactorOf(LayoutAxis::Get(axis)); }); -TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { +TVM_FFI_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { return layout.ndim(); }); -TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string { - const LayoutAxis& axis = layout[idx]; - return axis.name(); -}); +TVM_FFI_REGISTER_GLOBAL("tir.LayoutGetItem") + .set_body_typed([](Layout layout, int idx) -> std::string { + const LayoutAxis& axis = layout[idx]; + return axis.name(); + }); -TVM_REGISTER_GLOBAL("tir.BijectiveLayout") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayout") .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { return BijectiveLayout(src_layout, dst_layout); }); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") .set_body_method(&BijectiveLayout::ForwardIndex); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") .set_body_method(&BijectiveLayout::BackwardIndex); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") .set_body_method(&BijectiveLayout::ForwardShape); -TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") +TVM_FFI_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") .set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 8a496d06bed8..304fe0bf820f 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -340,7 +340,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) { [this](const BufferRegion& buffer_region) { return this->VisitBufferRegion(buffer_region); }); Array new_iter_vars = op->iter_vars.Map([this](const IterVar& iter_var) { return this->VisitIterVar(iter_var); }); - Optional new_init = NullOpt; + Optional new_init = std::nullopt; if (op->init.defined()) { new_init = this->VisitStmt(op->init.value()); } @@ -381,7 +381,7 @@ Map IndexDataTypeRewriter::VisitBlockAnnotations( if (Buffer new_buffer = GetRemappedBuffer(buffer); !new_buffer.same_as(buffer)) { return new_buffer; } - } else if (obj->IsInstance()) { + } else if (obj->IsInstance()) { return Downcast>(obj).Map(f_mutate_obj); } return obj; @@ -521,7 +521,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) { Stmt then_case = VisitStmt(op->then_case); Optional else_case = - op->else_case.defined() ? Optional{VisitStmt(op->else_case.value())} : NullOpt; + op->else_case.defined() ? Optional{VisitStmt(op->else_case.value())} : std::nullopt; if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) || !else_case.same_as(op->else_case)) { IfThenElse new_stmt = GetRef(op); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index e89162239a63..0ac59b160200 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -20,7 +20,7 @@ /*! * \file expr.cc */ -#include +#include #include #include #include @@ -43,7 +43,7 @@ namespace tir { * `expr.dtype` field), this function allows the FFI conversions to be * explicitly invoked. */ -TVM_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { +TVM_FFI_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { return expr; }); @@ -127,7 +127,8 @@ Var Var::copy_with_dtype(DataType dtype) const { return Var(new_ptr); } -TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, ffi::AnyView type, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, ffi::AnyView type, + Span span) { if (type.as()) { return Var(name_hint, type.cast(), span); } else { @@ -156,7 +157,7 @@ SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { return SizeVar(s, t, span); }); @@ -182,7 +183,7 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.IterVar") +TVM_FFI_REGISTER_GLOBAL("tir.IterVar") .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag, Span span) { return IterVar(dom, var, static_cast(iter_type), thread_tag, span); }); @@ -198,7 +199,7 @@ StringImm::StringImm(String value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) { return StringImm(value, span); }); @@ -216,7 +217,7 @@ Cast::Cast(DataType t, PrimExpr value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) { return Cast(dtype, value, span); }); @@ -225,7 +226,7 @@ TVM_REGISTER_NODE_TYPE(CastNode); // Add TVM_DEFINE_BINOP_CONSTRUCTOR(Add); -TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Add(a, b, span); }); @@ -234,7 +235,7 @@ TVM_REGISTER_NODE_TYPE(AddNode); // Sub TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); -TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Sub(a, b, span); }); @@ -243,7 +244,7 @@ TVM_REGISTER_NODE_TYPE(SubNode); // Mul TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); -TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Mul(a, b, span); }); @@ -252,7 +253,7 @@ TVM_REGISTER_NODE_TYPE(MulNode); // Div TVM_DEFINE_BINOP_CONSTRUCTOR(Div); -TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Div(a, b, span); }); @@ -261,7 +262,7 @@ TVM_REGISTER_NODE_TYPE(DivNode); // Mod TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); -TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Mod(a, b, span); }); @@ -270,7 +271,7 @@ TVM_REGISTER_NODE_TYPE(ModNode); // FloorDiv TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); -TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return FloorDiv(a, b, span); }); @@ -279,7 +280,7 @@ TVM_REGISTER_NODE_TYPE(FloorDivNode); // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); -TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return FloorMod(a, b, span); }); @@ -288,7 +289,7 @@ TVM_REGISTER_NODE_TYPE(FloorModNode); // Min TVM_DEFINE_BINOP_CONSTRUCTOR(Min); -TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Min(a, b, span); }); @@ -297,7 +298,7 @@ TVM_REGISTER_NODE_TYPE(MinNode); // Max TVM_DEFINE_BINOP_CONSTRUCTOR(Max); -TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Max(a, b, span); }); @@ -306,7 +307,7 @@ TVM_REGISTER_NODE_TYPE(MaxNode); // EQ TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); -TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return EQ(a, b, span); }); @@ -315,7 +316,7 @@ TVM_REGISTER_NODE_TYPE(EQNode); // NE TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); -TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return NE(a, b, span); }); @@ -324,7 +325,7 @@ TVM_REGISTER_NODE_TYPE(NENode); // LT TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); -TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return LT(a, b, span); }); @@ -333,7 +334,7 @@ TVM_REGISTER_NODE_TYPE(LTNode); // LE TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); -TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return LE(a, b, span); }); @@ -342,7 +343,7 @@ TVM_REGISTER_NODE_TYPE(LENode); // GT TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); -TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return GT(a, b, span); }); @@ -351,7 +352,7 @@ TVM_REGISTER_NODE_TYPE(GTNode); // GE TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); -TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return GE(a, b, span); }); @@ -374,7 +375,7 @@ And::And(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return And(a, b, span); }); @@ -397,7 +398,7 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { return Or(a, b, span); }); @@ -416,7 +417,9 @@ Not::Not(PrimExpr a, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { return Not(a, span); }); +TVM_FFI_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { + return Not(a, span); +}); TVM_REGISTER_NODE_TYPE(NotNode); @@ -442,7 +445,7 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Select") +TVM_FFI_REGISTER_GLOBAL("tir.Select") .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { return Select(condition, true_value, false_value, span); }); @@ -481,7 +484,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Ramp") +TVM_FFI_REGISTER_GLOBAL("tir.Ramp") .set_body_typed([](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { return Ramp(base, stride, lanes, span); }); @@ -514,9 +517,10 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { data_ = node; } -TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { - return Broadcast(value, lanes, span); -}); +TVM_FFI_REGISTER_GLOBAL("tir.Broadcast") + .set_body_typed([](PrimExpr value, PrimExpr lanes, Span span) { + return Broadcast(value, lanes, span); + }); TVM_REGISTER_NODE_TYPE(BroadcastNode); @@ -535,8 +539,8 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body, - Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body, + Span span) { return Let(var, value, body, span); }); @@ -556,38 +560,37 @@ Call::Call(DataType dtype, RelaxExpr op, Array args, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed( - [](Optional dtype, RelaxExpr op, - Array> args, - Span span) { - Array prim_expr_args; - for (const auto& it : args) { - if (auto opt_str = it.as()) { - prim_expr_args.push_back(StringImm(opt_str.value())); - } else if (auto opt_dtype = it.as()) { - prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); - } else if (const auto* iter_var = it.as()) { - prim_expr_args.push_back(iter_var->var); - } else if (const auto* br = it.as()) { - Array indices; - for (Range r : br->region) { - if (is_one(r->extent)) { - indices.push_back(r->min); - } else if (r->extent.as()) { - indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); - } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " - << GetRef(br); - } - } - prim_expr_args.push_back(BufferLoad(br->buffer, indices)); +TVM_FFI_REGISTER_GLOBAL("tir.Call") + .set_body_typed([](Optional dtype, RelaxExpr op, + Array> args, + Span span) { + Array prim_expr_args; + for (const auto& it : args) { + if (auto opt_str = it.as()) { + prim_expr_args.push_back(StringImm(opt_str.value())); + } else if (auto opt_dtype = it.as()) { + prim_expr_args.push_back(StringImm(ffi::DLDataTypeToString(opt_dtype.value()))); + } else if (const auto* iter_var = it.as()) { + prim_expr_args.push_back(iter_var->var); + } else if (const auto* br = it.as()) { + Array indices; + for (Range r : br->region) { + if (is_one(r->extent)) { + indices.push_back(r->min); + } else if (r->extent.as()) { + indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { - prim_expr_args.push_back(Downcast(it)); + LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " + << GetRef(br); } } - return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); - }); + prim_expr_args.push_back(BufferLoad(br->buffer, indices)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } + return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); + }); TVM_REGISTER_NODE_TYPE(CallNode); @@ -632,7 +635,7 @@ PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { return Shuffle({vector}, {Integer(index)}, span); } -TVM_REGISTER_GLOBAL("tir.Shuffle") +TVM_FFI_REGISTER_GLOBAL("tir.Shuffle") .set_body_typed([](Array vectors, Array indices, Span span) { return Shuffle(vectors, indices, span); }); @@ -651,8 +654,8 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, << "ValueError: The number of identities must equal to the number of elements in `results`"; // Change the dtype of input vars to adapt to the dtype of identities - ArrayObj* p_lhs = lhs.CopyOnWrite(); - ArrayObj* p_rhs = rhs.CopyOnWrite(); + ffi::ArrayObj* p_lhs = lhs.CopyOnWrite(); + ffi::ArrayObj* p_rhs = rhs.CopyOnWrite(); std::unordered_map var_map; var_map.reserve(n_group * 2); for (int i = 0; i < static_cast(n_group); ++i) { @@ -666,7 +669,7 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, p_rhs->SetItem(i, r); } - ArrayObj* p_result = result.CopyOnWrite(); + ffi::ArrayObj* p_result = result.CopyOnWrite(); for (int i = 0; i < static_cast(n_group); ++i) { p_result->SetItem(i, Substitute(result[i], var_map)); } @@ -692,13 +695,14 @@ Array CommReducerNode::operator()(Array a, Array b return Substitute(this->result, value_map); } -TVM_REGISTER_GLOBAL("tir.CommReducer") +TVM_FFI_REGISTER_GLOBAL("tir.CommReducer") .set_body_typed([](Array lhs, Array rhs, Array result, Array identity_element, Span span) { return CommReducer(lhs, rhs, result, identity_element, span); }); -TVM_REGISTER_GLOBAL("tir.CommReducerCombine").set_body_method(&tir::CommReducerNode::operator()); +TVM_FFI_REGISTER_GLOBAL("tir.CommReducerCombine") + .set_body_method(&tir::CommReducerNode::operator()); TVM_REGISTER_NODE_TYPE(CommReducerNode); @@ -737,7 +741,7 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.Reduce") +TVM_FFI_REGISTER_GLOBAL("tir.Reduce") .set_body_typed([](CommReducer combiner, Array source, Array axis, PrimExpr condition, int value_index, Array init, Span span) { return Reduce(combiner, source, axis, condition, value_index, init, span); @@ -812,7 +816,7 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices, Optional indices, Optional predicate, Span span) { return BufferLoad(buffer, indices, predicate, span); }); @@ -828,7 +832,7 @@ ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.ProducerLoad") +TVM_FFI_REGISTER_GLOBAL("tir.ProducerLoad") .set_body_typed([](DataProducer producer, Array indices, Span span) { return ProducerLoad(producer, indices, span); }); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 126b97dcfda3..2312d31fd276 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -21,14 +21,12 @@ * \file src/tir/ir/function.cc * \brief The function data structure. */ +#include #include -#include #include #include #include -#include "utils.h" - namespace tvm { namespace tir { namespace { @@ -147,7 +145,7 @@ Optional TensorIntrin::Get(String name, bool allow_missing) { auto it = manager->reg.find(name); if (it == manager->reg.end()) { if (allow_missing) { - return NullOpt; + return std::nullopt; } else { LOG(FATAL) << "ValueError: TensorIntrin '" << name << "' is not registered"; } @@ -157,19 +155,19 @@ Optional TensorIntrin::Get(String name, bool allow_missing) { TVM_REGISTER_NODE_TYPE(TensorIntrinNode); -TVM_REGISTER_GLOBAL("tir.PrimFunc") +TVM_FFI_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, Map buffer_map, DictAttrs attrs, Span span) { return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); -TVM_REGISTER_GLOBAL("tir.TensorIntrin") +TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrin") .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { return TensorIntrin(desc_func, intrin_func); }); -TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); -TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); +TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_FFI_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index b9bb43ca6ba6..901a5d5234ca 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include /*! * \file tir/ir/functor_common.h diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index ff948da01289..7297b62bf36d 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -186,7 +186,7 @@ Array IndexMapNode::MapRanges(const Array& ranges, arith::Analyzer // affine sum. Since the terms are orthogonal, the extent of the // sum is the extent of the largest term. for (const auto& index : iter_map->indices) { - Optional extent = NullOpt; + Optional extent = std::nullopt; for (const auto& term : index->args) { PrimExpr term_extent = term->extent * term->scale; if (extent.defined()) { @@ -351,7 +351,7 @@ IndexMap IndexMap::RenameVariables( [&](const Var& var) { return Downcast(Substitute(var, var_remap)); }); auto new_final_indices = n->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, var_remap); }); - Optional new_inverse_index_map = NullOpt; + Optional new_inverse_index_map = std::nullopt; if (n->inverse_index_map.defined()) { new_inverse_index_map = Downcast(n->inverse_index_map).RenameVariables(f_name_map); } @@ -410,7 +410,7 @@ IndexMap Substitute(const IndexMap& index_map, std::function(const Var& var)> f_subst) { Array new_output = index_map->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, f_subst); }); - Optional new_inverse_map = NullOpt; + Optional new_inverse_map = std::nullopt; if (index_map->inverse_index_map.defined()) { new_inverse_map = Substitute(Downcast(index_map->inverse_index_map.value()), f_subst); } @@ -419,33 +419,34 @@ IndexMap Substitute(const IndexMap& index_map, TVM_REGISTER_NODE_TYPE(IndexMapNode); -TVM_REGISTER_GLOBAL("tir.IndexMap") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMap") .set_body_typed([](Array initial_indices, Array final_indices, Optional inverse_index_map) { return IndexMap(initial_indices, final_indices, inverse_index_map); }); -TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapIndices") .set_body_typed([](IndexMap map, Array indices) { arith::Analyzer analyzer; return map->MapIndices(indices, &analyzer); }); -TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Array shape) { - arith::Analyzer analyzer; - return map->MapShape(shape, &analyzer); -}); +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapShape") + .set_body_typed([](IndexMap map, Array shape) { + arith::Analyzer analyzer; + return map->MapShape(shape, &analyzer); + }); -TVM_REGISTER_GLOBAL("tir.IndexMapInverse") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapInverse") .set_body_typed([](IndexMap map, Array initial_ranges) { arith::Analyzer analyzer; return map.Inverse(initial_ranges, &analyzer); }); -TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapMapNDArray") .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); -TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") +TVM_FFI_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse") .set_body_typed([](IndexMap forward, Array initial_ranges) { arith::Analyzer analyzer; auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index e6e942a87ba6..7e8c2913e55f 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -143,7 +143,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { }(); if (should_insert_root) { - Block root_block({}, {}, {}, "root", std::move(res), NullOpt, root_allocates); + Block root_block({}, {}, {}, "root", std::move(res), std::nullopt, root_allocates); res = BlockRealize({}, Bool(true), std::move(root_block)); } @@ -160,7 +160,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { } } -TVM_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete); +TVM_FFI_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/script/script_complete.h b/src/tir/ir/script/script_complete.h index 8df04566460a..273ca946a7ff 100644 --- a/src/tir/ir/script/script_complete.h +++ b/src/tir/ir/script/script_complete.h @@ -23,7 +23,7 @@ */ #ifndef TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ #define TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ -#include +#include #include #include diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index c7e254faacfe..86ed65c4905d 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -21,7 +21,7 @@ * \file src/tir/ir/specialize.cc * \brief Specialize parameters of PrimFunc. */ -#include +#include #include #include #include @@ -432,7 +432,7 @@ PrimFunc Specialize(PrimFunc func, const Map>& pa /**************** FFI ****************/ -TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); +TVM_FFI_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0a7817c02a3e..62baf45bc78e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -21,13 +21,12 @@ * \file tvm/tir/stmt.cc */ #include -#include +#include #include #include #include #include "buffer_common.h" -#include "utils.h" namespace tvm { namespace tir { @@ -53,7 +52,7 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt") +TVM_FFI_REGISTER_GLOBAL("tir.LetStmt") .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) { return LetStmt(var, value, body, span); }); @@ -71,13 +70,13 @@ AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, S data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.AttrStmt") +TVM_FFI_REGISTER_GLOBAL("tir.AttrStmt") .set_body_typed([](Any node, String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to primexpr. if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - return AttrStmt(node.as().value(), attr_key, value, body, span); + return AttrStmt(node.cast(), attr_key, value, body, span); } - return AttrStmt(node.as().value(), attr_key, value, body, span); + return AttrStmt(node.cast(), attr_key, value, body, span); }); TVM_REGISTER_NODE_TYPE(AttrStmtNode); @@ -101,7 +100,7 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span spa TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_REGISTER_GLOBAL("tir.AssertStmt") +TVM_FFI_REGISTER_GLOBAL("tir.AssertStmt") .set_body_typed([](PrimExpr condition, StringImm message, Stmt body, Span span) { return AssertStmt(condition, message, body, span); }); @@ -156,7 +155,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.For").set_body_typed( +TVM_FFI_REGISTER_GLOBAL("tir.For").set_body_typed( [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, Optional thread_binding, Optional> annotations, Span span) { return For(loop_var, min, extent, static_cast(kind), body, thread_binding, @@ -200,7 +199,7 @@ While::While(PrimExpr condition, Stmt body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) { return While(condition, body, span); }); @@ -217,7 +216,7 @@ ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices, Span span) { return ProducerStore(producer, value, indices, span); }); @@ -268,7 +267,7 @@ int64_t AllocateNode::ConstantAllocationSize(const Array& extents) { return static_cast(result); } -TVM_REGISTER_GLOBAL("tir.Allocate") +TVM_FFI_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { return Allocate(buffer_var, type, extents, condition, body, annotations, span); @@ -329,7 +328,7 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array& extents } return static_cast(result); } -TVM_REGISTER_GLOBAL("tir.AllocateConst") +TVM_FFI_REGISTER_GLOBAL("tir.AllocateConst") .set_body_typed([](Var buffer_var, DataType dtype, Array extents, ObjectRef data_or_idx, Stmt body, Optional> annotations, Span span) { @@ -348,7 +347,7 @@ DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) { return DeclBuffer(buffer, body, span); }); @@ -377,7 +376,7 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.ProducerRealize") +TVM_FFI_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { return ProducerRealize(producer, bounds, condition, body, storage_scope, span); @@ -390,7 +389,7 @@ Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { data_ = make_object(buffer, bounds, span); } -TVM_REGISTER_GLOBAL("tir.Prefetch") +TVM_FFI_REGISTER_GLOBAL("tir.Prefetch") .set_body_typed([](Buffer buffer, Array bounds, Span span) { return Prefetch(buffer, bounds, span); }); @@ -424,7 +423,7 @@ SeqStmt::SeqStmt(Array seq, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { return SeqStmt(std::move(seq), span); }); @@ -445,7 +444,7 @@ IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_c TVM_REGISTER_NODE_TYPE(IfThenElseNode); -TVM_REGISTER_GLOBAL("tir.IfThenElse") +TVM_FFI_REGISTER_GLOBAL("tir.IfThenElse") .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { return IfThenElse(condition, then_case, else_case, span); }); @@ -460,7 +459,7 @@ Evaluate::Evaluate(PrimExpr value, Span span) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { return Evaluate(value, span); }); @@ -542,7 +541,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BufferStore") +TVM_FFI_REGISTER_GLOBAL("tir.BufferStore") .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Optional predicate, Span span) { return BufferStore(buffer, value, indices, predicate, span); }); @@ -555,7 +554,7 @@ BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condit data_ = make_object(buffer, bounds, condition, body, span); } -TVM_REGISTER_GLOBAL("tir.BufferRealize") +TVM_FFI_REGISTER_GLOBAL("tir.BufferRealize") .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); @@ -609,7 +608,7 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { return BufferRegion(buffer, region); } -TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { +TVM_FFI_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { return BufferRegion(buffer, region); }); @@ -666,9 +665,10 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) { - return MatchBufferRegion(buffer, source); -}); +TVM_FFI_REGISTER_GLOBAL("tir.MatchBufferRegion") + .set_body_typed([](Buffer buffer, BufferRegion source) { + return MatchBufferRegion(buffer, source); + }); TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode); @@ -690,7 +690,7 @@ Block::Block(Array iter_vars, Array reads, Array iter_vars, Array reads, Array writes, String name_hint, Stmt body, Optional init, Array alloc_buffers, Array match_buffers, @@ -714,7 +714,7 @@ BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block blo data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BlockRealize") +TVM_FFI_REGISTER_GLOBAL("tir.BlockRealize") .set_body_typed([](Array iter_values, PrimExpr predicate, Block block, Span span) { return BlockRealize(iter_values, predicate, block, span); }); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 23dba3ef7233..85d347172702 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -19,8 +19,8 @@ /*! * \file stmt_functor.cc */ +#include #include -#include #include #include #include @@ -349,7 +349,7 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); - Optional else_case = NullOpt; + Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } @@ -520,7 +520,7 @@ Stmt StmtMutator::VisitStmt_(const BlockNode* op) { Array reads = Internal::Mutate(this, op->reads); Array writes = Internal::Mutate(this, op->writes); Array match_buffers = Internal::Mutate(this, op->match_buffers); - Optional init = NullOpt; + Optional init = std::nullopt; if (op->init.defined()) { init = VisitStmt(op->init.value()); } @@ -892,17 +892,17 @@ PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); } -TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); +TVM_FFI_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); -TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { +TVM_FFI_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); -TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { +TVM_FFI_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, ffi::Function f) { tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); }); -TVM_REGISTER_GLOBAL("tir.Substitute") +TVM_FFI_REGISTER_GLOBAL("tir.Substitute") .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { if (node->IsInstance()) { return Substitute(Downcast(node), vmap); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index f724b6a74598..6a5e1191d219 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -21,9 +21,9 @@ * \file tir/ir/transform.cc * \brief TIR specific transformation passes. */ +#include #include #include -#include #include namespace tvm { @@ -144,7 +144,7 @@ Pass CreatePrimFuncPass(std::function TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); -TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") +TVM_FFI_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") .set_body_typed( [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { diff --git a/src/tir/ir/utils.cc b/src/tir/ir/utils.cc deleted file mode 100644 index 65495bd8dd98..000000000000 --- a/src/tir/ir/utils.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/tir/ir/utils.cc - * \brief Utilities for manipulating TIR - */ -#include "utils.h" - -#include - -namespace tvm { -namespace tir { - -ffi::Any NormalizeAttributeObject(ffi::Any obj) { - if (obj.type_index() == ffi::TypeIndex::kTVMFFIBool) { - return Bool(obj.cast()); - } else if (auto opt_int = obj.as()) { - return Integer(opt_int.value()); - } else if (auto opt_float = obj.as()) { - return FloatImm(DataType::Float(32), opt_float.value()); - } else if (auto opt_array = obj.as>()) { - return opt_array.value().Map(NormalizeAttributeObject); - } else if (auto opt_map = obj.as>()) { - Map new_map; - bool is_same = true; - - for (const auto& [key, obj] : opt_map.value()) { - ObjectRef new_obj = NormalizeAttributeObject(obj.cast()).cast(); - is_same = is_same && obj.same_as(new_obj); - new_map.Set(key, new_obj); - } - - if (is_same) { - return obj; - } else { - return new_map; - } - } else if (auto dict_attrs = obj.as()) { - auto new_attrs = Downcast>(NormalizeAttributeObject(dict_attrs->dict)); - if (new_attrs.same_as(dict_attrs->dict)) { - return GetRef(dict_attrs); - } else { - return DictAttrs(new_attrs); - } - } else { - return obj; - } -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/ir/utils.h b/src/tir/ir/utils.h deleted file mode 100644 index c19a850a702c..000000000000 --- a/src/tir/ir/utils.h +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tir/ir/utils.h - * \brief Utilities for manipulating TIR - */ -#ifndef TVM_TIR_IR_UTILS_H_ -#define TVM_TIR_IR_UTILS_H_ - -#include - -namespace tvm { -namespace tir { - -/* \brief Normalize an ObjectRef held - * - * Where possible, the IR should be normalized contain IR types. For - * example, holding a `tir::IntImm` instead of a `runtime::Int`. In - * attributes, this is not always possible, as attributes may refer to - * non-IR objects. - * - * \param obj The attribute object to be normalized - * - * \returns The normalized attribute - */ -ffi::Any NormalizeAttributeObject(ffi::Any obj); - -} // namespace tir -} // namespace tvm -#endif // TVM_TIR_IR_UTILS_H_ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index a336688622a4..70614dfeebd7 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -22,7 +22,7 @@ * * builtin intrinsic operators. */ -#include +#include #include #include #include diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 838af436a6cd..341a96cae697 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -23,7 +23,7 @@ * Common operator definitions for ops in tir/op.h */ -#include +#include #include #include #include @@ -239,7 +239,7 @@ PrimExpr ret(PrimExpr value, Span span) { return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } -TVM_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); +TVM_FFI_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { @@ -761,7 +761,7 @@ PrimExpr bitwise_neg(PrimExpr a, Span span) { return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); } -TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) { return bitwise_neg(a, span); }); @@ -1071,10 +1071,10 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); // expose basic functions to node namespace -TVM_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - if (auto opt = args[0].as()) { +TVM_FFI_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + if (auto opt = args[0].try_cast()) { *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); - } else if (auto opt = args[0].as()) { + } else if (auto opt = args[0].try_cast()) { *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); } else { LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " @@ -1082,55 +1082,55 @@ TVM_REGISTER_GLOBAL("node._const").set_body_packed([](ffi::PackedArgs args, ffi: } }); -TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); +TVM_FFI_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); -TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); +TVM_FFI_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); -TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); +TVM_FFI_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); -TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity); +TVM_FFI_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity); -TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); +TVM_FFI_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); -TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely); +TVM_FFI_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely); -TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); +TVM_FFI_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); -TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); +TVM_FFI_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); -TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); +TVM_FFI_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); -TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); +TVM_FFI_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); -TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); +TVM_FFI_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); -TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); +TVM_FFI_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); -TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); +TVM_FFI_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); -TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); +TVM_FFI_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); -TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_FFI_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); -TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); +TVM_FFI_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ - return (Func(a, b, span)); \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ + return (Func(a, b, span)); \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { \ - bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ - bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } else { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ - } \ +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_FFI_REGISTER_GLOBAL("tir." #Node).set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { \ + bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ + bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } else { \ + *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + } \ }) REGISTER_MAKE_BINARY_OP(_OpAdd, add); @@ -1163,12 +1163,12 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor); REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, right_shift); -TVM_REGISTER_GLOBAL("tir._OpIfThenElse") +TVM_FFI_REGISTER_GLOBAL("tir._OpIfThenElse") .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { return if_then_else(cond, true_value, false_value, span); }); -TVM_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { +TVM_FFI_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t, Span span) { return const_true(t.lanes(), span); }); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9cf96bbd6b68..ad890ecb404e 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -264,7 +264,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, * \return The loop domain */ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, - const Optional& high_exclusive = NullOpt, + const Optional& high_exclusive = std::nullopt, const runtime::StorageScope& extra_relax_scope = // runtime::StorageScope{runtime::StorageRank::kGlobal, ""}); @@ -762,7 +762,7 @@ class TensorizeInfo : public ObjectRef { * \param block_sref The target block to match against * \param desc_func The prim func describing the computation to be tensorized * \param allow_padding Whether to allow padding the block iters to match the intrinsic description - * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise + * \return TensorizeInfo structure if a valid mapping is found, std::nullopt otherwise */ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, @@ -809,10 +809,10 @@ class AutoTensorizeMappingInfo : public ObjectRef { * \param self The schedule state * \param block_sref The compute block for auto tensorization * \param desc_func The prim func describing the computation to be tensorized - * \return AutoTensorizeMappingInfo structure if a potential mapping is found, NullOpt otherwise. - * \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be tensorized. - * We will need to apply the suggested layout transformations and then match against the tensor - * intrinsics. + * \return AutoTensorizeMappingInfo structure if a potential mapping is found, std::nullopt + * otherwise. \note Returning a valid AutoTensorizeMappingInfo doesn't guarantee the block can be + * tensorized. We will need to apply the suggested layout transformations and then match against the + * tensor intrinsics. */ Optional GetAutoTensorizeMappingInfo(const ScheduleState& self, const StmtSRef& block_sref, diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 043bb92ade8d..99f4050a84e5 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -328,7 +328,7 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; } -TVM_REGISTER_GLOBAL("tir.schedule.IsReductionBlock") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsReductionBlock") .set_body_typed([](Schedule sch, BlockRV block_rv, BlockRV scope_block_rv) { return IsReductionBlock(sch->state(), sch->GetSRef(block_rv), sch->GetSRef(scope_block_rv)); }); @@ -614,7 +614,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, Block block, } void CheckAffineBinding(const ScheduleState& self, Block block) { - CheckPartialAffineBinding(self, std::move(block), NullOpt); + CheckPartialAffineBinding(self, std::move(block), std::nullopt); } void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { @@ -864,7 +864,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } -TVM_REGISTER_GLOBAL("tir.schedule.GetBlockRealize") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetBlockRealize") .set_body_typed([](Schedule sch, BlockRV block_rv) { return GetBlockRealize(sch->state(), sch->GetSRef(block_rv)); }); @@ -1267,7 +1267,7 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ } // If we cannot find the defining site block, it means that the buffer must be in the function's // buffer_map, which isn't an intermediate buffer. - return {NullOpt, false}; + return {std::nullopt, false}; } /******** SRef Tree Related ********/ @@ -1385,7 +1385,7 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri } // Case 2. Read index cannot be recognized as `var +/- const` // where `var` is a write index and `const` is an optional constant shift - Optional opt_const = NullOpt; + Optional opt_const = std::nullopt; const VarNode* var = static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); if (var == nullptr || !var2idx.count(var)) { @@ -1483,7 +1483,7 @@ bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { return true; } -TVM_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsTrivialBinding") .set_body_typed([](Schedule sch, BlockRV block_rv) { return IsTrivialBinding(sch->state(), sch->GetSRef(block_rv)); }); @@ -1752,7 +1752,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, block_loops.push_back(loop); block_loop_vars.insert(loop->loop_var.get()); if (!analyzer.CanProve(loop->min == 0)) { - return NullOpt; + return std::nullopt; } } std::reverse(block_loops.begin(), block_loops.end()); @@ -1769,7 +1769,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, std::unordered_map block_index_to_padding; // padding of each block iter if necessary if (offset < 0) { - return NullOpt; + return std::nullopt; } const std::vector iter_types_block = GetBlockVarTypes(block_sref); @@ -1811,7 +1811,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, } } if (desc_loop == nullptr || desc_loop->extent.as() == nullptr) { - return NullOpt; + return std::nullopt; } const IntImmNode* int_desc_extent = desc_loop->extent.as(); @@ -1827,7 +1827,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, } } - if (!block_bind.defined()) return NullOpt; + if (!block_bind.defined()) return std::nullopt; // Step 3.3. Find the corresponding loop of the target block for (int i = 0, n = block_loops.size(); i < n; ++i) { @@ -1851,7 +1851,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, // Check divisibility if (!int_block_extent) { - return NullOpt; + return std::nullopt; } int64_t remainder = int_block_extent->value % int_desc_extent->value; if (remainder != 0) { @@ -1860,7 +1860,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, // divisible if padding is allowed. block_index_to_padding[current_block_ind] = int_desc_extent->value; } else { - return NullOpt; + return std::nullopt; } } @@ -1874,7 +1874,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, } if (!block_index_to_padding.empty()) { if (!allow_padding) { - return NullOpt; + return std::nullopt; } Array paddings; for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) { @@ -1891,8 +1891,8 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, return TensorizeInfo(ret); } -TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); -TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) { return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); @@ -2103,12 +2103,12 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu const tir::PrimFunc& desc_func) { AutoTensorizeComparator extractor(self->mod); if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) { - return NullOpt; + return std::nullopt; } arith::Analyzer analyzer; Array mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); if (mappings.empty()) { - return NullOpt; + return std::nullopt; } ObjectPtr ret = make_object(); ret->mappings = std::move(mappings); @@ -2121,19 +2121,20 @@ Optional GetAutoTensorizeMappingInfo(const tir::Schedu TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode); -TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo") .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); }); -TVM_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); -TVM_REGISTER_GLOBAL("tir.schedule.IsOutputBlock").set_body_typed([](Schedule sch, BlockRV block) { - auto state = sch->state(); - auto block_sref = sch->GetSRef(block); - return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); -}); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.HasBlock").set_body_typed(HasBlock); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.IsOutputBlock") + .set_body_typed([](Schedule sch, BlockRV block) { + auto state = sch->state(); + auto block_sref = sch->GetSRef(block); + return IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false)); + }); -TVM_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") .set_body_typed([](Schedule sch, LoopRV loop) -> String { IterVarType kind = GetLoopIterType(sch->GetSRef(loop)); if (kind == kDataPar) { diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index c31516234131..13b35582eefc 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -152,7 +152,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& /*index=*/f_flatten_index(indices), input_iters, predicate, /*check_level=*/arith::IterMapLevel::Surjective, analyzer); if (split_exprs.empty()) { - return NullOpt; + return std::nullopt; } // Step 4. Sort the order of the split expressions std::vector order(split_exprs.size(), 0); @@ -238,7 +238,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); } -TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") .set_body_typed([](Buffer buffer, Array indices, Array loops, PrimExpr predicate) { arith::Analyzer analyzer; diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index b29d13c3b9d3..4e3f04e0f389 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -68,7 +68,7 @@ class SRefTreeVerifier : public StmtVisitor { << GetRef(block) << "\nIts parent is supposed to be:\n" << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" << (sref->parent ? Optional(GetRef(sref->parent->stmt)) - : Optional(NullOpt)); + : Optional(std::nullopt)); ancestors_.push_back(sref.operator->()); if (block->init.defined()) { ++init_block_depth_; @@ -91,13 +91,13 @@ class SRefTreeVerifier : public StmtVisitor { << GetRef(loop); ++n_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(loop); - Optional stmt = NullOpt; + Optional stmt = std::nullopt; ICHECK(sref->parent == ancestors_.back()) << "InternalError: Parent information mismatch for ForNode:\n" << GetRef(loop) << "\nIts parent is supposed to be:\n" << GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" << (sref->parent ? Optional(GetRef(sref->parent->stmt)) - : Optional(NullOpt)); + : Optional(std::nullopt)); ancestors_.push_back(sref.operator->()); StmtVisitor::VisitStmt_(loop); ancestors_.pop_back(); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 48fac7d5eff8..edaccb51d687 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -36,7 +36,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa if (FindEntryFunc(mod, &gv) != nullptr) { n->func_working_on_ = gv; } else { - n->func_working_on_ = NullOpt; + n->func_working_on_ = std::nullopt; } return Schedule(std::move(n)); } @@ -913,7 +913,7 @@ void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intr /******** Schedule: Annotation ********/ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { - if (auto opt_str = ann_val.as()) { + if (auto opt_str = ann_val.try_cast()) { return *std::move(opt_str); } @@ -921,16 +921,16 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { return ann_val; } // prefer to return int/float literals for annotations - if (auto opt_intimm = ann_val.as()) { + if (auto opt_intimm = ann_val.try_cast()) { return (*std::move(opt_intimm))->value; } - if (auto opt_floatimm = ann_val.as()) { + if (auto opt_floatimm = ann_val.try_cast()) { return (*std::move(opt_floatimm))->value; } if (const auto* expr = ann_val.as()) { ICHECK(!expr->IsInstance()) - << "TypeError: runtime::String is expected, but gets StringImm"; + << "TypeError: String is expected, but gets StringImm"; auto res_expr = this->Get(GetRef(expr)); // prefer to return int/float literals for annotations if (auto opt_intimm = res_expr.as()) { @@ -941,7 +941,7 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } return res_expr; } - if (const auto* arr = ann_val.as()) { + if (const auto* arr = ann_val.as()) { Array result; result.reserve(arr->size()); for (size_t i = 0; i < arr->size(); i++) { @@ -949,14 +949,14 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } return std::move(result); } - if (const auto* dict = ann_val.as()) { + if (const auto* dict = ann_val.as()) { Map result; for (auto it = dict->begin(); it != dict->end(); ++it) { const auto& key = it->first; auto value = CheckAndGetAnnotationValue(it->second); if (const StringImmNode* imm = key.as()) { result.Set(imm->value, value); - } else if (auto opt_str = key.as()) { + } else if (auto opt_str = key.try_cast()) { result.Set(opt_str.value(), value); } else { LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm"; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index c57c6044f5f1..b00d2069ed17 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -63,7 +63,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } - Optional trace() const override { return NullOpt; } + Optional trace() const override { return std::nullopt; } Optional func_working_on() const final { return func_working_on_; } void WorkOn(const String& func_name) final; Schedule Copy() override; @@ -88,14 +88,14 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) override; + Optional decision = std::nullopt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = NullOpt) override; + Optional> decision = std::nullopt) override; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, - Optional> decision = NullOpt) override; + Optional> decision = std::nullopt) override; LoopRV SampleComputeLocation(const BlockRV& block_rv, - Optional decision = NullOpt) override; + Optional decision = std::nullopt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const Optional& func_name) override; Array GetLoops(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index 8e1bce285d06..7fd43c9242f0 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) inputs.push_back(String("None")); } else if (obj.as() || obj.as()) { inputs.push_back(String("_")); - } else if (const auto* str_obj = obj.as()) { + } else if (const auto* str_obj = obj.as()) { inputs.push_back(String('"' + std::string(str_obj->data) + '"')); } else if (obj.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) { inputs.push_back(obj); @@ -100,8 +100,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(InstructionNode); TVM_REGISTER_NODE_TYPE(InstructionKindNode); -TVM_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); -TVM_REGISTER_GLOBAL("tir.schedule.Instruction") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.Instruction") .set_body_typed([](InstructionKind kind, Array inputs, Array attrs, Array outputs) -> Instruction { return Instruction(kind, inputs, attrs, outputs); diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 9f24ee1a3e8b..cbd5185ff8f1 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -190,7 +190,7 @@ class PythonAPICall { * \brief Constructor * \param method_name The name of the schedule API to be called */ - explicit PythonAPICall(String method_name) : method_name_(method_name), output_(NullOpt) {} + explicit PythonAPICall(String method_name) : method_name_(method_name), output_(std::nullopt) {} /*! \brief Add an integer input */ inline void Input(String arg_name, int arg); /*! \brief Add an integer input */ @@ -272,7 +272,7 @@ template struct _IsTVMArray : std::false_type {}; template -struct _IsTVMArray> : std::true_type {}; +struct _IsTVMArray> : std::true_type {}; template struct _IsSingleObject @@ -409,7 +409,7 @@ TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs(const } else if (is_single_obj) { return {rv}; } else if (is_array) { - return rv.cast>(); + return rv.cast>(); } } @@ -420,12 +420,12 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { os << "None"; } else if (const auto* str = obj.as()) { os << str->data; - } else if (const auto opt_int_imm = obj.as()) { + } else if (const auto opt_int_imm = obj.try_cast()) { os << (*opt_int_imm)->value; - } else if (const auto opt_float_imm = obj.as()) { + } else if (const auto opt_float_imm = obj.try_cast()) { os.precision(17); os << (*opt_float_imm)->value; - } else if (const auto* array = obj.as()) { + } else if (const auto* array = obj.as()) { os << '['; bool is_first = true; for (Any e : *array) { @@ -437,7 +437,7 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { AsPythonString(e, os); } os << ']'; - } else if (const auto* dict = obj.as()) { + } else if (const auto* dict = obj.as()) { os << '{'; bool is_first = true; std::vector> dict_items; diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 97154d189cb1..0e2a055d7afe 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -173,7 +173,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { private: static bool IsValidAnnotation(const Block& block, const Any& anno_value) { - return anno_value.as>>().has_value(); + return anno_value.try_cast>>().has_value(); } IRModule mod_; @@ -347,7 +347,7 @@ void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_i const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kWrite); - DataType target_dtype(runtime::StringToDLDataType(dtype)); + DataType target_dtype(StringToDLDataType(dtype)); // Step 1. If `dtype` equals the original data type, just return. if (buffer->dtype == target_dtype) { diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 60ac7bf422e7..a5ec9d436b17 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -288,7 +288,7 @@ BlockRealize GenerateInner(bool is_write_reduction, Block block) { BlockNode* n = block.CopyOnWrite(); n->iter_vars = iter_vars; - n->init = NullOpt; + n->init = std::nullopt; if (is_write_reduction) { Array reads; reads.reserve(block->writes.size() + block->reads.size()); @@ -343,7 +343,7 @@ Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize /*writes=*/inner_block->writes, /*name_hint=*/block_name, /*body=*/block_init, - /*init=*/NullOpt)); + /*init=*/std::nullopt)); // Step 3. Create the loop nest on top of the block for (const ForNode* loop : loops) { bool is_init_loop = false; @@ -549,7 +549,7 @@ BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, block_subst->init.defined() // ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, block_subst->name_hint + "_init") - : Optional(NullOpt))); + : Optional(std::nullopt))); } StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) { @@ -654,7 +654,7 @@ BlockRealize BlockizeBlocks(const ScheduleState& self, const Array& bl /*writes=*/UnionRegions(write_regions), /*name_hint=*/outer_block_name, /*body=*/SeqStmt(seq_body), - /*init=*/Optional(NullOpt))); + /*init=*/Optional(std::nullopt))); } class BlockizeRewriter : public StmtMutator { @@ -743,7 +743,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed BlockRealize block_realize{nullptr}; - Optional old_block = NullOpt; + Optional old_block = std::nullopt; if (sref->stmt->IsInstance()) { block_realize = GetBlockRealize(self, sref); old_block = block_realize->block; diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc index 58bcd368c880..2e94b2050496 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/tir/schedule/primitive/cache_index.cc @@ -315,7 +315,7 @@ Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { /*name_hint=*/"index_" + std::to_string(expr_index), /*body=*/ BufferStore(info->cache_buffer[expr_index], new_expr, access_indices), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/{}); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a0536741c4f7..1b2a3a1cb478 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -88,7 +88,7 @@ struct CacheStageInfo { /*! \brief Return the buffer region related with the buffer */ Optional GetBufferRegionFromBuffer(const Array& buffer_regions, const Buffer& buffer) { - Optional res = NullOpt; + Optional res = std::nullopt; for (const auto& region : buffer_regions) { if (region->buffer.same_as(buffer)) { ICHECK(!res.defined()); @@ -204,7 +204,7 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI /*body=*/ BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), write_access_indices), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*buf_doms=*/{}); @@ -304,7 +304,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, /*body=*/ BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), write_access_indices), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/{}); @@ -503,7 +503,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { * \param scope_sref The scope block where the write is considered * \param buffer The queried buffer * \return The sref of the only writer of the input buffer in the given scope, - * or `NullOpt` if no block writes it in the scope. + * or `std::nullopt` if no block writes it in the scope. * \throw NotSingleWriteBlock if there are more than one interested block. */ Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, @@ -511,7 +511,7 @@ Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_s BlockScope scope = self->GetBlockScope(scope_sref); auto it = scope->buffer_writers.find(buffer); if (it == scope->buffer_writers.end()) { - return NullOpt; + return std::nullopt; } else { const Array& block_srefs = it->second; ICHECK(!block_srefs.empty()); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 56d85318d7bc..0075fee18f4c 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -366,7 +366,7 @@ void RelaxBufferRegions(const Map& binding, runtime::StorageScope global_scope{runtime::StorageRank::kGlobal, ""}; // We cache the variable domains runtime::StorageRank previous_rank = runtime::StorageRank::kGlobal; - Optional> var_dom = NullOpt; + Optional> var_dom = std::nullopt; // Enumerate every buffer region for (const BufferRegion& buffer_region : buffer_regions) { const Buffer& buffer = buffer_region->buffer; diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc index 299bc9a62d5a..94db2070c709 100644 --- a/src/tir/schedule/primitive/decompose_padding.cc +++ b/src/tir/schedule/primitive/decompose_padding.cc @@ -531,7 +531,7 @@ bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, /******** FFI ********/ -TVM_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding") .set_body_typed([](Schedule self, BlockRV block_rv, LoopRV loop_rv) { return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); }); diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 214390dafe2d..f1e035d92ef7 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -176,17 +176,17 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref /*iter_type=*/kThreadIndex, // /*thread_tag=*/thread_axis.value()); } else { - new_loop->thread_binding = NullOpt; + new_loop->thread_binding = std::nullopt; } self->Replace(loop_sref, For(new_loop), {}); } void Parallel(ScheduleState self, const StmtSRef& loop_sref) { - ParallelizeComputation(self, loop_sref, ForKind::kParallel, NullOpt); + ParallelizeComputation(self, loop_sref, ForKind::kParallel, std::nullopt); } void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { - ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt); + ParallelizeComputation(self, loop_sref, ForKind::kVectorized, std::nullopt); } void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis) { @@ -197,7 +197,7 @@ void Unroll(ScheduleState self, const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); ObjectPtr new_loop = make_object(*loop); new_loop->kind = ForKind::kUnrolled; - new_loop->thread_binding = NullOpt; + new_loop->thread_binding = std::nullopt; self->Replace(loop_sref, For(new_loop), {}); } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 6a0bba5c811a..a455afe6b067 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -513,14 +513,14 @@ class TransformLayoutPlanner : private StmtExprVisitor { auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional { if (!info.contains_row_major_traversal || !pad_value.defined() || is_zero(padding_predicate)) { - return NullOpt; + return std::nullopt; } BufferStoreReplacer replacer(info, new_buffer, padding_predicate, inverse, pad_value, &new_block_to_old, analyzer); Stmt stmt = replacer(info.dependent_loopnest.back()->body); if (!replacer.is_all_stores_replaced()) { - return NullOpt; + return std::nullopt; } ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); @@ -707,7 +707,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { * * Used to fill the `WriteInfo::innermost_block_realize` field.. */ - Optional innermost_block_realize_{NullOpt}; + Optional innermost_block_realize_{std::nullopt}; /*! \brief The buffer to be replaced */ Buffer old_buffer_; @@ -1176,7 +1176,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - Optional opt_inverse = NullOpt; + Optional opt_inverse = std::nullopt; PrimExpr padding_predicate = Bool(false); if (!assume_injective_transform) { std::tie(opt_inverse, padding_predicate) = [&]() { @@ -1209,7 +1209,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ GlobalVar g_var; const auto* old_func = GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); - MapObj* new_map = new_mod->functions.CopyOnWrite(); + ffi::MapObj* new_map = new_mod->functions.CopyOnWrite(); Map new_buffer_map; for (auto [var, buffer] : old_func->buffer_map) { @@ -1533,10 +1533,10 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer GlobalVar g_var; GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); - MapObj* new_map = new_mod->functions.CopyOnWrite(); + ffi::MapObj* new_map = new_mod->functions.CopyOnWrite(); PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); - MapObj* new_buffer_map = new_func->buffer_map.CopyOnWrite(); + ffi::MapObj* new_buffer_map = new_func->buffer_map.CopyOnWrite(); for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { if ((*it).second.same_as(old_buffer)) { (*it).second = new_buffer; diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index b8f4dfd58c2b..d112560a1fee 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -77,14 +77,14 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit IterMapSimplifyBlockBinding(MapObj* opaque_blocks, Map loop_var2extent, + explicit IterMapSimplifyBlockBinding(ffi::MapObj* opaque_blocks, Map loop_var2extent, bool preserve_unit_iters) : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent), preserve_unit_iters_(preserve_unit_iters) {} - static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, MapObj* opaque_blocks, - bool preserve_unit_iters) { + static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, + ffi::MapObj* opaque_blocks, bool preserve_unit_iters) { Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(sref); @@ -132,7 +132,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { } /*! \brief The reuse mapping */ - MapObj* opaque_blocks_; + ffi::MapObj* opaque_blocks_; /*! \brief The range of loops */ Map loop_var2extent_; /*! \brief Internal analyzer */ @@ -164,7 +164,7 @@ class BlockPropertyError : public ScheduleError { throw BlockPropertyError(state_->mod, GetRef(op)); } Optional high_exclusive = - top_->parent ? GetRef(top_->parent) : Optional(NullOpt); + top_->parent ? GetRef(top_->parent) : Optional(std::nullopt); CheckPartialAffineBinding(state_, GetRef(op), high_exclusive); } } @@ -427,7 +427,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array if (v.same_as(loop->loop_var)) { return substitute_value; } else { - return NullOpt; + return std::nullopt; } }, &opaque_block_reuse)(std::move(new_stmt)); @@ -933,7 +933,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs, bool preser return substitute_value[i]; } } - return NullOpt; + return std::nullopt; }; new_stmt = SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); @@ -994,7 +994,7 @@ std::pair GetBoundaryOfReorderRange( // Case 1. If `v` corresponds to a block, stop traversal. if (v->stmt->IsInstance()) { if (scope_block_visited) { - throw LoopsNotAChainError(self->mod, NullOpt, + throw LoopsNotAChainError(self->mod, std::nullopt, LoopsNotAChainError::ProblemKind::kNotUnderAScope); } scope_block_visited = true; diff --git a/src/tir/schedule/primitive/pad_einsum.cc b/src/tir/schedule/primitive/pad_einsum.cc index 5bf59005bbe7..5b724b6bd295 100644 --- a/src/tir/schedule/primitive/pad_einsum.cc +++ b/src/tir/schedule/primitive/pad_einsum.cc @@ -27,7 +27,7 @@ namespace tir { /*! * \brief Check if buffer indices are all Vars and expr * \param buffer_access The BufferLoad or BufferStore - * \return The indices if the indices are all Vars, otherwise NullOpt + * \return The indices if the indices are all Vars, otherwise std::nullopt */ Optional> CheckTrivialBufferIndices(const Array& buffer_access) { Array indices; @@ -37,7 +37,7 @@ Optional> CheckTrivialBufferIndices(const Array& buffer_acc } const VarNode* var = index.as(); if (var == nullptr) { - return NullOpt; + return std::nullopt; } indices.push_back(GetRef(var)); } @@ -49,7 +49,7 @@ Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) indices.reserve(buffer_region->region.size()); for (const Range& range : buffer_region->region) { if (!tir::is_one(range->extent)) { - return NullOpt; + return std::nullopt; } if (range->min->IsInstance()) { continue; @@ -57,7 +57,7 @@ Optional> CheckTrivialBufferAccess(const BufferRegion& buffer_region) if (const auto* var = range->min.as()) { indices.push_back(GetRef(var)); } else { - return NullOpt; + return std::nullopt; } } return indices; diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc index d482263bbcc8..9fdb322a4996 100644 --- a/src/tir/schedule/primitive/read_write_at.cc +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -311,7 +311,7 @@ struct ReadWriteAtImpl { /*writes=*/{BufferRegion(copy_to, domain)}, /*name_hint=*/name_hint, // /*body=*/std::move(stmt), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/annotations_)); diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 126832cc85fb..326d373d6e70 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -63,7 +63,7 @@ class DecomposeReductionBlockReplacer : public StmtMutator { if (block == old_reduction_block_.get()) { ObjectPtr p_new_block = CopyOnWrite(block); p_new_block->name_hint = p_new_block->name_hint + "_update"; - p_new_block->init = NullOpt; + p_new_block->init = std::nullopt; // Add write regions back to read regions in update block. Array new_reads; std::unordered_set read_bufs; @@ -412,7 +412,7 @@ struct ReducerRegistry { identity_getter = std::move(identity_getter) // ](Array values) -> Optional { if (static_cast(values.size()) != n_buffers) { - return NullOpt; + return std::nullopt; } Array lhs; Array rhs; @@ -747,7 +747,7 @@ class BaseBlockCreator { Optional CreateBlockInit(bool has_reduce_iter) { if (!has_reduce_iter) { - return NullOpt; + return std::nullopt; } Array inits; @@ -1344,7 +1344,7 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ -TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.RegisterReducer") .set_body_typed([](int n_buffers, ffi::Function combiner_getter, ffi::Function identity_getter) { ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index c01d6c568fcd..19375f7235fc 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -191,7 +191,7 @@ class RollingBufferInfoCollector { stride = 1; } else if (is_const_int(bound->min)) { // If the bound is an int, we can't roll over it - iter_var = NullOpt; + iter_var = std::nullopt; } else { // If all of the above matches fail, we're in unknown behaviour return false; @@ -202,9 +202,9 @@ class RollingBufferInfoCollector { bound_overlap = extent - stride; // Since Pass CompactBufferAllocation will be responsible for compacting the buffer // allocation region, there is no need to roll over the axis where the overlap is not - // positive, so reset iter_var to NullOpt. + // positive, so reset iter_var to std::nullopt. if (bound_overlap <= 0) { - iter_var = NullOpt; + iter_var = std::nullopt; } } bound_iter_vars.push_back(iter_var); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 29e40d4003d6..8dc1dcf8dbb2 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -43,35 +43,35 @@ TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // .set_body_method(&ScheduleNode::mod); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // .set_body_method(&ScheduleNode::func_working_on); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // .set_body_method(&ScheduleNode::Seed); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // .set_body_method(&ScheduleNode::ForkSeed); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") // +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") // .set_body_method(&ScheduleNode::WorkOn); /**************** (FFI) Constructor ****************/ -TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); -TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); -TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level), enable_check); }); -TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TracedSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Traced(mod, seed, debug_mask, @@ -81,7 +81,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") /******** (FFI) Lookup random variables ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGet") .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { if (auto loop_rv = obj.as()) { return self->Get(loop_rv.value()); @@ -96,7 +96,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") << ". Its value is: " << obj; throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") .set_body_typed([](Schedule self, ObjectRef obj) -> Optional { if (auto loop_rv = obj.as()) { return self->GetSRef(loop_rv.value()); @@ -110,7 +110,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") .set_body_typed([](Schedule self, ObjectRef obj) -> void { if (auto loop_rv = obj.as()) { return self->RemoveRV(loop_rv.value()); @@ -126,18 +126,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") }); /******** (FFI) Sampling ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePartitionedTile") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePartitionedTile") .set_body_method(&ScheduleNode::SamplePartitionedTile); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") .set_body_method(&ScheduleNode::SampleComputeLocation); /******** (FFI) Get blocks & loops ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock").set_body_method(&ScheduleNode::GetBlock); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops").set_body_method(&ScheduleNode::GetLoops); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock").set_body_method(&ScheduleNode::GetBlock); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops").set_body_method(&ScheduleNode::GetLoops); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") .set_body_typed([](Schedule self, ObjectRef rv) { if (auto block_rv = rv.as()) { return self->GetChildBlocks(block_rv.value()); @@ -149,22 +149,22 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") << ". Its value is: " << rv; throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") .set_body_method(&ScheduleNode::GetProducers); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") .set_body_method(&ScheduleNode::GetConsumers); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") .set_body_method(&ScheduleNode::GetOutputBlocks); /******** (FFI) Transform loops ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleLoopPartition") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleLoopPartition") .set_body_method(&ScheduleNode::LoopPartition); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder").set_body_method(&ScheduleNode::Reorder); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReorder").set_body_method(&ScheduleNode::Reorder); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReorderBlockIterVar") .set_body_method(&ScheduleNode::ReorderBlockIterVar); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { if (auto loop_rv = rv.as()) { return self->AddUnitLoop(loop_rv.value()); @@ -177,48 +177,50 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") } }); /******** (FFI) Manipulate ForKind ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel").set_body_method(&ScheduleNode::Parallel); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize").set_body_method(&ScheduleNode::Vectorize); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleParallel").set_body_method(&ScheduleNode::Parallel); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize").set_body_method(&ScheduleNode::Vectorize); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead").set_body_method(&ScheduleNode::CacheRead); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite").set_body_method(&ScheduleNode::CacheWrite); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead").set_body_method(&ScheduleNode::CacheRead); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") + .set_body_method(&ScheduleNode::CacheWrite); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") .set_body_method(&ScheduleNode::ReindexCacheRead); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") .set_body_method(&ScheduleNode::ReindexCacheWrite); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex").set_body_method(&ScheduleNode::CacheIndex); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") + .set_body_method(&ScheduleNode::CacheIndex); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); }); /******** (FFI) Data movement ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt").set_body_method(&ScheduleNode::WriteAt); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt").set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt").set_body_method(&ScheduleNode::ComputeAt); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt").set_body_method(&ScheduleNode::ComputeAt); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") .set_body_method(&ScheduleNode::ReverseComputeAt); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") .set_body_method(&ScheduleNode::ReverseComputeInline); /******** (FFI) Reduction ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") .set_body_method(&ScheduleNode::DecomposeReduction); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor").set_body_method(&ScheduleNode::RFactor); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor").set_body_method(&ScheduleNode::RFactor); /******** (FFI) Block annotation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope").set_body_method(&ScheduleNode::SetScope); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope").set_body_method(&ScheduleNode::SetScope); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType") .set_body_method(&ScheduleNode::UnsafeSetDType); /******** (FFI) Blockize & Tensorize ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") .set_body_typed([](Schedule self, ObjectRef target, bool preserve_unit_iters) { if (auto loop_rv = target.as()) { return self->Blockize(loop_rv.value(), preserve_unit_iters); @@ -227,7 +229,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") } LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { if (auto block_rv = rv.as()) { self->Tensorize(block_rv.value(), intrin, preserve_unit_iters); @@ -240,7 +242,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") }); /******** (FFI) Annotation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, const Any& ann_val) { if (auto block_rv = rv.as()) { return self->Annotate(block_rv.value(), ann_key, ann_val); @@ -252,7 +254,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") << ". Its value is: " << rv; throw; }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { if (auto block_rv = rv.as()) { return self->Unannotate(block_rv.value(), ann_key); @@ -266,7 +268,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") }); /******** (FFI) Layout transformation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map, const Optional& pad_value, bool assume_injective_transform) { @@ -274,9 +276,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") static_cast(buffer_index_type), index_map, pad_value, assume_injective_transform); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") .set_body_method(&ScheduleNode::TransformBlockLayout); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const Array& axis_separators) { return self->SetAxisSeparator( @@ -284,19 +286,19 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator") }); /******** (FFI) Padding decomposition ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") .set_body_method(&ScheduleNode::DecomposePadding); -TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum").set_body_method(&ScheduleNode::PadEinsum); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum").set_body_method(&ScheduleNode::PadEinsum); /******** (FFI) Buffer transformation ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") .set_body_method(&ScheduleNode::RollingBuffer); /******** (FFI) Misc ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") .set_body_method(&ScheduleNode::UnsafeHideBufferAccess); /******** (FFI) Annotate buffer access ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, const IndexMap& index_map) { return self->AnnotateBufferAccess(block_rv, buffer_index, diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 43a3c9008f0d..f2c4b56121c9 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -750,7 +750,7 @@ class ChildReplacer : private StmtMutator { int n = static_cast(op->seq.size()); if (0 <= i && i < n) { const Stmt& stmt = op->seq[i]; - Optional new_stmt = NullOpt; + Optional new_stmt = std::nullopt; const StmtNode* src_stmt = this->src_stmt_; // `stmt` can be For or BlockRealize // `src_stmt` can be For or Block @@ -945,7 +945,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } // Ensure the uniqueness of `this->mod` and `this->mod->functions` IRModuleNode* new_mod = this->mod.CopyOnWrite(); - MapObj* new_map = new_mod->functions.CopyOnWrite(); + ffi::MapObj* new_map = new_mod->functions.CopyOnWrite(); // Move out the PrimFunc where the sref belong while ensuring uniqueness PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); ICHECK(ref_new_func.get() == g_func); @@ -1012,20 +1012,20 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(ScheduleStateNode); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleState") .set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { return ScheduleState(mod, debug_mask, enable_check); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") .set_body_method(&ScheduleStateNode::GetBlockScope); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace") .set_body_method(&ScheduleStateNode::Replace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef") .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional { auto it = self->stmt2ref.find(stmt.get()); - return it != self->stmt2ref.end() ? it->second : Optional(NullOpt); + return it != self->stmt2ref.end() ? it->second : Optional(std::nullopt); }); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index a666a1d5902e..1992f5ae8a69 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -58,7 +58,7 @@ Array TranslateInputRVs(const Array& inputs, auto f_subst_with_rv_map = [&rv_map](const Var& var) -> Optional { auto it = rv_map.find(var.get()); if (it == rv_map.end()) { - return NullOpt; + return std::nullopt; } const Object* dst = it->second; ICHECK(dst->IsInstance()) @@ -78,7 +78,7 @@ Array TranslateInputRVs(const Array& inputs, auto it = rv_map.find(input.as()); ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; result.push_back(GetRef(it->second)); - } else if (auto expr = input.as()) { // RV: Expr + } else if (auto expr = input.try_cast()) { // RV: Expr result.push_back(Substitute(expr.value(), f_subst_with_rv_map)); } else if (auto index_map = input.as()) { result.push_back(Substitute(index_map.value(), f_subst_with_rv_map)); @@ -120,16 +120,16 @@ Array TranslateInputRVs( LOG(FATAL) << "IndexError: Random variable is not defined " << input; throw; } - } else if (const auto* str_obj = input.as()) { + } else if (const auto* str_obj = input.as()) { // Case 2. string => "content" results.push_back(String('"' + std::string(str_obj->data) + '"')); } else if (input.as() || input.as()) { // Case 3. integer or floating-point number results.push_back(input); - } else if (input.as()) { + } else if (input.as()) { // Case 4: array results.push_back(TranslateInputRVs(Downcast>(Any(input)), rv_names)); - } else if (input.as()) { + } else if (input.as()) { // Case 5: dict results.push_back(input); } else if (input.as()) { @@ -139,7 +139,7 @@ Array TranslateInputRVs( if (auto it = rv_names.find(var); it != rv_names.end()) { return it->second; } - return NullOpt; + return std::nullopt; }); results.push_back(index_map); } else { @@ -166,12 +166,12 @@ Array TranslateInputRVs(const Array& inputs, continue; } // Case 4. array - if (input.as()) { + if (input.as()) { results.push_back(TranslateInputRVs(Downcast>(input), named_rvs)); continue; } // Case 5. dict - if (input.as()) { + if (input.as()) { results.push_back(input); continue; } @@ -190,7 +190,7 @@ Array TranslateInputRVs(const Array& inputs, if (it != named_rvs.end()) { return Downcast(it->second); } - return NullOpt; + return std::nullopt; }); results.push_back(index_map); continue; @@ -236,7 +236,7 @@ Array TranslateAddOutputRVs( ICHECK(!rv_names->count(output.cast())) << "ValueError: The random variable has been produced once: " << rv_names->at(output.cast()); - String result{ObjectPtr{nullptr}}; + String result{ffi::ObjectPtr{nullptr}}; if (output == nullptr) { result = "_"; } else if (output.as()) { @@ -280,7 +280,7 @@ void TraceNode::Append(Instruction inst, Any decision) { Optional TraceNode::Pop() { if (insts.empty()) { - return NullOpt; + return std::nullopt; } Instruction inst = insts.back(); insts.pop_back(); @@ -359,7 +359,7 @@ Array TraceNode::AsPython(bool remove_postproc) const { Array attrs; attrs.reserve(inst->attrs.size()); for (const Any& obj : inst->attrs) { - if (const auto* str = obj.as()) { + if (const auto* str = obj.as()) { attrs.push_back(String('"' + std::string(str->data) + '"')); } else { attrs.push_back(obj); @@ -379,10 +379,10 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Array json_decisions{nullptr}; // Parse `json` into `json_insts` and `json_decisions` try { - const ArrayObj* arr = json.as(); + const ffi::ArrayObj* arr = json.as(); ICHECK(arr && arr->size() == 2); - const auto* arr0 = arr->at(0).as(); - const auto* arr1 = arr->at(1).as(); + const auto* arr0 = arr->at(0).as(); + const auto* arr1 = arr->at(1).as(); ICHECK(arr0 && arr1); json_insts = GetRef>(arr0); json_decisions = GetRef>(arr1); @@ -398,9 +398,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { int index = -1; Any decision{nullptr}; try { - const ArrayObj* arr = decision_entry.as(); + const ffi::ArrayObj* arr = decision_entry.as(); ICHECK(arr && arr->size() == 2); - auto arr0 = arr->at(0).as(); + auto arr0 = arr->at(0).try_cast(); ICHECK(arr0); index = arr0.value()->value; decision = arr->at(1); @@ -422,9 +422,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Array outputs{ObjectPtr{nullptr}}; // Parse the entry try { - const auto* arr = inst_entry.as(); + const auto* arr = inst_entry.as(); ICHECK(arr && arr->size() == 4); - const auto* arr0 = arr->at(0).as(); + const auto* arr0 = arr->at(0).as(); kind = InstructionKind::Get(arr0->data); inputs = arr->at(1).cast>(); attrs = arr->at(2).cast>(); @@ -563,13 +563,13 @@ TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(TraceNode); -TVM_REGISTER_GLOBAL("tir.schedule.Trace") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.Trace") .set_body_typed([](Optional> insts, Optional> decisions) { return Trace(insts.value_or(Array()), decisions.value_or({})); }); -TVM_REGISTER_GLOBAL("tir.schedule.TraceGetDecision").set_body_method(&TraceNode::GetDecision); -TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceGetDecision").set_body_method(&TraceNode::GetDecision); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAppend") .set_body_typed([](Trace self, Instruction inst, Optional decision) { if (decision.defined()) { return self->Append(inst, decision.value()); @@ -577,14 +577,14 @@ TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend") return self->Append(inst); } }); -TVM_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); -TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TracePop").set_body_method(&TraceNode::Pop); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule") .set_body_method(&TraceNode::ApplyToSchedule); -TVM_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); -TVM_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); -TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision").set_body_method(&TraceNode::WithDecision); -TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); -TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAsJSON").set_body_method(&TraceNode::AsJSON); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceAsPython").set_body_method(&TraceNode::AsPython); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceWithDecision").set_body_method(&TraceNode::WithDecision); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceSimplified").set_body_method(&TraceNode::Simplified); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule") .set_body_typed(Trace::ApplyJSONToSchedule); } // namespace tir diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 52a012b2a321..d3e77e0e3b84 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -35,7 +35,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand if (FindEntryFunc(mod, &gv) != nullptr) { n->func_working_on_ = gv; } else { - n->func_working_on_ = NullOpt; + n->func_working_on_ = std::nullopt; } return Schedule(std::move(n)); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 30eeccfd85c7..777f31a57bea 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -48,13 +48,14 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) final; + Optional decision = std::nullopt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, - Optional> decision = NullOpt) final; + Optional> decision = std::nullopt) final; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, int innerpart_factor, - Optional> decision = NullOpt) final; - LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; + Optional> decision = std::nullopt) final; + LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = std::nullopt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const Optional& func_name) final; Array GetLoops(const BlockRV& block_rv) final; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 69686bac55fb..c0929e01a8ad 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -323,7 +323,7 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block Optional opt_tensorize_info = GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); - if (!opt_tensorize_info) return NullOpt; + if (!opt_tensorize_info) return std::nullopt; const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); if (info->block_iter_paddings.defined()) { // We have to track whether each producer or consumer is padded. @@ -413,9 +413,9 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block int64_t total = int_block_extent->value; int64_t inner = int_desc_extent->value; ICHECK_EQ(total % inner, 0); - // Do the split. Leave the outer extent as NullOpt (unspecified) so that the split factors + // Do the split. Leave the outer extent as std::nullopt (unspecified) so that the split factors // can be used for different extents (needed during tuning). - Array split = sch->Split(loop2rv.at(block_loop_sref), {NullOpt, Integer(inner)}); + Array split = sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized @@ -439,7 +439,7 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block return reorder_suffix[0]; } -TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin); /******** BlockBufferAccessSimplifier ********/ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array* old_access_regions) { @@ -508,15 +508,15 @@ Optional NormalizePrimFunc(Schedule sch) { Array binds = GetBlockRealize(sch->state(), block_sref)->iter_values; if (loops.size() == 0) continue; if (loops.size() != binds.size()) { - return NullOpt; + return std::nullopt; } for (int i = 0, n = loops.size(); i < n; ++i) { const ForNode* loop = TVM_SREF_TO_FOR(loops[i]); if (binds[i].get() != loop->loop_var.get()) { - return NullOpt; + return std::nullopt; } if (!is_zero(loop->min)) { - return NullOpt; + return std::nullopt; } } } @@ -557,7 +557,7 @@ Optional NormalizePrimFunc(Schedule sch) { return Array{leaf_blocks, block_loops, block_iters, block_is_reduction}; } -TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index c597997414a0..73d6a0d85371 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -212,7 +212,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * TensorIntrin.register(...) beforehand * \param allow_padding Whether to allow padding when tiling * \return LoopRV corresponding to the outermost loop of a - * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found + * block tiled according to the given intrin, std::nullopt if a valid loop mapping is not found */ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, const String& intrin_name, bool allow_padding = false); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 1910bed28796..deedfd6f68dc 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -218,12 +218,12 @@ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { * \brief Check if an expression consists of a single variable, * or a variable plus/minus an constant integer shift * \param expr The expression to be checked - * \return The single variable in the expression, or NullOpt if the expression is neither a variable - * or a constant shift from a variable + * \return The single variable in the expression, or std::nullopt if the expression is neither a + * variable or a constant shift from a variable */ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { if (const auto* var = expr.as()) { - *constant = NullOpt; + *constant = std::nullopt; return GetRef(var); } arith::PVar var; @@ -239,7 +239,7 @@ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* *constant = IntImm(result->dtype, -result->value); return var.Eval(); } - return NullOpt; + return std::nullopt; } /******** Annotation ********/ @@ -249,7 +249,7 @@ inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* * \tparam TObjectRef The type of the annotation value * \param sref The sref to the block or the for loop * \param ann_key The annotation key to be looked up - * \return NullOpt if not found; otherwise the annotation value + * \return std::nullopt if not found; otherwise the annotation value */ template inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) { @@ -259,7 +259,7 @@ inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) return Downcast(ann.second); } } - return NullOpt; + return std::nullopt; } /*! @@ -267,7 +267,7 @@ inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) * \tparam TObjectRef The type of the annotation value * \param sref The sref to the block or the for loop * \param ann_key The annotation key to be looked up - * \return NullOpt if not found; otherwise the annotation value + * \return std::nullopt if not found; otherwise the annotation value */ template inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc index a81af7d7805b..f8adcf4f5010 100644 --- a/src/tir/transforms/annotate_device_regions.cc +++ b/src/tir/transforms/annotate_device_regions.cc @@ -21,8 +21,8 @@ * \file annotate_device_regions.cc * \brief Split device function from host. */ +#include #include -#include #include #include #include @@ -74,7 +74,8 @@ Pass AnnotateDeviceRegions() { return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); } -TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions); +TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions") + .set_body_typed(AnnotateDeviceRegions); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index 66d0fb61661b..06d596adb44d 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -23,8 +23,8 @@ * Re-write data access to enable memory sharing when possible. */ #include +#include #include -#include #include #include #include diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 616b47f29403..15728e846224 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -23,7 +23,7 @@ // Instrument checkers for out of the bounds access. #include -#include +#include #include #include #include @@ -255,7 +255,7 @@ Pass InstrumentBoundCheckers() { return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") +TVM_FFI_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") .set_body_typed(InstrumentBoundCheckers); } // namespace transform diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 18e568c83e74..2ff7c03c6287 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -22,9 +22,9 @@ * * \file combine_context_call.cc */ +#include #include #include -#include #include #include #include @@ -112,7 +112,7 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); +TVM_FFI_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index b6c52ec1a3be..42409efb0bd1 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -29,9 +29,9 @@ #include "common_subexpr_elim.h" +#include +#include #include // For the class Pass and the class PassContext -#include -#include #include // For the analysis which gives the size of an expr #include #include @@ -655,7 +655,7 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) { } // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it -TVM_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR); +TVM_FFI_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index ba101ce4e70f..ce8aef4587dd 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -26,8 +26,8 @@ #include "common_subexpr_elim_tools.h" #include // For the arith::Analyzer::Simplify() method simplifying terms -#include // For the class Pass and the class PassContext -#include +#include +#include // For the class Pass and the class PassContext #include // For the ExprDeepEqual analysis #include #include diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 841f1d65a6f6..58014e6a406d 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -26,7 +26,7 @@ #ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_ #define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_ -#include +#include #include // For the ExprDeepEqual analysis #include #include diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index b18c07651000..c5c6accf221a 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -537,7 +537,7 @@ struct BufferAllocInfo { std::vector dim_aligns; /*! * \brief The reallocated buffer with minimal size. - * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer). + * \note The value if std::nullopt if the buffer do not need reallocate (e.g parameter buffer). */ Buffer new_buffer; }; @@ -756,7 +756,7 @@ Pass CompactBufferAllocation(bool is_strict) { return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation") +TVM_FFI_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation") .set_body_typed(CompactBufferAllocation); } // namespace transform diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc index ab8d98a00e0e..1b29cea2f27a 100644 --- a/src/tir/transforms/convert_blocks_to_opaque.cc +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -122,7 +122,8 @@ Pass ConvertBlocksToOpaque() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque").set_body_typed(ConvertBlocksToOpaque); +TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque") + .set_body_typed(ConvertBlocksToOpaque); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc index d01ae8a45113..4c992163df04 100644 --- a/src/tir/transforms/convert_for_loops_serial.cc +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -66,7 +66,7 @@ Pass ConvertForLoopsToSerial() { return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") .set_body_typed(ConvertForLoopsToSerial); } // namespace transform diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index 5034a858130d..3b382850559a 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -20,7 +20,7 @@ /*! * \file decorate_device_scope.cc */ -#include +#include #include #include #include @@ -44,7 +44,7 @@ Pass DecorateDeviceScope() { return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); } -TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); +TVM_FFI_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/default_gpu_schedule.cc b/src/tir/transforms/default_gpu_schedule.cc index 249003d2ad6c..398b00092d08 100644 --- a/src/tir/transforms/default_gpu_schedule.cc +++ b/src/tir/transforms/default_gpu_schedule.cc @@ -66,15 +66,15 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& block, int64_t max_thread } // schedule the fused loop if (product > max_thread_per_block * max_threadblocks) { - Array splits = - sch->Split(fused, - /*factors=*/{NullOpt, Integer(max_threadblocks), Integer(max_thread_per_block)}); + Array splits = sch->Split( + fused, + /*factors=*/{std::nullopt, Integer(max_threadblocks), Integer(max_thread_per_block)}); sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); } else { - Array splits = - sch->Split(fused, /*factors=*/{NullOpt, Integer(std::min(product, max_thread_per_block))}); + Array splits = sch->Split( + fused, /*factors=*/{std::nullopt, Integer(std::min(product, max_thread_per_block))}); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); } @@ -162,7 +162,7 @@ Pass DefaultGPUSchedule() { /*required=*/{}); } -TVM_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule); +TVM_FFI_REGISTER_GLOBAL("tir.transform.DefaultGPUSchedule").set_body_typed(DefaultGPUSchedule); } // namespace transform diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 052f7cf948cb..509efb8d06fd 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -25,8 +25,8 @@ * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0022-tir-non-scalar-constants.md */ #include +#include #include -#include #include #include "ir_utils.h" @@ -105,7 +105,7 @@ tvm::transform::Pass ExtractPrimFuncConstants() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ExtractPrimFuncConstants") .set_body_typed(ExtractPrimFuncConstants); } // namespace transform diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index a6da7f7fc407..5ea0a60ea2a8 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -279,7 +279,7 @@ Pass FlattenBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc index 86f839c4f5e2..bd33e564e5c2 100644 --- a/src/tir/transforms/force_narrow_index_to_i32.cc +++ b/src/tir/transforms/force_narrow_index_to_i32.cc @@ -86,7 +86,7 @@ Pass ForceNarrowIndexToInt32() { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ForceNarrowIndexToInt32") .set_body_typed(ForceNarrowIndexToInt32); } // namespace transform diff --git a/src/tir/transforms/hoist_expression.cc b/src/tir/transforms/hoist_expression.cc index f0fc90ee3244..d1c2155fd066 100644 --- a/src/tir/transforms/hoist_expression.cc +++ b/src/tir/transforms/hoist_expression.cc @@ -21,7 +21,7 @@ * \file hoist_expression.cc */ #include -#include +#include #include #include #include @@ -552,7 +552,7 @@ Pass HoistExpression() { "tir.HoistExpression"); } -TVM_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); +TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression); Pass HoistIfThenElse() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -587,7 +587,7 @@ Pass HoistIfThenElse() { "tir.HoistIfThenElse"); } -TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); +TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse); Pass HoistIfThenElseBasic() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -607,7 +607,7 @@ Pass HoistIfThenElseBasic() { "tir.HoistIfThenElseBasic"); } -TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); +TVM_FFI_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic); } // namespace transform diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 52e4d44b615a..6b992ce1f999 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -21,7 +21,7 @@ * \brief Inject double buffering optimization for data fetch. * \file inject_double_buffer.cc */ -#include +#include #include #include #include @@ -319,7 +319,7 @@ Pass InjectDoubleBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); } // namespace transform diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index fd1e4a54e473..00e29061ba3a 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -103,12 +103,12 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } static bool CheckAnnotation(const Any& annotation) { - if (auto* node = annotation.as()) { + if (auto* node = annotation.as()) { // Support string annotation for backward compatibility return GetRef(node) != ""; } else if (auto* node = annotation.as()) { return node->value != 0; - } else if (auto opt_val = annotation.as()) { + } else if (auto opt_val = annotation.try_cast()) { return *opt_val != 0; } else { LOG(FATAL) << "Invalid permuted layout annotation: " << annotation; @@ -295,7 +295,7 @@ Pass InjectPermutedLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPermutedLayout", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectPermutedLayout").set_body_typed(InjectPermutedLayout); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPermutedLayout").set_body_typed(InjectPermutedLayout); } // namespace transform diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 5d23e854be02..04bcecac36b0 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -199,7 +199,7 @@ Pass InjectPTXAsyncCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy); } // namespace transform diff --git a/src/tir/transforms/inject_ptx_ldg32.cc b/src/tir/transforms/inject_ptx_ldg32.cc index b4c398bd17eb..c3a6cf50b828 100644 --- a/src/tir/transforms/inject_ptx_ldg32.cc +++ b/src/tir/transforms/inject_ptx_ldg32.cc @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include #include @@ -123,7 +123,7 @@ Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) { // The pass can now be invoked via the pass infrastructure, but we also add a // Python binding for it -TVM_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/inject_rolling_buffer.cc b/src/tir/transforms/inject_rolling_buffer.cc index 03f94e3e9139..ed35bdb0655f 100644 --- a/src/tir/transforms/inject_rolling_buffer.cc +++ b/src/tir/transforms/inject_rolling_buffer.cc @@ -34,7 +34,7 @@ https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836 */ #include -#include +#include #include #include @@ -315,7 +315,7 @@ Pass InjectRollingBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectRollingBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectRollingBuffer").set_body_typed(InjectRollingBuffer); } // namespace transform diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 5e003d173d65..4f137619ea7e 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -364,7 +364,7 @@ class PipelineRewriter : public StmtExprMutator { // introduce extra lowerbound when the loop length is smaller than num stages // to ensure the epilogue interval do not overlap the prologue interval. PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent; - Optional extra_epilogue_lower_bound = NullOpt; + Optional extra_epilogue_lower_bound = std::nullopt; if (max_stage_ > 1 && !analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { if (is_const_int(epigogue_start)) { epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_); @@ -811,7 +811,7 @@ class PipelineRewriter : public StmtExprMutator { * \return The result loop. */ Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, - Optional extra_loop_lower_bound = NullOpt) { + Optional extra_loop_lower_bound = std::nullopt) { PrimExpr new_loop_var; PrimExpr extent = end - start; @@ -941,7 +941,7 @@ class PipelineRewriter : public StmtExprMutator { if (!is_unit_loop) { new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop), - NullOpt, preserved_annotations_); + std::nullopt, preserved_annotations_); } // Update producer heads in the global async states. @@ -957,7 +957,7 @@ class PipelineRewriter : public StmtExprMutator { async_states[stage_id].producer_head.value() + extent; } else { // Otherwise, invalidate the global producer head - async_states[stage_id].producer_head = NullOpt; + async_states[stage_id].producer_head = std::nullopt; } } @@ -1259,7 +1259,8 @@ Pass InjectSoftwarePipeline() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline") + .set_body_typed(InjectSoftwarePipeline); } // namespace transform diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index d9fc74f8ad18..334d9594616d 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -20,7 +20,7 @@ /*! * \file inject_virtual_thread.cc */ -#include +#include #include #include #include @@ -343,7 +343,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { visit_touched_var_ = false; ICHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->VisitStmt(op->then_case); - Optional else_case = NullOpt; + Optional else_case = std::nullopt; if (op->else_case) { int temp = max_loop_depth_; max_loop_depth_ = 0; @@ -525,7 +525,7 @@ Pass InjectVirtualThread() { return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); } // namespace transform diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index 14672f568549..eae2e29ef686 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -21,7 +21,7 @@ * \file inline_private_functions.cc * \brief Inline private functions to their callsite */ -#include +#include #include #include #include @@ -155,7 +155,7 @@ class PrimFuncInliner : StmtExprMutator { PrimFunc VisitFunc(PrimFunc func) { current_target_ = func->GetAttr(tvm::attr::kTarget); auto new_body = VisitStmt(func->body); - current_target_ = NullOpt; + current_target_ = std::nullopt; if (!new_body.same_as(func->body)) { func.CopyOnWrite()->body = new_body; @@ -177,13 +177,13 @@ class PrimFuncInliner : StmtExprMutator { Optional GetInlinedFunction(const EvaluateNode* eval) { auto call = eval->value.as(); - if (!call) return NullOpt; + if (!call) return std::nullopt; auto gvar = call->op.as(); - if (!gvar) return NullOpt; + if (!gvar) return std::nullopt; auto opt_callee = inlinable_funcs_.Get(gvar.value()); - if (!opt_callee) return NullOpt; + if (!opt_callee) return std::nullopt; auto callee = opt_callee.value(); bool is_same_target = [&]() -> bool { @@ -194,7 +194,7 @@ class PrimFuncInliner : StmtExprMutator { return true; } }(); - if (!is_same_target) return NullOpt; + if (!is_same_target) return std::nullopt; Stmt inlined = InlineArguments(gvar.value(), callee, call->args); return VisitStmt(inlined); @@ -252,7 +252,7 @@ class PrimFuncInliner : StmtExprMutator { */ PSet removable_funcs_; - Optional current_target_ = NullOpt; + Optional current_target_ = std::nullopt; }; } // namespace @@ -292,7 +292,8 @@ Pass InlinePrivateFunctions() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions") + .set_body_typed(InlinePrivateFunctions); } // namespace transform diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 3664abcf5612..0017e97beb88 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -600,7 +600,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { arith::Analyzer analyzer; PrimExpr condition = analyzer.Simplify(condition_); if (is_const_int(condition)) { - return NullOpt; + return std::nullopt; } Array equations; Array vars; @@ -644,7 +644,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { }; fvisit(condition); if (equations.empty() || vars.empty()) { - return NullOpt; + return std::nullopt; } // build dom ranges for related vars Map ranges; @@ -667,7 +667,7 @@ Optional ConditionalBoundsContext::TrySolveCondition() { arith::IntConstraints constraint(vars, ranges, equations); arith::IntConstraints result = arith::SolveInequalitiesToRange(constraint); if (!result->relations.empty()) { - return NullOpt; + return std::nullopt; } return std::move(result); } @@ -850,7 +850,7 @@ Pass ConvertSSA() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ConvertSSA").set_body_typed(ConvertSSA); +TVM_FFI_REGISTER_GLOBAL("tir.transform.ConvertSSA").set_body_typed(ConvertSSA); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 7a9a1f59977b..b30a47c84fe9 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -183,7 +183,7 @@ Pass LiftThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 3a0b80921ff9..1adc0f6b043a 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include #include #include @@ -810,7 +810,7 @@ Pass LoopPartition() { return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); } // namespace transform diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index e1ec0f1572c7..c3358e1c9207 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -175,7 +175,7 @@ Pass LowerAsyncDMA() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 325d8e5bb578..31d7f91d74e9 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -312,7 +312,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // }; Array ct_buffer_regions = f_create_buffer_regions(ct_buffers); - Optional> it_buffer_regions = NullOpt; + Optional> it_buffer_regions = std::nullopt; if (it_buffers.defined()) { it_buffer_regions = f_create_buffer_regions(it_buffers.value()); } @@ -342,7 +342,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // } // Stmt 2: do in-thread reduction { - Optional new_realize = NullOpt; + Optional new_realize = std::nullopt; // If need to generate in-thread reduction, // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize // otherwise, directly remove given BlockRealize @@ -353,7 +353,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body)); - new_block->init = NullOpt; + new_block->init = std::nullopt; ObjectPtr n = make_object(*realize); n->block = Block(new_block); new_realize = BlockRealize(n); @@ -673,9 +673,9 @@ class CrossThreadReductionTransformer : public StmtMutator { Array combiner_lhs{nullptr}; Array combiner_rhs{nullptr}; std::tie(init_values, updates) = - GetInitValuesAndUpdatesFromReductionBlock(NullOpt, GetRef(block)); + GetInitValuesAndUpdatesFromReductionBlock(std::nullopt, GetRef(block)); std::tie(reducer, combiner_lhs, combiner_rhs) = - GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates); + GetReducerAndCombinerLhsRhs(std::nullopt, init_values, updates); // Condition 4. All reduction buffers should be all local or all non-local. int is_local_buf = -1; @@ -815,7 +815,7 @@ class CrossThreadReductionTransformer : public StmtMutator { Array& new_buffers = block2new_buffers_[block_stack_.back()]; Array ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); new_buffers.insert(new_buffers.end(), ct_buffers.begin(), ct_buffers.end()); - Optional> it_buffers = NullOpt; + Optional> it_buffers = std::nullopt; if (need_in_thread_reduction) { it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); @@ -934,7 +934,7 @@ Pass LowerCrossThreadReduction() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction") .set_body_typed(LowerCrossThreadReduction); } // namespace transform diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 19f529103b5c..dbc529cfeabd 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -21,7 +21,7 @@ * \brief Pass for lowering custom datatypes */ -#include +#include #include #include #include @@ -249,7 +249,7 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); } // namespace transform diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 6eb196d2520e..2ca0e6d92f68 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -21,8 +21,8 @@ * \file lower_device_kernel_launch.cc * \brief Split device function from host. */ +#include #include -#include #include #include #include @@ -143,7 +143,7 @@ class DeviceInfoCollector : public StmtVisitor { // The extent of each thread Map thread_extent; // The amount of dynamic shared memory used - Optional dyn_shmem_size{NullOpt}; + Optional dyn_shmem_size{std::nullopt}; }; class ReturnRemover : public StmtExprMutator { @@ -195,7 +195,7 @@ class DeviceKernelMutator : public StmtExprMutator { func.CopyOnWrite()->body = body; } - current_target_ = NullOpt; + current_target_ = std::nullopt; return func; } @@ -369,7 +369,7 @@ Pass LowerDeviceKernelLaunch() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") .set_body_typed(LowerDeviceKernelLaunch); } // namespace transform diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 8951f70a6c5b..a30232b9ce80 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -22,7 +22,7 @@ * \brief Lower the special device storage access. */ #include -#include +#include #include #include #include @@ -130,7 +130,7 @@ Pass LowerDeviceStorageAccessInfo() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") .set_body_typed(LowerDeviceStorageAccessInfo); } // namespace transform diff --git a/src/tir/transforms/lower_init_block.cc b/src/tir/transforms/lower_init_block.cc index 3e8fc204314d..03188fb6c907 100644 --- a/src/tir/transforms/lower_init_block.cc +++ b/src/tir/transforms/lower_init_block.cc @@ -39,7 +39,7 @@ class InitBlockLower : public StmtMutator { Stmt init = DoLowering(block->init.value(), block->iter_vars); Stmt body = VisitStmt(block->body); auto n = CopyOnWrite(block); - n->init = NullOpt; + n->init = std::nullopt; n->body = SeqStmt::Flatten(init, body); return Block(n); } @@ -79,7 +79,7 @@ Pass LowerInitBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerInitBlock", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock); } // namespace transform diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index c1b8a2e83a45..8fe9bedce9f0 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -21,7 +21,7 @@ * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ -#include +#include #include #include #include @@ -386,7 +386,7 @@ Pass LowerIntrin() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - auto mtriple = target.value()->GetAttr("mtriple", ""); + auto mtriple = target.value()->GetAttr("mtriple", ""); n->body = IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); return f; @@ -394,7 +394,7 @@ Pass LowerIntrin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); } // namespace transform diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 3c2c6b67e653..6e2ea5bc14af 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -267,7 +267,7 @@ Pass LowerMatchBuffer() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerMatchBuffer", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchBuffer); } // namespace transform diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 215e4461d672..f3551987426d 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -107,7 +107,7 @@ class OpaqueBlockLower : public StmtExprMutator { } else { // Case 3. An ordinary loop body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body), - NullOpt, new_annotations); + std::nullopt, new_annotations); } // Step 5. Insert nested attrs for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) { @@ -150,9 +150,9 @@ class OpaqueBlockLower : public StmtExprMutator { PrimExpr ConvertAttrValue(const String& key, const Any& obj) { if (obj == nullptr) { return PrimExpr(); - } else if (auto expr = obj.as()) { + } else if (auto expr = obj.try_cast()) { return expr.value(); - } else if (auto str = obj.as()) { + } else if (auto str = obj.try_cast()) { return std::move(StringImm(str.value())); } else { LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj.GetTypeKey() @@ -213,7 +213,7 @@ Pass LowerOpaqueBlock() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index dde33fa2678d..0d2092338228 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -22,7 +22,7 @@ * \file lower_thread_allreduce.cc */ #include -#include +#include #include #include #include @@ -103,7 +103,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return new_buf; } - return NullOpt; + return std::nullopt; } Stmt VisitStmt_(const DeclBufferNode* op) final { @@ -294,8 +294,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {}); if (reduce_extent <= warp_size_) { - std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce( - values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq); + std::tie(reduce_results, new_alloc_bufs) = + MakeWarpAllreduce(values, types, combiner, reduce_index, reduce_extent, group_index, + mask, std::nullopt, &seq); // Broadcast the reduction result from lane 0 to all other lanes. // This avoids to emit predicated stores, as all threads are @@ -324,8 +325,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // 2. First round of allreduce. - std::tie(reduce_results, local_bufs) = MakeWarpAllreduce( - values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq); + std::tie(reduce_results, local_bufs) = + MakeWarpAllreduce(values, types, combiner, reduce_index, warp_size_, group_index, mask, + std::nullopt, &seq); new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end()); // 3. Write allreduce results to staging buffer. @@ -807,7 +809,7 @@ Pass LowerThreadAllreduce() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 0931edcc2ec1..095bd321c937 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -21,7 +21,7 @@ * Lower TVM related builtin intrinsics such as packed call. * \file tir/transforms/lower_tvm_buildin.cc */ -#include +#include #include #include #include @@ -39,7 +39,7 @@ namespace tir { class BuiltinLower : public StmtExprMutator { public: static PrimFunc Build(PrimFunc func) { - Optional device_type = NullOpt; + Optional device_type = std::nullopt; if (auto target = func->GetAttr(tvm::attr::kTarget)) { device_type = Integer(target.value()->kind->default_device_type); } @@ -49,7 +49,8 @@ class BuiltinLower : public StmtExprMutator { return func; } - explicit BuiltinLower(Optional device_type = NullOpt) : device_type_(device_type) {} + explicit BuiltinLower(Optional device_type = std::nullopt) + : device_type_(device_type) {} // NOTE: Right now, we make the following scoping requirement // for memory allocated by the following primitives @@ -650,8 +651,8 @@ class BuiltinLower : public StmtExprMutator { // The prepration sequence to be emitted before the current statement. std::vector> prep_seq_stack_; - Optional device_type_{NullOpt}; - Optional device_id_{NullOpt}; + Optional device_type_{std::nullopt}; + Optional device_id_{std::nullopt}; bool is_precheck_{false}; @@ -672,7 +673,7 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_vtcm_alloc.cc b/src/tir/transforms/lower_vtcm_alloc.cc index 0b5f7bf1554d..eac2a21b4917 100644 --- a/src/tir/transforms/lower_vtcm_alloc.cc +++ b/src/tir/transforms/lower_vtcm_alloc.cc @@ -72,7 +72,7 @@ Pass LowerVtcmAlloc() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); } // namespace transform diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 4a364c0ecb8b..b1642bef3c92 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -27,7 +27,7 @@ // explaining the concept of warp shuffle. #include #include -#include +#include #include #include #include @@ -461,7 +461,7 @@ Pass LowerWarpMemory() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7f8dc60460b4..340e018a8db8 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,8 +20,9 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ +#include #include -#include +#include #include #include #include @@ -123,7 +124,7 @@ class SubroutineCallRewriter : public StmtExprMutator { if (rewriter.made_change_) { return stmt; } else { - return NullOpt; + return std::nullopt; } } @@ -172,21 +173,21 @@ inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) { * \param func The function to be inspected * * \returns The global_symbol to be used for the function at call - * sites, or NullOpt if the function is to remain unchanged. + * sites, or std::nullopt if the function is to remain unchanged. */ Optional RequiresPackedAPI(const PrimFunc& func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { if (CallingConv(opt.value()->value) != CallingConv::kDefault) { - return NullOpt; + return std::nullopt; } } // Internal function calls do not need the ffi::Function API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (!global_symbol.defined()) { - return NullOpt; + return std::nullopt; } return global_symbol; @@ -438,7 +439,9 @@ Pass MakePackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { return MakePackedAPI(); }); +TVM_FFI_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed([]() { + return MakePackedAPI(); +}); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 0cb072701cb5..a72d68972735 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -20,8 +20,8 @@ /*! * \file make_unpacked_api.cc Lower PrimFunc to a standard C function API. */ +#include #include -#include #include #include #include @@ -51,7 +51,7 @@ class SubroutineCallRewriter : public StmtExprMutator { if (rewriter.made_change_) { return stmt; } else { - return NullOpt; + return std::nullopt; } } @@ -200,7 +200,7 @@ Pass MakeUnpackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakeUnpackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); +TVM_FFI_REGISTER_GLOBAL("tir.transform.MakeUnpackedAPI").set_body_typed(MakeUnpackedAPI); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 885d5917136d..dc9420f728be 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -275,7 +275,7 @@ Pass ManifestSharedMemoryLocalStage() { return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage") .set_body_typed(ManifestSharedMemoryLocalStage); } // namespace transform diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 5ca20f57aa78..2be5e148fbfe 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -118,7 +118,7 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { if (v.same_as(loop->loop_var)) { return substitute_value; } else { - return NullOpt; + return std::nullopt; } }); PrimExpr predicate = substitute_value < loop->extent; diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc index 8d576b0258f4..2ecb740ba327 100644 --- a/src/tir/transforms/memhammer_intermediate_stage.cc +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -395,7 +395,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String new_loop->loop_var = new_loop_vars[i]; new_loop->body = generate_body; new_loop->kind = ForKind::kSerial; - new_loop->thread_binding = NullOpt; + new_loop->thread_binding = std::nullopt; new_loop->annotations = {}; generate_body = For(new_loop); } diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc index 6d35cc5ac2d1..916c5c84e9af 100644 --- a/src/tir/transforms/memhammer_lower_auto_copy.cc +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include #include @@ -776,7 +776,7 @@ Pass LowerAutoCopy() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); +TVM_FFI_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h index e8dc22be4f50..46c9a97c527d 100644 --- a/src/tir/transforms/memhammer_rewrite_rule.h +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -20,7 +20,7 @@ #define TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ #include -#include +#include #include #include #include diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index 71f91af92fea..5a0d0fa2105c 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -42,7 +42,7 @@ std::pair> TileWmmaBlock(Stmt stmt) { arith::Analyzer analyzer; if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) || !analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) { - return std::make_pair(stmt, NullOpt); + return std::make_pair(stmt, std::nullopt); } } Var new_loop_vars[4] = { @@ -177,7 +177,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { /*6:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2], /*7:*/ StringImm(layout), })), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/ { @@ -280,7 +280,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { }), /*6:*/ new_tgt_buffer->strides[0], /*7:*/ StringImm("row_major")})), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/ { @@ -366,7 +366,7 @@ std::pair> TileMmaToGlobalBlock(Stmt stmt) { // Only tile when both extent % 8 == 0 if (!analyzer.CanProveEqual(floormod(extent_last1, 8), 0) || !analyzer.CanProveEqual(floormod(extent_last2, 8), 0)) { - return std::make_pair(stmt, NullOpt); + return std::make_pair(stmt, std::nullopt); } } Var new_loop_vars[4] = { @@ -498,7 +498,7 @@ Stmt RewriteMmaStore(Stmt stmt) { {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}), /*annotations=*/{})), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/{}, /*match_buffers=*/ { diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 85f102cb4177..52966e005aaa 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -23,7 +23,7 @@ * This pass merges multiple TIR-level dynamic or static shared memory allocations into one * allocation. */ -#include +#include #include #include #include @@ -695,7 +695,7 @@ Pass MergeSharedMemoryAllocations() { return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations") +TVM_FFI_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations") .set_body_typed(MergeSharedMemoryAllocations); } // namespace transform diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 696eae201f3c..8183b2fd8f45 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -22,7 +22,7 @@ * \brief narrow the datatype of indexing vars */ -#include +#include #include #include #include @@ -320,7 +320,7 @@ Pass NarrowDataType(int target_bits) { return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); +TVM_FFI_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 5ce8ade2085c..c141ef33c289 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -212,7 +212,7 @@ class BufferAllocationLocator : public StmtExprMutator { /*writes=*/{}, /*name_hint=*/"", /*body=*/std::move(body), - /*init=*/NullOpt, + /*init=*/std::nullopt, /*alloc_buffers=*/alloc_buffers); ObjectPtr n = CopyOnWrite(opaque_block.get()); Array> access = @@ -257,7 +257,7 @@ Pass PlanAndUpdateBufferAllocationLocation() { return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {}); } -TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation") +TVM_FFI_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation") .set_body_typed(PlanAndUpdateBufferAllocationLocation); } // namespace transform diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 7efa23bc322d..ade1aea7c941 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -109,9 +109,9 @@ transform::Pass Filter(ffi::TypedFunction fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); -TVM_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); -TVM_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); +TVM_FFI_REGISTER_GLOBAL("tir.transform.BindTarget").set_body_typed(BindTarget); +TVM_FFI_REGISTER_GLOBAL("tir.transform.AnnotateEntryFunc").set_body_typed(AnnotateEntryFunc); +TVM_FFI_REGISTER_GLOBAL("tir.transform.Filter").set_body_typed(Filter); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/profile_instrumentation.cc b/src/tir/transforms/profile_instrumentation.cc index 7f6930e2e2bf..f8548ca59a7d 100644 --- a/src/tir/transforms/profile_instrumentation.cc +++ b/src/tir/transforms/profile_instrumentation.cc @@ -283,7 +283,7 @@ Pass InstrumentProfileIntrinsics() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstrumentProfileIntrinsics", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics") +TVM_FFI_REGISTER_GLOBAL("tir.transform.InstrumentProfileIntrinsics") .set_body_typed(InstrumentProfileIntrinsics); } // namespace transform diff --git a/src/tir/transforms/reduce_branching_through_overcompute.cc b/src/tir/transforms/reduce_branching_through_overcompute.cc index 0c3f7a9ba32f..0593c4f812fe 100644 --- a/src/tir/transforms/reduce_branching_through_overcompute.cc +++ b/src/tir/transforms/reduce_branching_through_overcompute.cc @@ -169,7 +169,7 @@ Pass ReduceBranchingThroughOvercompute() { return CreatePrimFuncPass(pass_func, 0, "tir.ReduceBranchingThroughOvercompute", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute") +TVM_FFI_REGISTER_GLOBAL("tir.transform.ReduceBranchingThroughOvercompute") .set_body_typed(ReduceBranchingThroughOvercompute); } // namespace transform diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index 519a3e1f80d8..6afaa0c61583 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -20,7 +20,7 @@ /*! * \file remap_thread_axis.cc */ -#include +#include #include #include #include @@ -69,7 +69,7 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { tmap[kv.first] = kv.second; @@ -96,14 +96,14 @@ PrimFunc RemapThreadAxis(PrimFunc func, Map thread_map namespace transform { -Pass RemapThreadAxis(Map thread_map) { +Pass RemapThreadAxis(Map thread_map) { auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { return RemapThreadAxis(std::move(f), thread_map); }; return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc index 928bcf02bc1b..ce7176e8cc46 100644 --- a/src/tir/transforms/remove_assume.cc +++ b/src/tir/transforms/remove_assume.cc @@ -21,7 +21,7 @@ * \file remove_store_undef.cc * \brief Remove stores of tir::builtin::undef */ -#include +#include #include #include #include @@ -61,7 +61,7 @@ Pass RemoveAssume() { return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume"); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume); } // namespace transform diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 3b418aac0cf5..49dd41ae86a6 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -22,7 +22,7 @@ * \brief Remove no op from the stmt */ #include -#include +#include #include #include #include @@ -331,7 +331,7 @@ Pass RemoveNoOp() { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); } // namespace transform diff --git a/src/tir/transforms/remove_store_undef.cc b/src/tir/transforms/remove_store_undef.cc index 6b28cb165aa9..31b4a558c600 100644 --- a/src/tir/transforms/remove_store_undef.cc +++ b/src/tir/transforms/remove_store_undef.cc @@ -21,7 +21,7 @@ * \file remove_store_undef.cc * \brief Remove stores of tir::builtin::undef */ -#include +#include #include #include #include @@ -171,7 +171,7 @@ Pass RemoveStoreUndef() { "tir.RemoveStoreUndef"); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef); } // namespace transform diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index e8d89bfb5700..881f321bf673 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -285,7 +285,7 @@ Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) { return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock") +TVM_FFI_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock") .set_body_typed(RemoveWeightLayoutRewriteBlock); } // namespace transform diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 6098535fb5e8..cd1517b11c2a 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -116,7 +116,7 @@ class RenewDefMutator : public StmtExprMutator { std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); // Step 3. Visit body - Optional init = NullOpt; + Optional init = std::nullopt; if (op->init.defined()) { init = this->VisitStmt(op->init.value()); } @@ -290,7 +290,7 @@ class RenewDefMutator : public StmtExprMutator { PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } -TVM_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); +TVM_FFI_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc index beb5997d4982..0fb24c62500a 100644 --- a/src/tir/transforms/renormalize_split_pattern.cc +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -21,7 +21,7 @@ * \file renormalize_split_pattern.cc * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) */ -#include +#include #include #include #include @@ -205,7 +205,7 @@ Pass RenormalizeSplitPattern() { return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") +TVM_FFI_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") .set_body_typed(RenormalizeSplitPattern); } // namespace transform diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 7646d01f8e90..624e2d9921a9 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -21,7 +21,7 @@ * \file unsafe_select_rewrite.cc * \brief Rewrite uinsafe select expression. */ -#include +#include #include #include #include @@ -139,7 +139,7 @@ Pass RewriteUnsafeSelect() { return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); +TVM_FFI_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); } // namespace transform diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index f518c61bc676..82c5d4178401 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -25,7 +25,7 @@ #include "../../tir/transforms/simplify.h" #include -#include +#include #include #include #include @@ -146,7 +146,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, - Optional config_opt = NullOpt) { + Optional config_opt = std::nullopt) { auto config = config_opt.value_or(AttrsWithDefaultValues()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); @@ -327,7 +327,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (const int64_t* as_int = as_const_int(condition)) { return Bool(*as_int); } else { - return NullOpt; + return std::nullopt; } } @@ -335,7 +335,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { std::optional touch_pattern_; Map non_inlined_bindings_; - Optional current_stmt_{NullOpt}; + Optional current_stmt_{std::nullopt}; std::unordered_set used_in_buffer_def_; }; @@ -359,7 +359,7 @@ Pass Simplify() { return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); +TVM_FFI_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index d9cd6d35497c..98aea3da99d5 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -17,7 +17,7 @@ * under the License. */ -#include +#include #include #include #include @@ -47,7 +47,7 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); +TVM_FFI_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index adf00f0b57c4..0b12bd02d482 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -21,9 +21,9 @@ * \file split_host_device.cc * \brief Split device function from host. */ +#include #include #include -#include #include #include #include @@ -168,7 +168,7 @@ Pass SplitHostDevice() { return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); +TVM_FFI_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 22c347066789..b8062e2a2f10 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -23,8 +23,8 @@ * Re-write data access to enable memory sharing when possible. */ #include +#include #include -#include #include #include #include @@ -1761,7 +1761,7 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); +TVM_FFI_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -1770,7 +1770,7 @@ Pass PointerValueTypeRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); } -TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") +TVM_FFI_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") .set_body_typed(PointerValueTypeRewrite); } // namespace transform diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index e0ae7172ad5c..3c6a6fc9be86 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -21,7 +21,7 @@ * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ -#include +#include #include #include #include @@ -217,7 +217,7 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); +TVM_FFI_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index fd772863f780..34878d1b333d 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -20,7 +20,7 @@ /*! * \file thread_storage_sync.cc */ -#include +#include #include #include #include @@ -471,7 +471,7 @@ Pass ThreadSync(String storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); +TVM_FFI_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index 5332fcfff123..3ef35d74cf8a 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -184,7 +184,7 @@ Pass TransformMmaBufferLayout() { return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {}); } -TVM_REGISTER_GLOBAL("tir.transform.TransformMmaBufferLayout") +TVM_FFI_REGISTER_GLOBAL("tir.transform.TransformMmaBufferLayout") .set_body_typed(TransformMmaBufferLayout); } // namespace transform diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 33d77c23353a..08fc921f4ebf 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -76,7 +76,7 @@ class ThreadBindingUnifier : public StmtExprMutator { /*min=*/IntImm(dtype, 0), // /*extent=*/IntImm(dtype, 1), // /*kind=*/ForKind::kSerial, stmt, // - /*thread_binding=*/NullOpt, // + /*thread_binding=*/std::nullopt, // /*annotation=*/std::move(annotations)); } } @@ -199,7 +199,7 @@ Pass UnifyThreadBinding() { return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); +TVM_FFI_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding); } // namespace transform diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a68ebe7e02ff..7218adbda216 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -23,7 +23,7 @@ */ // Unrolls the loop as in Halide pipeline. #include -#include +#include #include #include #include @@ -288,7 +288,7 @@ Pass UnrollLoop() { return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); +TVM_FFI_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); } // namespace transform diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 047e80ea0e7b..c4d2d4608044 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -21,7 +21,7 @@ * \file unsupported_dtype_legalize.cc * \brief legalize bf16/fp8 type by adding cast_to_fp32 */ -#include +#include #include #include #include @@ -758,7 +758,7 @@ Pass BF16ComputeLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize); Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { @@ -771,7 +771,7 @@ Pass BF16StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -779,13 +779,12 @@ Pass FP8ComputeLegalize(String promote_dtype_str) { if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } - return FP8ComputeLegalizer(DataType(runtime::StringToDLDataType(promote_dtype_str))) - .Legalize(f); + return FP8ComputeLegalizer(DataType(StringToDLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { @@ -798,7 +797,7 @@ Pass FP8StorageLegalize() { return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize); +TVM_FFI_REGISTER_GLOBAL("tir.transform.FP8StorageLegalize").set_body_typed(FP8StorageLegalize); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc b/src/tir/transforms/using_assume_to_reduce_branches.cc index 3cd33b85905b..a1195cfef81f 100644 --- a/src/tir/transforms/using_assume_to_reduce_branches.cc +++ b/src/tir/transforms/using_assume_to_reduce_branches.cc @@ -161,7 +161,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { With analyzer_context; size_t old_num_constraints{0}; size_t new_num_constraints{0}; - Optional assume{NullOpt}; + Optional assume{std::nullopt}; // Disable default-generated copy/move assignment and constructors InternalConstraintContext(const InternalConstraintContext&) = delete; @@ -381,7 +381,7 @@ Pass UseAssumeToReduceBranches() { return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {}); } -TVM_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") +TVM_FFI_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches") .set_body_typed(UseAssumeToReduceBranches); } // namespace transform diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 58ce6d61742a..16aae03932cf 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -22,7 +22,7 @@ */ // Loop vectorizer as in Halide pipeline. #include -#include +#include #include #include #include @@ -80,8 +80,8 @@ bool EnableBufferLevelPredication(Target target) { return enable_buffer_predication.value(); } - // Use buffer-level predication by default for AArch64 SVE targets - return arith::TargetHasSVE(target); + // Use buffer-level predication by default for VLA targets + return arith::TargetHasVLA(target); } /*! @@ -769,7 +769,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitStmt(op->then_case); - Optional else_case = NullOpt; + Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } @@ -972,7 +972,7 @@ class LoopVectorizer : public StmtMutator { if (!extent_as_int || extent_as_int->value < 1) { bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && arith::TargetHasSVE(target_)) + ICHECK(is_scalable_expr && arith::TargetHasVLA(target_)) << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; } ICHECK(is_zero(op->min)); @@ -1028,7 +1028,7 @@ Pass VectorizeLoop(bool enable_vectorize) { return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } -TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); +TVM_FFI_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); } // namespace transform diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 6d6dc4edc5f6..1ee85e7b8c95 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -21,8 +21,7 @@ * \brief Registration of broadcast operators * \file broadcast.cc */ -#include -#include +#include #include #include @@ -32,19 +31,19 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { \ - bool lhs_is_tensor = args[0].as().has_value(); \ - bool rhs_is_tensor = args[1].as().has_value(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].cast(), args[1].cast()); \ - } \ +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_FFI_REGISTER_GLOBAL(OpName).set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { \ + bool lhs_is_tensor = args[0].as().has_value(); \ + bool rhs_is_tensor = args[1].as().has_value(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].cast(), args[1].cast()); \ + } \ }); TOPI_REGISTER_BCAST_OP("topi.add", topi::add); @@ -73,9 +72,10 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = broadcast_to(args[0].cast(), args[1].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.broadcast_to") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = broadcast_to(args[0].cast(), args[1].cast>()); + }); } // namespace topi } // namespace tvm diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index e4ac103f14d6..40c8332ab725 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -355,7 +355,7 @@ Array InferEinsumShape(const std::string& subscripts, return einsum_builder.InferShape(); } -TVM_REGISTER_GLOBAL("topi.einsum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.einsum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = einsum(args[0].cast(), args[1].cast>()); }); diff --git a/src/topi/elemwise.cc b/src/topi/elemwise.cc index e3a3411a9c6c..13947abcf604 100644 --- a/src/topi/elemwise.cc +++ b/src/topi/elemwise.cc @@ -21,8 +21,7 @@ * \brief Registration of elemwise operators * \file elemwise.cc */ -#include -#include +#include #include namespace tvm { @@ -31,139 +30,140 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.acos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.acos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = acos(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.acosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.acosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = acosh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.asin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.asin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = asin(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.asinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.asinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = asinh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.atanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.atanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = atanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = exp(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.fast_exp").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_exp(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = erf(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.fast_erf").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_erf(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.tan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tan(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.cos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.cos").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cos(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.cosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.cosh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cosh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sin(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sinh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sinh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.fast_tanh").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = fast_tanh(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.atan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.atan").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = atan(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sigmoid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sigmoid(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sqrt(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.rsqrt").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = rsqrt(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.log").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log2").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.log2").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log2(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.log10").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.log10").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = log10(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.identity").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.identity").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = identity(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.negative").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.negative").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = negative(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.clip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.clip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = clip(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.cast").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.cast").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = cast(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.reinterpret").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reinterpret(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = elemwise_sum(args[0].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.elemwise_sum") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = elemwise_sum(args[0].cast>()); + }); -TVM_REGISTER_GLOBAL("topi.sign").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sign").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = sign(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.full").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.full").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = full(args[0].cast>(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.full_like").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.full_like").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = full_like(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.logical_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = logical_not(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.bitwise_not").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = bitwise_not(args[0].cast()); }); diff --git a/src/topi/nn.cc b/src/topi/nn.cc index b9eeef74d778..4b2095a53868 100644 --- a/src/topi/nn.cc +++ b/src/topi/nn.cc @@ -21,8 +21,7 @@ * \brief Registration of NN operators * \file nn.cc */ -#include -#include +#include #include #include #include @@ -45,127 +44,131 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = relu(args[0].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = leaky_relu(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.leaky_relu") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = leaky_relu(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.prelu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = prelu(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = pad(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.space_to_batch_nd") +TVM_FFI_REGISTER_GLOBAL("topi.nn.space_to_batch_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = space_to_batch_nd(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd") +TVM_FFI_REGISTER_GLOBAL("topi.nn.batch_to_space_nd") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = batch_to_space_nd(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.nll_loss").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nll_loss(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast()); }); /* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dense(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); /* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.bias_add").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::bias_add(args[0].cast(), args[1].cast(), args[2].cast()); }); /* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.dilate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::dilate(args[0].cast(), args[1].cast>(), args[2].cast()); }); /* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.flatten").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::flatten(args[0].cast()); }); /* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") +TVM_FFI_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::scale_shift_nchw(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") +TVM_FFI_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::scale_shift_nhwc(args[0].cast(), args[1].cast(), args[2].cast()); }); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = - nn::pool_grad(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>(), - args[4].cast>(), static_cast(args[5].cast()), - args[6].cast(), args[7].cast(), args[8].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool_grad") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::pool_grad(args[0].cast(), args[1].cast(), + args[2].cast>(), args[3].cast>(), + args[4].cast>(), + static_cast(args[5].cast()), args[6].cast(), + args[7].cast(), args[8].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::global_pool(args[0].cast(), static_cast(args[1].cast()), - args[2].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.global_pool") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::global_pool(args[0].cast(), + static_cast(args[1].cast()), + args[2].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool1d") +TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool1d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool1d(args[0].cast(), args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") +TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool(args[0].cast(), args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") +TVM_FFI_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::adaptive_pool3d(args[0].cast(), args[1].cast>(), static_cast(args[2].cast()), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool1d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool1d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool2d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool2d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), args[6].cast(), args[7].cast(), args[8].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::pool3d(args[0].cast(), args[1].cast>(), args[2].cast>(), args[3].cast>(), args[4].cast>(), static_cast(args[5].cast()), @@ -173,45 +176,49 @@ TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body_packed([](ffi::PackedArgs args, f }); /* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::softmax(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::log_softmax(args[0].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.log_softmax") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::log_softmax(args[0].cast()); + }); -TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.lrn").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::lrn(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), args[4].cast(), args[5].cast()); }); /* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") +TVM_FFI_REGISTER_GLOBAL("topi.nn.binarize_pack") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::binarize_pack(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::binary_dense(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.binary_dense") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::binary_dense(args[0].cast(), args[1].cast()); + }); /* Ops from nn/layer_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::layer_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast>(), - args[4].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.layer_norm") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::layer_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast>(), + args[4].cast()); + }); /* Ops from nn/group_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = nn::group_norm(args[0].cast(), args[1].cast(), - args[2].cast(), args[3].cast(), args[4].cast(), - args[5].cast>(), args[6].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.nn.group_norm") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = nn::group_norm(args[0].cast(), args[1].cast(), + args[2].cast(), args[3].cast(), args[4].cast(), + args[5].cast>(), args[6].cast()); + }); /* Ops from nn/instance_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.instance_norm") +TVM_FFI_REGISTER_GLOBAL("topi.nn.instance_norm") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::instance_norm(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast(), @@ -219,7 +226,7 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm") }); /* Ops from nn/rms_norm.h */ -TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.nn.rms_norm").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = nn::rms_norm(args[0].cast(), args[1].cast(), args[2].cast>(), args[3].cast()); }); diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc index e1720cc0b6b0..f8920bdefd46 100644 --- a/src/topi/reduction.cc +++ b/src/topi/reduction.cc @@ -21,8 +21,7 @@ * \brief Registration of reduction operators * \file reduction.cc */ -#include -#include +#include #include #include @@ -32,43 +31,44 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::sum(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.min").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.min").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::min(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.max").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.max").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::max(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.argmin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.argmin").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::argmin(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.argmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.argmax").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::argmax(args[0].cast(), ArrayOrInt(args[1]), args[2].cast(), false, args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.prod").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.prod").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::prod(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.all").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.all").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::all(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.any").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.any").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::any(args[0].cast(), ArrayOrInt(args[1]), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.collapse_sum") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = topi::collapse_sum(args[0].cast(), args[1].cast>()); + }); } // namespace topi } // namespace tvm diff --git a/src/topi/transform.cc b/src/topi/transform.cc index cf86242d8491..5826fdac864f 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -21,8 +21,7 @@ * \brief Registration of transform operators * \file transform.cc */ -#include -#include +#include #include #include #include @@ -37,56 +36,58 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.expand_dims").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = expand_dims(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.transpose").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.transpose").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = transpose(args[0].cast(), args[1].cast>>()); }); -TVM_REGISTER_GLOBAL("topi.flip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.flip").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // pass empty seq_lengths tensor to reverse_sequence *rv = reverse_sequence(args[0].cast(), Tensor(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.reverse_sequence") +TVM_FFI_REGISTER_GLOBAL("topi.reverse_sequence") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reverse_sequence(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.reshape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.reshape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = reshape(args[0].cast(), args[1].cast>()); }); -TVM_REGISTER_GLOBAL("topi.sliding_window").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sliding_window(args[0].cast(), args[1].cast(), - args[2].cast>(), args[3].cast>()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.sliding_window") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = sliding_window(args[0].cast(), args[1].cast(), + args[2].cast>(), args[3].cast>()); + }); -TVM_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.squeeze").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = squeeze(args[0].cast(), ArrayOrInt(args[1])); }); -TVM_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.concatenate").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = concatenate(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.stack").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.stack").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = stack(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.shape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.shape").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = shape(args[0].cast(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = ndarray_size(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.ndarray_size") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = ndarray_size(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - if (args[1].as()) { +TVM_FFI_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + if (args[1].try_cast()) { *rv = split_n_sections(args[0].cast(), args[1].cast(), args[2].cast()); } else { *rv = split_indices_array(args[0].cast(), args[1].cast>(), @@ -94,13 +95,13 @@ TVM_REGISTER_GLOBAL("topi.split").set_body_packed([](ffi::PackedArgs args, ffi:: } }); -TVM_REGISTER_GLOBAL("topi.layout_transform") +TVM_FFI_REGISTER_GLOBAL("topi.layout_transform") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = layout_transform(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 4) { auto mode = args[3].cast(); int batch_dims = args[2].cast(); @@ -115,52 +116,55 @@ TVM_REGISTER_GLOBAL("topi.take").set_body_packed([](ffi::PackedArgs args, ffi::A } }); -TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - double pad_val = args[2].cast(); - int axis = args[3].cast(); - *rv = sequence_mask(args[0].cast(), args[1].cast(), pad_val, axis); -}); +TVM_FFI_REGISTER_GLOBAL("topi.sequence_mask") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + double pad_val = args[2].cast(); + int axis = args[3].cast(); + *rv = sequence_mask(args[0].cast(), args[1].cast(), pad_val, axis); + }); -TVM_REGISTER_GLOBAL("topi.where").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.where").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = where(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.arange").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.arange").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = arange(args[0].cast(), args[1].cast(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.meshgrid").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = meshgrid(args[0].cast>(), args[1].cast()); }); -TVM_REGISTER_GLOBAL("topi.repeat").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.repeat").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = repeat(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.tile").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tile").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = tile(args[0].cast(), args[1].cast>()); }); -TVM_REGISTER_GLOBAL("topi.gather").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.gather").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = gather(args[0].cast(), args[1].cast(), args[2].cast()); }); -TVM_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.gather_nd").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int batch_dims = args[2].cast(); *rv = gather_nd(args[0].cast(), args[1].cast(), batch_dims); }); -TVM_REGISTER_GLOBAL("topi.unravel_index").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = unravel_index(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.unravel_index") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = unravel_index(args[0].cast(), args[1].cast()); + }); -TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = sparse_to_dense(args[0].cast(), args[1].cast>(), - args[2].cast(), args[3].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.sparse_to_dense") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = sparse_to_dense(args[0].cast(), args[1].cast>(), + args[2].cast(), args[3].cast()); + }); -TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { switch (args.size()) { case 2: *rv = matmul(args[0].cast(), args[1].cast()); @@ -177,7 +181,7 @@ TVM_REGISTER_GLOBAL("topi.matmul").set_body_packed([](ffi::PackedArgs args, ffi: } }); -TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { if (args.size() == 2) { *rv = tensordot(args[0].cast(), args[1].cast()); } else if (args.size() == 3) { @@ -189,34 +193,36 @@ TVM_REGISTER_GLOBAL("topi.tensordot").set_body_packed([](ffi::PackedArgs args, f } }); -TVM_REGISTER_GLOBAL("topi.strided_slice").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - Tensor x = args[0].cast(); - Array begin = args[1].cast>(); - Array end = args[2].cast>(); - Array strides = args[3].cast>(); - Array axes = args[4].cast>(); - bool assume_inbound = args[6].cast(); - if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && - IsConstIntArray(x->shape)) { - Array begin_static = args[1].cast>(); - Array end_static = args[2].cast>(); - Array strides_static = args[3].cast>(); - auto slice_mode = args[5].cast(); - if (axes.size()) { - *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode); - } else { - *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); - } - } else { - if (axes.size()) { - *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); - } else { - *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); - } - } -}); +TVM_FFI_REGISTER_GLOBAL("topi.strided_slice") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + Tensor x = args[0].cast(); + Array begin = args[1].cast>(); + Array end = args[2].cast>(); + Array strides = args[3].cast>(); + Array axes = args[4].cast>(); + bool assume_inbound = args[6].cast(); + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && + IsConstIntArray(x->shape)) { + Array begin_static = args[1].cast>(); + Array end_static = args[2].cast>(); + Array strides_static = args[3].cast>(); + auto slice_mode = args[5].cast(); + if (axes.size()) { + *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, + slice_mode); + } else { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } + } else { + if (axes.size()) { + *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); + } else { + *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); + } + } + }); -TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice") +TVM_FFI_REGISTER_GLOBAL("topi.dynamic_strided_slice") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { te::Tensor begin = args[1].cast(); te::Tensor end = args[2].cast(); @@ -224,13 +230,13 @@ TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice") *rv = dynamic_strided_slice(args[0].cast(), begin, end, strides); }); -TVM_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice") +TVM_FFI_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice") .set_body_typed([](te::Tensor x, te::Tensor begin, te::Tensor end, te::Tensor strides, Array output_shape) { return relax::dynamic_strided_slice(x, begin, end, strides, output_shape); }); -TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { int depth = args[3].cast(); int axis = args[4].cast(); DataType dtype = args[5].cast(); @@ -238,18 +244,18 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body_packed([](ffi::PackedArgs args, ffi depth, axis, dtype); }); -TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - int k1 = args[2].cast(); - int k2 = args[3].cast(); - bool super_diag_right_align = args[4].cast(); - bool sub_diag_right_align = args[5].cast(); - *rv = matrix_set_diag(args[0].cast(), args[1].cast(), k1, k2, - super_diag_right_align, sub_diag_right_align); -}); +TVM_FFI_REGISTER_GLOBAL("topi.matrix_set_diag") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + int k1 = args[2].cast(); + int k2 = args[3].cast(); + bool super_diag_right_align = args[4].cast(); + bool sub_diag_right_align = args[5].cast(); + *rv = matrix_set_diag(args[0].cast(), args[1].cast(), k1, k2, + super_diag_right_align, sub_diag_right_align); + }); -TVM_REGISTER_GLOBAL("topi.adv_index").set_body_typed([](te::Tensor x, Array indices) { - return adv_index(x, indices); -}); +TVM_FFI_REGISTER_GLOBAL("topi.adv_index") + .set_body_typed([](te::Tensor x, Array indices) { return adv_index(x, indices); }); } // namespace topi } // namespace tvm diff --git a/src/topi/utils.cc b/src/topi/utils.cc index c02744a4202d..9a668ad2ac17 100644 --- a/src/topi/utils.cc +++ b/src/topi/utils.cc @@ -22,25 +22,24 @@ * \file utils.cc */ -#include -#include +#include #include namespace tvm { namespace topi { -TVM_REGISTER_GLOBAL("topi.utils.is_empty_shape") +TVM_FFI_REGISTER_GLOBAL("topi.utils.is_empty_shape") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = topi::detail::is_empty_shape(args[0].cast>()); }); -TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw") +TVM_FFI_REGISTER_GLOBAL("topi.utils.bilinear_sample_nchw") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nchw(args[0].cast(), args[1].cast>(), args[2].cast(), args[3].cast()); }); -TVM_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc") +TVM_FFI_REGISTER_GLOBAL("topi.utils.bilinear_sample_nhwc") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { *rv = detail::bilinear_sample_nhwc(args[0].cast(), args[1].cast>(), diff --git a/src/topi/vision.cc b/src/topi/vision.cc index dca44bf86c3c..57d936268010 100644 --- a/src/topi/vision.cc +++ b/src/topi/vision.cc @@ -21,8 +21,7 @@ * \brief Registration of vision operators * \file vision.cc */ -#include -#include +#include #include namespace tvm { @@ -31,9 +30,10 @@ namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - *rv = vision::reorg(args[0].cast(), args[1].cast()); -}); +TVM_FFI_REGISTER_GLOBAL("topi.vision.reorg") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + *rv = vision::reorg(args[0].cast(), args[1].cast()); + }); } // namespace topi } // namespace tvm diff --git a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc index c23391d16846..6f9c9f0f6f7b 100644 --- a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc @@ -18,12 +18,13 @@ */ #include -#include +#include #include "../src/runtime/hexagon/hexagon_buffer.h" using namespace tvm::runtime; using namespace tvm::runtime::hexagon; +using namespace tvm::ffi; TEST(HexagonBuffer, default_scope) { Optional scope; diff --git a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc index c6ce0c72f5e4..6211bd63dfbc 100644 --- a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc @@ -23,6 +23,7 @@ using namespace tvm::runtime; using namespace tvm::runtime::hexagon; +using namespace tvm::ffi; class HexagonDeviceAPITest : public ::testing::Test { protected: diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index e9c81fa91125..2e47473f8a17 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -23,6 +23,7 @@ using namespace tvm::runtime; using namespace tvm::runtime::hexagon; +using namespace tvm::ffi; class HexagonUserDMATest : public ::testing::Test { void SetUp() override { diff --git a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc index 8240241eee26..3cf008c874ab 100644 --- a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc @@ -23,6 +23,7 @@ using namespace tvm::runtime; using namespace tvm::runtime::hexagon; +using namespace tvm::ffi; class HexagonVtcmPoolTest : public ::testing::Test { void SetUp() override { diff --git a/tests/cpp-runtime/hexagon/run_all_tests.cc b/tests/cpp-runtime/hexagon/run_all_tests.cc index fa2a4aa45895..cf8160971a51 100644 --- a/tests/cpp-runtime/hexagon/run_all_tests.cc +++ b/tests/cpp-runtime/hexagon/run_all_tests.cc @@ -18,8 +18,7 @@ */ #include -#include -#include +#include #include #include @@ -38,7 +37,7 @@ namespace tvm { namespace runtime { namespace hexagon { -TVM_REGISTER_GLOBAL("hexagon.run_all_tests") +TVM_FFI_REGISTER_GLOBAL("hexagon.run_all_tests") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // gtest args are passed into this packed func as a singular string // split gtest args using delimiter and build argument vector diff --git a/tests/cpp-runtime/hexagon/run_unit_tests.cc b/tests/cpp-runtime/hexagon/run_unit_tests.cc index d9331db28bee..a4c613b41140 100644 --- a/tests/cpp-runtime/hexagon/run_unit_tests.cc +++ b/tests/cpp-runtime/hexagon/run_unit_tests.cc @@ -18,8 +18,7 @@ */ #include -#include -#include +#include #include #include @@ -80,7 +79,7 @@ class GtestPrinter : public testing::EmptyTestEventListener { std::string GetOutput() { return gtest_out_.str(); } }; -TVM_REGISTER_GLOBAL("hexagon.run_unit_tests") +TVM_FFI_REGISTER_GLOBAL("hexagon.run_unit_tests") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { // gtest args are passed into this packed func as a singular string // split gtest args using delimiter and build argument vector diff --git a/tests/cpp-runtime/opencl/aa_opencl_qcom_extn.cc b/tests/cpp-runtime/opencl/aa_opencl_qcom_extn.cc index 1f3dc2057aee..5073e49e3af0 100644 --- a/tests/cpp-runtime/opencl/aa_opencl_qcom_extn.cc +++ b/tests/cpp-runtime/opencl/aa_opencl_qcom_extn.cc @@ -21,7 +21,7 @@ // hence, crafted the filename accordingly #include -#include +#include #include "../src/runtime/opencl/opencl_common.h" diff --git a/tests/cpp-runtime/opencl/clml_memory_planner.cc b/tests/cpp-runtime/opencl/clml_memory_planner.cc index 3d4d9c41f40e..364eb5591e9c 100644 --- a/tests/cpp-runtime/opencl/clml_memory_planner.cc +++ b/tests/cpp-runtime/opencl/clml_memory_planner.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include diff --git a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc index 4aa09d63fd48..b8e5b90ece7c 100644 --- a/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc +++ b/tests/cpp-runtime/opencl/opencl_compile_to_bin.cc @@ -18,6 +18,9 @@ */ #include +#include +#include +#include #include #include @@ -27,6 +30,7 @@ using namespace tvm::runtime; using namespace tvm::runtime::cl; +using namespace tvm::ffi; namespace { // This kernel was generated by TVM for conv2d operation @@ -180,9 +184,9 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { module.InstallKernel(m_workspace, m_workspace->GetThreadEntry(), m_kernelNames[i], e); } Timestamp comp_end = std::chrono::high_resolution_clock::now(); - auto get_pre_compiled_f = - module.GetFunction("opencl.GetPreCompiledPrograms", GetObjectPtr(&module)); - bytes = get_pre_compiled_f().cast(); + auto get_pre_compiled_f = module.GetFunction("opencl.GetPreCompiledPrograms", + tvm::ffi::GetObjectPtr(&module)); + bytes = get_pre_compiled_f().cast(); std::chrono::duration duration = std::chrono::duration_cast(comp_end - comp_start); compileFromSourceTimeMS = duration.count() * 1e-6; @@ -192,7 +196,7 @@ TEST_F(OpenCLCompileBin, SourceVsBinaryCompilationPerf) { OpenCLModuleNode module(m_dataSrc, "cl", m_fmap, std::string()); module.Init(); module.GetFunction("opencl.SetPreCompiledPrograms", - GetObjectPtr(&module))(String(bytes)); + GetObjectPtr(&module))(tvm::String(bytes)); Timestamp comp_start = std::chrono::high_resolution_clock::now(); for (size_t i = 0; i < m_kernelNames.size(); ++i) { OpenCLModuleNode::KTRefEntry e = {i, 1}; diff --git a/tests/cpp-runtime/opencl/opencl_nativeptr.cc b/tests/cpp-runtime/opencl/opencl_nativeptr.cc index 8f894c4bffca..260effadea0b 100644 --- a/tests/cpp-runtime/opencl/opencl_nativeptr.cc +++ b/tests/cpp-runtime/opencl/opencl_nativeptr.cc @@ -18,7 +18,7 @@ */ #include -#include +#include #include #include diff --git a/tests/cpp-runtime/opencl/texture_copy_test.cc b/tests/cpp-runtime/opencl/texture_copy_test.cc index 701fec4d8baf..61d9044b6d86 100644 --- a/tests/cpp-runtime/opencl/texture_copy_test.cc +++ b/tests/cpp-runtime/opencl/texture_copy_test.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include #include @@ -101,12 +101,12 @@ TEST_F(TextureCopyTest, ViewBufferAsBuffer) { DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); - auto buffer = allocator->Alloc(cl_dev, ShapeTuple(shape), {kDLFloat, 32, 1}); + auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); auto stor = Storage(buffer, allocator); - auto opencl_memobj = stor->AllocNDArrayScoped(0, ShapeTuple(shape), {kDLFloat, 32, 1}, mem_scope); + auto opencl_memobj = stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, mem_scope); auto opencl_memview = - stor->AllocNDArrayScoped(0, ShapeTuple(same_shape), {kDLFloat, 32, 1}, mem_scope); + stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, mem_scope); std::random_device dev; std::mt19937 mt(dev()); @@ -158,12 +158,12 @@ TEST_F(TextureCopyTest, ViewBufferAsImage) { DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); - auto buffer = allocator->Alloc(cl_dev, ShapeTuple(shape), {kDLFloat, 32, 1}); + auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); auto stor = Storage(buffer, allocator); - auto opencl_buf_obj = stor->AllocNDArrayScoped(0, ShapeTuple(shape), {kDLFloat, 32, 1}, "global"); + auto opencl_buf_obj = stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global"); auto opencl_img_obj = - stor->AllocNDArrayScoped(0, ShapeTuple(same_shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); std::random_device dev; std::mt19937 mt(dev()); @@ -215,13 +215,13 @@ TEST_F(TextureCopyTest, ViewImageAsBuffer) { DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); - auto buffer = allocator->Alloc(cl_dev, ShapeTuple(shape), {kDLFloat, 32, 1}); + auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); auto stor = Storage(buffer, allocator); auto opencl_img_obj = - stor->AllocNDArrayScoped(0, ShapeTuple(shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global.texture"); auto opencl_buf_obj = - stor->AllocNDArrayScoped(0, ShapeTuple(same_shape), {kDLFloat, 32, 1}, "global"); + stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global"); std::random_device dev; std::mt19937 mt(dev()); @@ -273,13 +273,13 @@ TEST_F(TextureCopyTest, ViewImageAsImage) { DLDevice cl_dev = {kDLOpenCL, 0}; auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); - auto buffer = allocator->Alloc(cl_dev, ShapeTuple(shape), {kDLFloat, 32, 1}); + auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); auto stor = Storage(buffer, allocator); auto opencl_img_obj_1 = - stor->AllocNDArrayScoped(0, ShapeTuple(shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocNDArrayScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global.texture"); auto opencl_img_obj_2 = - stor->AllocNDArrayScoped(0, ShapeTuple(same_shape), {kDLFloat, 32, 1}, "global.texture"); + stor->AllocNDArrayScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); std::random_device dev; std::mt19937 mt(dev()); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 9449787218cc..348792d6ff88 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -336,7 +336,7 @@ TEST(IRF, Substitute) { if (var.same_as(x)) { return y; } - return NullOpt; + return std::nullopt; }; BufferLoad new_buffer_load = Downcast(Substitute(buffer_load, f_subst)); ICHECK(new_buffer_load->buffer->data.same_as(y)); diff --git a/tests/cpp/llvm_codegen_registry_test.cc b/tests/cpp/llvm_codegen_registry_test.cc index 534d4c8e411b..b5cea29c6b27 100644 --- a/tests/cpp/llvm_codegen_registry_test.cc +++ b/tests/cpp/llvm_codegen_registry_test.cc @@ -21,8 +21,7 @@ #include #include -#include -#include +#include #include diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 52bfa2cdd399..d552dae8f754 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -41,10 +41,10 @@ using namespace tvm::relax; TEST(NestedMsg, Basic) { // start with no annotation - relax::Var x("x", NullOpt), y("y", NullOpt); + relax::Var x("x", std::nullopt), y("y", std::nullopt); // constructor from array, T and nullopt. - NestedMsg msg({x, NullOpt, x}); + NestedMsg msg({x, std::nullopt, x}); EXPECT_TRUE(msg.IsNested()); EXPECT_FALSE(msg.IsLeaf()); @@ -64,11 +64,11 @@ TEST(NestedMsg, Basic) { // assignment // assign null - a0 = NullOpt; + a0 = std::nullopt; EXPECT_TRUE(a0 == nullptr); // assign array - a0 = {x, {x, NullOpt, y}}; + a0 = {x, {x, std::nullopt, y}}; EXPECT_TRUE(a0.IsNested()); auto t0 = a0.NestedArray()[1]; EXPECT_TRUE(t0.IsNested()); @@ -82,8 +82,8 @@ TEST(NestedMsg, Basic) { } TEST(NestedMsg, ForEachLeaf) { - relax::Var x("x", NullOpt), y("y", NullOpt); - NestedMsg msg = {x, {x, y}, NullOpt, {x, {x, y}}}; + relax::Var x("x", std::nullopt), y("y", std::nullopt); + NestedMsg msg = {x, {x, y}, std::nullopt, {x, {x, y}}}; int x_count = 0, y_count = 0; @@ -96,35 +96,36 @@ TEST(NestedMsg, ForEachLeaf) { } TEST(NestedMsg, Equal) { - relax::Var x("x", NullOpt), y("y", NullOpt); - relax::Var z("z", NullOpt); + relax::Var x("x", std::nullopt), y("y", std::nullopt); + relax::Var z("z", std::nullopt); auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); }; using M = NestedMsg; - EXPECT_TRUE(Equal(M(NullOpt), M(NullOpt), fequal)); + EXPECT_TRUE(Equal(M(std::nullopt), M(std::nullopt), fequal)); EXPECT_TRUE(Equal(M(x), M(x), fequal)); EXPECT_TRUE(Equal(M({x, y}), M({x, y}), fequal)); - EXPECT_TRUE(Equal(M({x, NullOpt}), M({x, NullOpt}), fequal)); + EXPECT_TRUE(Equal(M({x, std::nullopt}), M({x, std::nullopt}), fequal)); - EXPECT_TRUE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}}), fequal)); + EXPECT_TRUE(Equal(M({x, {std::nullopt, y}}), M({x, {std::nullopt, y}}), fequal)); - EXPECT_TRUE(Equal(M({x, {NullOpt, y}, {x, z}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + EXPECT_TRUE(Equal(M({x, {std::nullopt, y}, {x, z}}), M({x, {std::nullopt, y}, {x, z}}), fequal)); // type mismatch - EXPECT_FALSE(Equal(M({x, {NullOpt, y}, x}), M({x, {NullOpt, y}, {x, z}}), fequal)); + EXPECT_FALSE(Equal(M({x, {std::nullopt, y}, x}), M({x, {std::nullopt, y}, {x, z}}), fequal)); - EXPECT_FALSE(Equal(M({x, {NullOpt, y}, {x, NullOpt}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + EXPECT_FALSE(Equal(M({x, {std::nullopt, y}, {x, std::nullopt}}), + M({x, {std::nullopt, y}, {x, z}}), fequal)); - EXPECT_FALSE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + EXPECT_FALSE(Equal(M({x, {std::nullopt, y}}), M({x, {std::nullopt, y}, {x, z}}), fequal)); - EXPECT_FALSE(Equal(M(x), M(NullOpt), fequal)); + EXPECT_FALSE(Equal(M(x), M(std::nullopt), fequal)); - EXPECT_FALSE(Equal(M(NullOpt), M(x), fequal)); + EXPECT_FALSE(Equal(M(std::nullopt), M(x), fequal)); EXPECT_FALSE(Equal(M(x), M(Array({x})), fequal)); @@ -136,7 +137,7 @@ TEST(NestedMsg, MapAndDecompose) { relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32))); relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64))); - BlockBuilder bb = BlockBuilder::Create(NullOpt); + BlockBuilder bb = BlockBuilder::Create(std::nullopt); relax::Expr t0 = bb->Normalize(Tuple({x, y})); relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0})); @@ -158,12 +159,12 @@ TEST(NestedMsg, MapAndDecompose) { auto output2 = MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { const auto* prim_sinfo = sinfo.as(); - if (prim_sinfo == nullptr) return NullOpt; + if (prim_sinfo == nullptr) return std::nullopt; int bits = prim_sinfo->dtype.bits(); if (bits == 16) return c0; if (bits == 32) return c1; if (bits == 64) return c2; - return NullOpt; + return std::nullopt; }); EXPECT_TRUE(Equal(output2, expected, @@ -248,9 +249,9 @@ TEST(NestedMsg, CombineNestedMsg) { auto c1 = Integer(1); auto c2 = Integer(2); - NestedMsg lhs = {c0, {c0, c1}, NullOpt, {c0, {c1, c2}}}; - NestedMsg rhs = {c1, {c2, NullOpt}, NullOpt, {c1, {c2, c2}}}; - NestedMsg expected = {c1, {c2, c1}, NullOpt, {c1, {c2, c2}}}; + NestedMsg lhs = {c0, {c0, c1}, std::nullopt, {c0, {c1, c2}}}; + NestedMsg rhs = {c1, {c2, std::nullopt}, std::nullopt, {c1, {c2, c2}}}; + NestedMsg expected = {c1, {c2, c1}, std::nullopt, {c1, {c2, c2}}}; auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) { if (x->value > y->value) return x; @@ -267,8 +268,8 @@ TEST(NestedMsg, MapNestedMsg) { auto c2 = Integer(2); auto c3 = Integer(3); - NestedMsg msg = {c0, {c0, c1}, NullOpt, {c0, {c2, c1}}}; - NestedMsg expected = {c3, {c3, NullOpt}, NullOpt, {c3, {c2, NullOpt}}}; + NestedMsg msg = {c0, {c0, c1}, std::nullopt, {c0, {c2, c1}}}; + NestedMsg expected = {c3, {c3, std::nullopt}, std::nullopt, {c3, {c2, std::nullopt}}}; auto output = MapNestedMsg(msg, [](Integer x) { if (x->value == 0) { diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index 3f76b1d2f108..18dff9eaece9 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -18,8 +18,8 @@ */ #include +#include #include -#include #include namespace tvm { @@ -66,6 +66,7 @@ TVM_REGISTER_OBJECT_TYPE(ObjAA); TEST(ObjectHierachy, Basic) { using namespace tvm::runtime; using namespace tvm::test; + using namespace tvm::ffi; ObjectRef refA(make_object()); ICHECK_EQ(refA->type_index(), ObjA::RuntimeTypeIndex()); diff --git a/tests/cpp/runtime/memory/memory_manager_tests.cc b/tests/cpp/runtime/memory/memory_manager_tests.cc index 47146d2000fc..bfe6f9644dde 100644 --- a/tests/cpp/runtime/memory/memory_manager_tests.cc +++ b/tests/cpp/runtime/memory/memory_manager_tests.cc @@ -77,7 +77,7 @@ TEST_F(TvmVMMemoryManagerTest, NaiveEmptyBasic) { EXPECT_EQ(allocator->UsedMemory(), 0); auto dt = DataType::Float(32); size_t nbytes = 1 * 3 * 6 * 6 * dt.bytes(); - ShapeTuple shape = {1, 3, 6, 6}; + ffi::Shape shape = {1, 3, 6, 6}; { auto ndarray = allocator->Empty(shape, dt, dev); EXPECT_EQ(allocator->UsedMemory(), nbytes); @@ -92,7 +92,7 @@ TEST_F(TvmVMMemoryManagerTest, BothAllocatorsCoexists) { EXPECT_EQ(nallocator->UsedMemory(), 0); auto dt = DataType::Float(32); size_t nbytes = 1 * 3 * 6 * 6 * dt.bytes(); - ShapeTuple shape = {1, 3, 6, 6}; + ffi::Shape shape = {1, 3, 6, 6}; { auto ndarray = nallocator->Empty(shape, dt, dev); EXPECT_EQ(nallocator->UsedMemory(), nbytes); @@ -125,7 +125,7 @@ TEST_F(TvmVMMemoryManagerTest, PooledEmptyBasic) { size_t nbytes = 1 * 3 * 6 * 6 * dt.bytes(); size_t page_size = PooledAllocator::kDefaultPageSize; size_t size = ((nbytes + page_size - 1) / page_size) * page_size; - ShapeTuple shape = {1, 3, 6, 6}; + ffi::Shape shape = {1, 3, 6, 6}; { auto ndarray = allocator->Empty(shape, dt, dev); EXPECT_EQ(allocator->UsedMemory(), size); @@ -139,7 +139,7 @@ TEST_F(TvmVMMemoryManagerTest, NaiveAllocWithShape) { EXPECT_EQ(allocator->UsedMemory(), 0); auto dt = DataType::Float(32); size_t nbytes = 1 * 3 * 6 * 6 * dt.bytes(); - ShapeTuple shape = {1, 3, 6, 6}; + ffi::Shape shape = {1, 3, 6, 6}; auto buff = allocator->Alloc(dev, shape, dt); EXPECT_EQ(allocator->UsedMemory(), nbytes); allocator->Free(buff); @@ -165,7 +165,7 @@ TEST_F(TvmVMMemoryManagerTest, PooledAllocWithShape) { size_t nbytes = 1 * 3 * 6 * 6 * dt.bytes(); size_t page_size = PooledAllocator::kDefaultPageSize; size_t size = ((nbytes + page_size - 1) / page_size) * page_size; - ShapeTuple shape = {1, 3, 6, 6}; + ffi::Shape shape = {1, 3, 6, 6}; auto buff = allocator->Alloc(dev, shape, dt); EXPECT_EQ(allocator->UsedMemory(), size); allocator->Free(buff); diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index a5eeab034f82..ccb43a81d8f1 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -88,7 +88,7 @@ TEST(ScalableDataType, TestScalableDataTypeToString) { TEST(ScalableDataType, TestStringToScalableDataType) { std::string scalable_type_str = "int32xvscalex4"; - EXPECT_EQ(tvm::DataType(tvm::runtime::StringToDLDataType(scalable_type_str)), + EXPECT_EQ(tvm::DataType(tvm::ffi::StringToDLDataType(scalable_type_str)), tvm::DataType(kDLInt, 32, 4, true)); } @@ -97,7 +97,7 @@ TEST(ScalableDataType, TestInvalidStringToScalableDataType) { EXPECT_THROW( { try { - tvm::runtime::StringToDLDataType(scalable_type_str); + tvm::ffi::StringToDLDataType(scalable_type_str); } catch (const tvm::ffi::Error& e) { EXPECT_THAT(e.what(), HasSubstr("unknown dtype `int32x4xvscale`")); throw; diff --git a/tests/lint/rust_format.sh b/tests/lint/rust_format.sh deleted file mode 100755 index bed7ad976ea6..000000000000 --- a/tests/lint/rust_format.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -TVM_HOME="$(git rev-parse --show-toplevel)" -RUST_DIR="$TVM_HOME/rust" - -if [[ "$1" == "-i" ]]; then - INPLACE_FORMAT=1 - shift 1 -else - INPLACE_FORMAT=0 -fi - -cd $RUST_DIR - -if [[ ${INPLACE_FORMAT} -eq 1 ]]; then - cargo fmt -else - cargo fmt -- --check -fi diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 42f5b0ccd0b8..733d1d13b371 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -448,7 +448,6 @@ def test_simplify_le(): ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x)) ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16) - ck.verify(x * 8 + y - z < 16, x < 2) n = te.size_var("n") ck.verify(x * 8 + y < n, x * 8 + y < n) diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index ad4abdfe2934..6954cf4e1d5c 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -391,17 +391,17 @@ class TestAddIndex(BaseCompare): TestCase(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)), TestCase(tvm.te.min(0, 1 - x * 4) + x * 4, tvm.te.min(x * 4, 1)), TestCase(tvm.te.min(2 - x * 4, 0) + x * 4, tvm.te.min(x * 4, 2)), - TestCase(x * y + x * 10, x * (y + 10)), - TestCase(y * x + x * 10, x * (y + 10)), - TestCase(y * x + 10 * x, x * (y + 10)), - TestCase(x * y + 10 * x, x * (y + 10)), + TestCase(x * y + x * 10, (y + 10) * x), + TestCase(y * x + x * 10, (y + 10) * x), + TestCase(y * x + 10 * x, (y + 10) * x), + TestCase(x * y + 10 * x, (y + 10) * x), TestCase((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)), - TestCase(y * x + x, x * (y + 1)), - TestCase(x * y + x, x * (y + 1)), + TestCase(y * x + x, (y + 1) * x), + TestCase(x * y + x, (y + 1) * x), TestCase((x + 10) + 13, x + 23), TestCase((x + 10) + (13 + z), x + z + 23), - TestCase(x * y + 10 * x, x * (y + 10)), - TestCase(y * x + x * 3, x * (y + 3)), + TestCase(x * y + 10 * x, (y + 10) * x), + TestCase(y * x + x * 3, (y + 3) * x), TestCase(x + 3 + y, x + y + 3), TestCase((3 - y) + x, x - y + 3), # canonicalization @@ -409,10 +409,10 @@ class TestAddIndex(BaseCompare): TestCase(x + 2 + 3 + 4 + x * 3, x * 4 + 9), # DivMod rules # trunc div - TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10)), + TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), (y + 10) * tmod(x, 8)), TestCase(tdiv(x, 8) * 8 + tmod(x, 8), x), # floor div - TestCase(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)), + TestCase(y * flm(x, 8) + 10 * flm(x, 8), (y + 10) * flm(x, 8)), TestCase(fld(x, 8) * 8 + flm(x, 8), x), TestCase(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)), ) @@ -436,10 +436,10 @@ class TestSubIndex(BaseCompare): TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)), # mul co-efficient foldng TestCase(x - x, 0), - TestCase(x * y - x, x * (y + (-1))), - TestCase(x * y - 10 * x, x * (y + (-10))), - TestCase(y * x - x * z, x * (y - z)), - TestCase(y * x - z * x, x * (y - z)), + TestCase(x * y - x, (y + (-1)) * x), + TestCase(x * y - 10 * x, (y + (-10)) * x), + TestCase(y * x - x * z, (y - z) * x), + TestCase(y * x - z * x, (y - z) * x), TestCase(x + 10 - 20, x + (-10)), # 4-operands pattern TestCase((x + y) - (x + z), y - z), diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 3b0237740045..5a61cb8a52a9 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -113,7 +113,7 @@ def test_simplify_vscale_comparison_without_sve_target(capfd): warning_msg = ( "Warning: The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " - "AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu" + "VLA targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu" ) capture = capfd.readouterr().err assert warning_msg in capture @@ -131,5 +131,18 @@ def test_regression_simplify_inf_recursion(): ana.rewrite_simplify(res) +def test_simplify_floor_mod_with_linear_offset(): + """ + Test that the floor_mod is simplified correctly when the offset is linear. + """ + ana = tvm.arith.Analyzer() + past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64") + expr1 = (past_decoder_sequence_length + 1) * 64 + divisor1 = (past_decoder_sequence_length + 1) * 32 + assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0) + divisor2 = 32 * (past_decoder_sequence_length + 1) + assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index 5a24fcb197a5..aa56411cc9e0 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -102,7 +102,7 @@ def compile_metal(src, target): if define_metal_compile_callback: if cached is None: - tvm._ffi.registry.remove_global_func(name) + tvm.ffi.registry.remove_global_func(name) else: tvm.register_func(name, cached, override=True) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 43870044d528..2c8f185d8ecd 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -43,7 +43,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] * B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mul instructions using z registers assembly = f.get_source("asm") @@ -73,7 +75,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] + B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and add instructions using z registers assembly = f.get_source("asm") @@ -103,7 +107,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] - B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and sub instructions using z registers assembly = f.get_source("asm") @@ -134,7 +140,9 @@ def check_correct_assembly(type): B = te.placeholder(m, dtype=type, name="B") C = te.placeholder(m, dtype=type, name="C") D = te.compute((m), lambda i: A[i] * B[i] + C[i], name="D") - f = tvm.tir.build(te.create_prim_func([A, B, C, D]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C, D])) # Verify we see SVE load instructions and either mad or mla instructions using z registers assembly = f.get_source("asm") @@ -164,7 +172,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.max(A[i], B[i])) - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a max instruction, all using z registers assembly = f.get_source("asm") @@ -198,7 +208,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.min(A[i], B[i])) - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmgt + sel instructions or a min instruction, all using z registers assembly = f.get_source("asm") @@ -232,7 +244,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.div(A[i], B[i])) - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and div instructions using z registers assembly = f.get_source("asm") @@ -261,7 +275,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: tvm.te.floormod(A[i], B[i]), name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and mls instructions using z registers assembly = f.get_source("asm") @@ -291,7 +307,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] == B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpeq or cmeq instructions using z registers assembly = f.get_source("asm") @@ -321,7 +339,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] != B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and cmpgt, cmgt, cmpne or cmne instructions, all using z registers assembly = f.get_source("asm") @@ -350,7 +370,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] | B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and orr instructions using z registers assembly = f.get_source("asm") @@ -379,7 +401,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype=type, name="B") C = te.compute((m), lambda i: A[i] & B[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see SVE load instructions and and instructions using z registers assembly = f.get_source("asm") @@ -407,7 +431,9 @@ def check_correct_assembly(type): m = te.var("m") A = te.placeholder(m, dtype=type, name="A") C = te.compute((m), lambda i: ~A[i], name="C") - f = tvm.tir.build(te.create_prim_func([A, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, C])) # Verify we see SVE load instructions and eor instructions using z registers assembly = f.get_source("asm") @@ -440,7 +466,9 @@ def check_correct_assembly(type): A = te.placeholder(m, dtype=type, name="A") B = te.placeholder(m, dtype="int32", name="B") C = te.compute((m), lambda i: A[B[i]], name="C") - f = tvm.tir.build(te.create_prim_func([A, B, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, B, C])) # Verify we see gather instructions in the assembly assembly = f.get_source("asm") @@ -451,65 +479,6 @@ def check_correct_assembly(type): check_correct_assembly(type=dtype) -@pytest.mark.skipif( - llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" -) -def test_codegen_vscale(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - vscale = tvm.tir.vscale() - - @T.prim_func - def main(A: T.Buffer((5,), "int32")): - for i in range(5): - A[i] = 2 * vscale - - build_mod = tvm.tir.build(main, target=target) - llvm = build_mod.get_source() - - assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." - - -@pytest.mark.skipif( - llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" -) -def test_scalable_buffer_load_store(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def my_func(a: T.handle, b: T.handle): - A = T.match_buffer(a, (128,), "float32") - B = T.match_buffer(b, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] - - mod = tvm.tir.build(my_func, target=target) - llvm = mod.get_source("ll") - - assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." - assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." - - -@pytest.mark.skipif( - llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" -) -def test_scalable_broadcast(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def my_func(a: T.handle): - A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) - A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) - - mod = tvm.tir.build(my_func, target=target) - llvm = mod.get_source("ll") - - assert re.findall( - r"shufflevector \( insertelement \(", llvm - ), "No scalable broadcast in generated LLVM." - assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." - - @pytest.mark.skipif( llvm_version_major() < 13, reason="Function attribute vscale_range() is not supported in earlier versions of LLVM", @@ -529,7 +498,9 @@ def test_vscale_range_function_attribute(mattr, expect_attr): m = te.var("m") A = te.placeholder(m, dtype="float32", name="A") C = te.compute((m), lambda i: A[i] + 1, name="C") - f = tvm.tir.build(te.create_prim_func([A, C]), target=target) + + with tvm.target.Target(target): + f = tvm.tir.build(te.create_prim_func([A, C])) # Check if the vscale_range() attribute exists ll = f.get_source("ll") @@ -545,49 +516,5 @@ def test_vscale_range_function_attribute(mattr, expect_attr): ), f"Unexpected function attribute vscale_range() was found in generated LLVM IR" -@pytest.mark.skip( - reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", -) -def test_get_active_lane_mask(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def before(a: T.handle): - A = T.match_buffer(a, (30,), "int1") - for i in range(T.ceildiv(30, T.vscale() * 4)): - A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) - - with tvm.target.Target(target): - out = tvm.tir.build(before) - - ll = out.get_source("ll") - assert "get.active.lane.mask" in ll - - -@pytest.mark.skip( - reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", -) -def test_predicated_scalable_buffer(): - target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" - - @T.prim_func - def before(a: T.handle, b: T.handle): - A = T.match_buffer(a, (16,), "float32") - B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): - for i_1 in T.vectorized(4 * T.vscale()): - if i_0 * 4 * T.vscale() + i_1 < 14: - B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 - - with tvm.target.Target(target): - out = tvm.tir.build(before) - - ll = out.get_source("ll") - assert "get.active.lane.mask" in ll - assert "llvm.masked.load" in ll - assert "llvm.masked.store" in ll - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py b/tests/python/codegen/test_target_codegen_llvm_vla.py new file mode 100644 index 000000000000..7ca3083dd5e3 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_llvm_vla.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Codegen tests for VLA extensions +""" + +import re +import pytest + +import tvm +from tvm import te +from tvm.script import tir as T +from tvm.target.codegen import llvm_version_major + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_codegen_vscale(target): + vscale = tvm.tir.vscale() + + @T.prim_func + def main(A: T.Buffer((5,), "int32")): + for i in range(5): + A[i] = 2 * vscale + + with tvm.target.Target(target): + build_mod = tvm.tir.build(main) + + llvm = build_mod.get_source() + assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_scalable_buffer_load_store(target): + @T.prim_func + def my_func(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] + + with tvm.target.Target(target): + mod = tvm.tir.build(my_func) + + llvm = mod.get_source("ll") + assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_scalable_broadcast(target): + @T.prim_func + def my_func(a: T.handle): + A = T.match_buffer(a, (128,), "float32") + T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) + + with tvm.target.Target(target): + mod = tvm.tir.build(my_func) + + llvm = mod.get_source("ll") + assert re.findall( + r"shufflevector \( insertelement \(", llvm + ), "No scalable broadcast in generated LLVM." + assert re.findall(r" store ", llvm), "No scalable store in generated LLVM." + + +@pytest.mark.skip( + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_get_active_lane_mask(target): + @T.prim_func + def before(a: T.handle): + A = T.match_buffer(a, (30,), "int1") + for i in range(T.ceildiv(30, T.vscale() * 4)): + A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) + + with tvm.target.Target(target): + out = tvm.tir.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + + +@pytest.mark.skip( + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", +) +@tvm.testing.parametrize_targets( + "llvm -mtriple=aarch64-linux-gnu -mattr=+sve", + "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", +) +def test_predicated_scalable_buffer(target): + @T.prim_func + def before(a: T.handle, b: T.handle): + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): + for i_1 in T.vectorized(4 * T.vscale()): + if i_0 * 4 * T.vscale() + i_1 < 14: + B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 + + with tvm.target.Target(target): + out = tvm.tir.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + assert "llvm.masked.load" in ll + assert "llvm.masked.store" in ll + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index b661ce486981..89acf598d6e3 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -568,5 +568,60 @@ def kernel(): vulkan_codegen(mod, target) +@tvm.testing.requires_gpu +@tvm.testing.requires_vulkan +def test_unary(): + test_funcs = [ + (tvm.tir.sin, lambda x: np.sin(x)), + (tvm.tir.cos, lambda x: np.cos(x)), + (tvm.tir.tan, lambda x: np.tan(x)), + (tvm.tir.sinh, lambda x: np.sinh(x)), + (tvm.tir.cosh, lambda x: np.cosh(x)), + (tvm.tir.tanh, lambda x: np.tanh(x)), + (tvm.tir.asin, lambda x: np.arcsin(x)), + (tvm.tir.acos, lambda x: np.arccos(x)), + (tvm.tir.atan, lambda x: np.arctan(x)), + (tvm.tir.asinh, lambda x: np.arcsinh(x)), + (tvm.tir.acosh, lambda x: np.arccosh(x)), + (tvm.tir.atanh, lambda x: np.arctanh(x)), + ] + + def run_test(tvm_intrin, np_func): + m = te.var("m") + A = te.placeholder((m,), name="A", dtype="float32") + B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B") + + mod = te.create_prim_func([A, B]) + sch = tir.Schedule(mod) + + block = sch.get_block("B") + loop = sch.get_loops(block)[0] + bx, tx = sch.split(loop, factors=[None, 64]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + target = tvm.target.Target("vulkan") + dev = tvm.device(target.kind.name, 0) + func = tvm.compile(sch.mod, target=target) + + n = 16 + if tvm_intrin in [tvm.tir.asin, tvm.tir.acos]: + data = np.random.uniform(-1.0, 1.0, size=n) + elif tvm_intrin == tvm.tir.atanh: + data = np.random.uniform(-0.999, 0.999, size=n) + elif tvm_intrin == tvm.tir.acosh: + data = np.random.uniform(1.0, 5.0, size=n) + else: + data = np.random.uniform(0.1, 0.9, size=n) + + a = tvm.nd.array(data.astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev) + func(a, b) + tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) + + for func in test_funcs: + run_test(*func) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md index 28300dfdea4e..955cd58dc2ae 100644 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ b/tests/python/contrib/test_hexagon/README_RPC.md @@ -80,7 +80,7 @@ Which eventually jumps to the following line in C++, which creates a RPC client [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129) ```cpp -TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { +TVM_FFI_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto url = args[0].cast(); int port = args[1].cast(); auto key = args[2].cast(); @@ -94,7 +94,7 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body_packed([](ffi::PackedArgs args, ffi: [https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106) ```cpp -TVM_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") +TVM_FFI_REGISTER_GLOBAL("tvm.contrib.hexagon.create_hexagon_session") .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { auto session_name = args[0].cast(); int remote_stack_size_bytes = args[1].cast(); diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index ea9feb740319..2795f5630163 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -62,7 +62,7 @@ def test_vtcm_limit(vtcm_capacity, limited): def _raises_exception(f): try: f() - except tvm._ffi.base.TVMError: + except tvm.base.TVMError: return True return False diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index ba0287afc61a..5089336f09d3 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import dlight as dl from tvm import relax as rx -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.contrib import tvmjs from tvm.runtime import ShapeTuple from tvm.runtime import disco as di diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py b/tests/python/dlight/test_gpu_low_batch_gemv.py index 6341b7b0ae66..ae07a3b7318c 100644 --- a/tests/python/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/dlight/test_gpu_low_batch_gemv.py @@ -136,7 +136,7 @@ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T with T.block("NT_matmul_intermediate_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) - T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul_intermediate[v0, T.int64(0), v1]) NT_matmul_intermediate[v0, T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] @@ -240,7 +240,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float with T.block("NT_matmul_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(4096), u_fused_ax1_fused_fused_0 * T.int64(16) + ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) - T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size) T.reads(NT_matmul_pad_local[v0, T.int64(0), v1]) T.writes(NT_matmul[v0, T.int64(0), v1]) NT_matmul[v0, T.int64(0), v1] = NT_matmul_pad_local[v0, T.int64(0), v1] @@ -369,7 +369,7 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16" with T.block("C_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * T.int64(4) + ax0) v1 = T.axis.spatial(T.int64(8), ax1_fused_0_ax1_fused_1_fused * T.int64(2) + ax1_fused_2) - T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) + T.where((ax0_0 - (batch_size + T.int64(3)) // T.int64(4) < T.int64(0) or ax0_0 * T.int64(4) + ax0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 < batch_size and (T.Mul(T.int64(0), T.int64(16)) + ax1_fused_0_ax1_fused_1_fused % T.int64(16)) * T.int64(2) + ax1_fused_2 < T.int64(8)) T.reads(C_pad_local[v0, v1]) T.writes(C[v0, v1]) C[v0, v1] = C_pad_local[v0, v1] @@ -516,7 +516,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo with T.block("C_pad"): v0 = T.axis.spatial(batch_size, ax0_0 * 4 + ax0) v1 = T.axis.spatial(6144, ax1_fused_0 * 64 + ax1) - T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 == 0) and ax0_0 * 4 + ax0 < batch_size) + T.where((ax0_0 - (batch_size + 3) // 4 < 0 or ax0_0 * 4 + ax0 == 0) and ax0_0 * 4 + ax0 < batch_size) T.reads(C_pad_local[v0, 0, v1]) T.writes(C[v0, 0, v1]) C[v0, 0, v1] = C_pad_local[v0, 0, v1] diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 2fa61faf40f8..f27d9d370fce 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -695,7 +695,7 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), v0 = T.axis.spatial(T.int64(1), T.int64(0)) v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m) T.reads(matmul_pad_local[v0, v1, v2]) T.writes(matmul[v0, v1, v2]) matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2] @@ -835,7 +835,7 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len) T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2] diff --git a/tests/python/ffi/test_ndarray.py b/tests/python/ffi/test_ndarray.py index a5a6f5b07438..5b75171b55bb 100644 --- a/tests/python/ffi/test_ndarray.py +++ b/tests/python/ffi/test_ndarray.py @@ -14,6 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + +try: + import torch +except ImportError: + torch = None from tvm import ffi as tvm_ffi import numpy as np @@ -47,3 +53,24 @@ def test_shape_object(): shape3 = tvm_ffi.convert(shape) assert shape3.__tvm_ffi_object__.same_as(shape.__tvm_ffi_object__) assert isinstance(shape3, tvm_ffi.Shape) + + +@pytest.mark.skipif(torch is None, reason="Torch is not installed") +def test_ndarray_auto_dlpack(): + def check(x, y): + assert isinstance(y, tvm_ffi.NDArray) + assert y.shape == (128,) + assert y.dtype == tvm_ffi.dtype("int64") + assert y.device.device_type == tvm_ffi.Device.kDLCPU + assert y.device.device_id == 0 + x2 = torch.from_dlpack(y) + np.testing.assert_equal(x2.numpy(), x.numpy()) + + x = torch.arange(128) + fecho = tvm_ffi.get_global_func("testing.echo") + y = fecho(x) + check(x, y) + + # pass in list of tensors + y = fecho([x]) + check(x, y[0]) diff --git a/tests/python/meta_schedule/test_meta_schedule_builder.py b/tests/python/meta_schedule/test_meta_schedule_builder.py index a74ac893262f..090a393fbeeb 100644 --- a/tests/python/meta_schedule/test_meta_schedule_builder.py +++ b/tests/python/meta_schedule/test_meta_schedule_builder.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" Test Meta Schedule Builder """ +"""Test Meta Schedule Builder""" import os import sys @@ -25,7 +25,7 @@ import tvm.testing from tvm import script -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.meta_schedule.builder import ( BuilderInput, BuilderResult, diff --git a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py index 8d374c01ae19..57d9d0961088 100644 --- a/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_post_order_apply.py @@ -25,7 +25,7 @@ import tvm.testing from tvm import te from tvm.ir.module import IRModule -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule diff --git a/tests/python/meta_schedule/test_meta_schedule_runner.py b/tests/python/meta_schedule/test_meta_schedule_runner.py index 03ab8c58b48d..e5deefe7507c 100644 --- a/tests/python/meta_schedule/test_meta_schedule_runner.py +++ b/tests/python/meta_schedule/test_meta_schedule_runner.py @@ -25,7 +25,7 @@ import pytest import tvm import tvm.testing -from tvm._ffi import register_func +from tvm.ffi import register_func from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput, LocalBuilder from tvm.meta_schedule.runner import ( diff --git a/tests/python/meta_schedule/test_meta_schedule_space_generator.py b/tests/python/meta_schedule/test_meta_schedule_space_generator.py index ef2be381c694..9457a9a40f00 100644 --- a/tests/python/meta_schedule/test_meta_schedule_space_generator.py +++ b/tests/python/meta_schedule/test_meta_schedule_space_generator.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" Test Meta Schedule SpaceGenerator """ +"""Test Meta Schedule SpaceGenerator""" # pylint: disable=missing-function-docstring import math @@ -22,7 +22,7 @@ import pytest import tvm import tvm.testing -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.meta_schedule.space_generator import ( PySpaceGenerator, ScheduleFn, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py similarity index 70% rename from tests/python/relax/test_from_exported_to_cuda.py rename to tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py index 6bb35b50b1df..3f0964cfa8ed 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/nightly/test_nnapi/test_from_exported_to_cuda.py @@ -21,6 +21,7 @@ import numpy as np import torch from torch import nn +from torch.nn import functional as F from torch.export import export from tvm.relax.frontend.torch import from_exported_program from torch.nn import Softmax, Upsample @@ -742,5 +743,332 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_leakyrelu_module(target, dev): + class LeakyReLUModule(nn.Module): + def __init__(self): + super().__init__() + self.act = nn.LeakyReLU(negative_slope=0.1) + + def forward(self, x): + return self.act(x) + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = LeakyReLUModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_log_softmax_module(target, dev): + class LogSoftmaxModule(nn.Module): + def __init__(self): + super().__init__() + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, x): + return self.logsoftmax(x) + + raw_data = np.random.randn(4, 5).astype(np.float32) + torch_module = LogSoftmaxModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_softmax_module(target, dev): + class SoftmaxModule(nn.Module): + def __init__(self): + super().__init__() + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + return self.softmax(x) + + raw_data = np.random.randn(4, 5).astype(np.float32) + torch_module = SoftmaxModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_adaptive_avg_pool2d_module(target, dev): + class AdaptiveAvgPool2dModule(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return self.pool(x) + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = AdaptiveAvgPool2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_avg_pool2d_module(target, dev): + class AvgPool2dModule(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size=2) + + def forward(self, x): + return self.pool(x) + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = AvgPool2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_conv1d_module(target, dev): + class Conv1dModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(in_channels=3, out_channels=4, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + raw_data = np.random.randn(2, 3, 10).astype(np.float32) + torch_module = Conv1dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_conv2d_module(target, dev): + class Conv2dModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + raw_data = np.random.randn(2, 3, 10, 10).astype(np.float32) + torch_module = Conv2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_conv3d_module(target, dev): + class Conv3dModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv3d(in_channels=2, out_channels=3, kernel_size=3) + + def forward(self, x): + return self.conv(x) + + raw_data = np.random.randn(1, 2, 8, 8, 8).astype(np.float32) + torch_module = Conv3dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_group_norm_module(target, dev): + class GroupNormModule(nn.Module): + def __init__(self): + super().__init__() + self.gn = nn.GroupNorm(num_groups=1, num_channels=4) + + def forward(self, x): + return self.gn(x) + + raw_data = np.random.randn(2, 4, 8, 8).astype(np.float32) + torch_module = GroupNormModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_layer_norm_module(target, dev): + class LayerNormModule(nn.Module): + def __init__(self): + super().__init__() + self.ln = nn.LayerNorm(normalized_shape=8) + + def forward(self, x): + return self.ln(x) + + raw_data = np.random.randn(2, 4, 8).astype(np.float32) + torch_module = LayerNormModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_linear_module(target, dev): + class LinearModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + raw_data = np.random.randn(4, 10).astype(np.float32) + torch_module = LinearModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_max_pool2d_module(target, dev): + class MaxPool2dModule(nn.Module): + def __init__(self): + super().__init__() + self.pool = nn.MaxPool2d(kernel_size=2) + + def forward(self, x): + return self.pool(x) + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = MaxPool2dModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_embedding_module(target, dev): + class EmbeddingModule(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(num_embeddings=10, embedding_dim=3) + + def forward(self, x): + return self.embed(x) + + raw_data = np.random.randint(0, 10, (2, 4)).astype(np.int64) + torch_module = EmbeddingModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_flatten_module(target, dev): + class FlattenModule(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + + def forward(self, x): + return self.flatten(x) + + raw_data = np.random.randn(2, 3, 4, 5).astype(np.float32) + torch_module = FlattenModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_numel(target, dev): + class NumelModule(nn.Module): + def forward(self, x): + return torch.tensor(x.numel()) + + raw_data = np.random.randn(2, 3, 4).astype(np.float32) + torch_module = NumelModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_size(target, dev): + class SizeModule(nn.Module): + def forward(self, x): + return torch.tensor(x.size(0)) + + raw_data = np.random.randn(5, 4).astype(np.float32) + torch_module = SizeModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_tensor(target, dev): + class TensorModule(nn.Module): + def forward(self, x): + return torch.tensor([1, 2, 3]) + + raw_data = np.zeros((1,)).astype(np.float32) + torch_module = TensorModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_type(target, dev): + class TypeModule(nn.Module): + def forward(self, x): + return x.type(torch.float16) + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = TypeModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_float(target, dev): + class FloatModule(nn.Module): + def forward(self, x): + return x.float() + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = FloatModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_half(target, dev): + class HalfModule(nn.Module): + def forward(self, x): + return x.half() + + raw_data = np.random.randn(2, 3).astype(np.float32) + torch_module = HalfModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_getattr(target, dev): + class GetAttrModule(nn.Module): + def forward(self, x): + # Use getattr to call the ndimension method. + return torch.tensor(getattr(x, "ndimension")()) + + raw_data = np.random.randn(2, 3, 4).astype(np.float32) + torch_module = GetAttrModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_sym_size_int(target, dev): + class SymSizeIntModule(nn.Module): + def forward(self, x): + return torch.tensor(x.shape[1]) + + raw_data = np.random.randn(2, 3, 4).astype(np.float32) + torch_module = SymSizeIntModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_interpolate(target, dev): + class InterpolateModule(nn.Module): + def forward(self, x): + # Upsample to a fixed size. + return F.interpolate(x, size=(16, 16), mode="nearest") + + raw_data = np.random.randn(2, 3, 8, 8).astype(np.float32) + torch_module = InterpolateModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_cross_entropy_module(target, dev): + class CrossEntropyModule(nn.Module): + def __init__(self): + super().__init__() + self.criterion = nn.CrossEntropyLoss() + self.target = torch.tensor([0, 1, 2, 1]) + + def forward(self, x): + return self.criterion(x, self.target) + + raw_data = np.random.randn(4, 3).astype(np.float32) + torch_module = CrossEntropyModule().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/frontend_nn_extern_module.cc b/tests/python/relax/frontend_nn_extern_module.cc index 09adbe9780d6..1bac39b35091 100644 --- a/tests/python/relax/frontend_nn_extern_module.cc +++ b/tests/python/relax/frontend_nn_extern_module.cc @@ -21,8 +21,8 @@ * \brief Testing code to be compiled by Relax nn.SourceModule */ #include +#include #include -#include namespace { @@ -65,5 +65,5 @@ int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) { return 0; } } // namespace -TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add); -TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym); diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py index d0d3344eb61e..d3c78b3657a8 100644 --- a/tests/python/relax/test_binding_rewrite.py +++ b/tests/python/relax/test_binding_rewrite.py @@ -18,7 +18,7 @@ import pytest import tvm import tvm.testing -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.relax.analysis import name_to_binding from tvm.relax.binding_rewrite import DataflowBlockRewrite from tvm.relax.expr import DataflowVar, Var diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index 3deed8c2bfd3..fb1544be68a8 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -157,10 +157,6 @@ def Func1(x, y): tvm.testing.assert_allclose(opt_func(x, y), opt_func(x, y)) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def test_subgraph_capture(): import torch from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs @@ -178,13 +174,13 @@ class Expected1: @R.function def subgraph_0( inp_0: R.Tensor((10, 100), dtype="float32"), - w0: R.Tensor((10, 100), dtype="float32"), w1: R.Tensor((10,), dtype="float32"), + w0: R.Tensor((10, 100), dtype="float32"), ) -> R.Tensor((10, 10), dtype="float32"): # block 0 with R.dataflow(): - lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) - lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32") + lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(inp_0, axes=None) + lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(w0, lv, out_dtype="float32") lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1) lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2) gv: R.Tensor((10, 10), dtype="float32") = lv3 @@ -193,10 +189,7 @@ def subgraph_0( model = Input1() mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)) - binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} - binding = {k: tvm.nd.array(v) for k, v in binding.items()} - expected = relax.transform.BindParams("subgraph_0", binding)(Expected1) - tvm.ir.assert_structural_equal(mod, expected) + tvm.ir.assert_structural_equal(mod, Expected1) def Input2(a, b): x = a / (torch.sin(a) + 1) @@ -258,27 +251,20 @@ def subgraph_0( ) -> R.Tensor((10, 10), dtype="float32"): # block 0 with R.dataflow(): - lv0 = R.add(inp_0, R.const(1, "float32")) - lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) - lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(lv0, lv, out_dtype="float32") - lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1) - lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2) - gv: R.Tensor((10, 10), dtype="float32") = lv3 + lv: R.Tensor((10, 100), dtype="float32") = R.add(inp_0, R.const(1.0, "float32")) + lv1: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) + lv2: R.Tensor((10, 10), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32") + lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, w1) + lv4: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv3) + gv: R.Tensor((10, 10), dtype="float32") = lv4 R.output(gv) return gv model = Input3() mod = dynamo_capture_subgraphs(model, torch.randn(10, 100), add_one=True) - binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} - binding = {k: tvm.nd.array(v) for k, v in binding.items()} - expected = relax.transform.BindParams("subgraph_0", binding)(Expected3) - tvm.ir.assert_structural_equal(mod, expected) + tvm.ir.assert_structural_equal(mod, Expected3) -@pytest.mark.skipif( - version.parse(torch_version) >= version.parse("2.6.0"), - reason="Tests not compatible with PyTorch >= 2.6", -) def verify_dynamo_model(torch_model, input_info, binding, expected): import torch import torch._dynamo as dynamo diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f28ab6236df5..4c965bb6ffa8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -17,6 +17,7 @@ import operator import pytest import torch +from torch import nn from torch.nn import Module from torch.export import export @@ -513,6 +514,53 @@ def main( verify_model(MinModel(), example_args, {}, expected_min) + # relu6 + class ReLU6_1(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + return self.relu6(x) + + class ReLU6_2(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.relu6(x) + + class ReLU6_3(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.relu6_(x) + + @tvm.script.ir_module + class expected_relu6_1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + x, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(6.0)) + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_relu6_2: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu6(x) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(ReLU6_1(), example_args, {}, expected_relu6_1) + verify_model(ReLU6_2(), example_args, {}, expected_relu6_2) + verify_model(ReLU6_3(), example_args, {}, expected_relu6_2) + def test_hardtanh(): class Hardtanh(torch.nn.Module): @@ -1210,6 +1258,38 @@ def main( verify_model(model, example_args, binding, expected1) +def test_adaptive_avgpool1d(): + class AdaptiveAvgPool1d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool1d(output_size=5) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool1d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool1d(input, output_size=5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d( + input_1, output_size=[5], layout="NCW" + ) + gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): @@ -1243,6 +1323,38 @@ def main( verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) +def test_adaptive_avgpool3d(): + class AdaptiveAvgPool3d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d([4, 4, 4]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool3d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool3d(input, [4, 4, 4]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.adaptive_avg_pool3d( + input_1, output_size=[4, 4, 4], layout="NCDHW", out_layout="NCDHW" + ) + gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) + verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1) + + def test_addmm(): class Addmm1(Module): def __init__(self): @@ -1302,6 +1414,102 @@ def main( verify_model(Addmm2(), example_args, {}, expected2) +def test_avg_pool1d(): + class AvgPool1d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.avg_pool1d( + input_1, + pool_size=[1], + strides=[1], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=3, stride=2, padding=1, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool1d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d( + input, kernel_size=3, stride=2, padding=1, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[3], + strides=[2], + dilation=[1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool1d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d(input, kernel_size=2, stride=2, padding=0) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + verify_model(AvgPool1d1(), example_args, {}, expected1) + verify_model(AvgPool1d2(), example_args, {}, expected2) + verify_model(AvgPool1d3(), example_args, {}, expected2) + verify_model(AvgPool1d4(), example_args, {}, expected3) + + def test_avg_pool2d(): class AvgPool2d1(Module): def __init__(self): @@ -1395,6 +1603,102 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): verify_model(AvgPool2d4(), example_args, {}, expected3) +def test_avg_pool3d(): + class AvgPool3d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.avg_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool3d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d( + input, kernel_size=3, stride=2, padding=1, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + ceil_mode=True, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool3d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d(input, kernel_size=[2, 1, 2], stride=[2, 1, 2]) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[2, 1, 2], + strides=[2, 1, 2], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) + verify_model(AvgPool3d1(), example_args, {}, expected1) + verify_model(AvgPool3d2(), example_args, {}, expected2) + verify_model(AvgPool3d3(), example_args, {}, expected2) + verify_model(AvgPool3d4(), example_args, {}, expected3) + + def test_baddbmm(): class BAddBMM1(Module): def __init__(self): @@ -1566,9 +1870,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -1600,9 +1905,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -1658,9 +1964,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -1692,9 +1999,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -2166,6 +2474,30 @@ def main( verify_model(Einsum2(), example_args, {}, Expected2) +def test_outer(): + class Outer(torch.nn.Module): + def forward(self, x, y): + return torch.outer(x, y) + + @tvm.script.ir_module + class expected: + @R.function + def main( + a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(3, dtype=torch.float32), + torch.randn(4, dtype=torch.float32), + ) + verify_model(Outer(), example_args, {}, expected) + + def test_embedding(): class Embedding(Module): def __init__(self): @@ -2409,6 +2741,101 @@ def main( verify_model(model, example_args, binding, expected2) +def test_maxpool1d(): + class MaxPool1d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=2) + + def forward(self, input): + return self.pool(input) + + class MaxPool1d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool1d(input, kernel_size=2) + + class MaxPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool1d( + input_1, + pool_size=[3], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv = (lv,) + R.output(gv) + return gv + + # Example inputs + example_args1 = (torch.randn(1, 3, 8, dtype=torch.float32),) + example_args2 = (torch.randn(1, 3, 8, dtype=torch.float32),) + example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),) + + # Verify the models + verify_model(MaxPool1d(), example_args1, {}, expected1) + verify_model(MaxPool1d_functional(), example_args2, {}, expected2) + verify_model(MaxPool1d2(), example_args3, {}, expected3) + + def test_maxpool2d(): class MaxPool2d(Module): def __init__(self): @@ -2511,6 +2938,110 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_maxpool3d(): + class MaxPool3d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1]) + + def forward(self, input): + return self.pool(input) + + class MaxPool3d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class MaxPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[2, 2, 2]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool3d( + input_1, + pool_size=[2, 2, 2], + strides=[2, 2, 2], + dilation=[2, 2, 2], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class MaxPool3d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")): + with R.dataflow(): + lv = R.nn.max_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + layout="NCDHW", + out_layout="NCDHW", + ) + gv = (lv,) + R.output(gv) + return gv + + # Example input tensors + example_args1 = (torch.randn(1, 3, 4, 4, 4, dtype=torch.float32),) + example_args2 = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) + example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) + + # Verify the models with expected IR modules + verify_model(MaxPool3d(), example_args1, {}, expected1) + verify_model(MaxPool3d_functional(), example_args1, {}, expected1) + verify_model(MaxPool3d2(), example_args2, {}, expected2) + verify_model(MaxPool3d3(), example_args3, {}, expected3) + + def test_scaled_dot_product_attention(): class Attention1(Module): def forward(self, q, k, v): @@ -2622,8 +3153,7 @@ def main( R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((0, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + ) = R.split(input_1, indices_or_sections=3, axis=0) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] @@ -2666,8 +3196,7 @@ def main( R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 0, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + ) = R.split(input_1, indices_or_sections=3, axis=1) lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] @@ -2716,7 +3245,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void", @@ -2745,7 +3274,36 @@ def main( method="nearest_neighbor", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class InterpolateBicubic(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="bicubic") + + @tvm.script.ir_module + class expected_bicubic: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="cubic", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void", @@ -2757,6 +3315,7 @@ def main( example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) verify_model(InterpolateNearest(), example_args, {}, expected_nearest) + verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) def test_mean(): @@ -3409,6 +3968,51 @@ def main( verify_model(Slice2(), example_args, {}, expected2) +def test_slice_scatter(): + class SliceScatter1(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=1, start=1, end=7, step=2) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + a: R.Tensor((8, 8, 10, 10), dtype="float32"), + b: R.Tensor((8, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 8, 10, 10), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), axis=1 + ) + gv: R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class SliceScatter2(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=0, start=0, end=6, step=1) + + @I.ir_module + class expected2: + @R.function + def main( + a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((8, 16), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), axis=0 + ) + gv: R.Tuple(R.Tensor((8, 16), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), torch.randn(8, 3, 10, 10)) + verify_model(SliceScatter1(), example_args, {}, expected1) + + example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 16)) + verify_model(SliceScatter2(), example_args, {}, expected2) + + def test_split(): class Chunk(Module): def forward(self, input): @@ -3462,8 +4066,7 @@ def main( R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((0, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + ) = R.split(input_1, indices_or_sections=3, axis=0) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] @@ -3506,8 +4109,7 @@ def main( R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 0, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + ) = R.split(input_1, indices_or_sections=3, axis=1) lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] @@ -5140,6 +5742,36 @@ def main( verify_model(Eye2(), example_args2, {}, Expected2) +def test_cross_entropy(): + class CrossEntropyModule(Module): + def __init__(self): + super().__init__() + self.criterion = nn.CrossEntropyLoss() + self.target = torch.tensor([0, 1, 2, 1]) + + def forward(self, x): + return self.criterion(x, self.target) + + @tvm.script.ir_module + class Expected1: + @R.function + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) + lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( + lv, + targets=R.const([0, 1, 2, 1], dtype="int64"), + reduction="mean", + ignore_index=-100, + ) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args1 = (torch.randn(4, 3, dtype=torch.float32),) + verify_model(CrossEntropyModule(), example_args1, {}, Expected1) + + def test_linspace(): class Linspace(Module): def forward(self, input): @@ -5161,32 +5793,43 @@ def main( verify_model(Linspace(), example_args, {}, Expected) -def test_bfloat16(): - # TODO(mshr-h): Add tests for all the dtypes supported in fx frontend +@pytest.mark.parametrize( + "torch_dtype, relax_dtype", + [ + (torch.float32, "float32"), + (torch.float16, "float16"), + (torch.bfloat16, "bfloat16"), + (torch.int64, "int64"), + (torch.int32, "int32"), + (torch.bool, "bool"), + ], +) +def test_dtypes(torch_dtype, relax_dtype): example_args = ( - torch.randn(10, 10, dtype=torch.bfloat16), - torch.randn(10, 10, dtype=torch.bfloat16), + torch.randint(0, 10, (10, 10)).to(torch_dtype), + torch.randint(0, 10, (10, 10)).to(torch_dtype), ) - class BFloat16Model(Module): + class Model(Module): def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): return torch.ops.aten.add(lhs, rhs) @tvm.script.ir_module - class expected: + class Expected: @R.function def main( - lhs: R.Tensor((10, 10), dtype="bfloat16"), - rhs: R.Tensor((10, 10), dtype="bfloat16"), - ) -> R.Tuple(R.Tensor((10, 10), dtype="bfloat16")): + lhs: R.Tensor((10, 10), dtype=relax_dtype), + rhs: R.Tensor((10, 10), dtype=relax_dtype), + ) -> R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs) - gv: R.Tuple(R.Tensor((10, 10), dtype="bfloat16")) = (lv,) + lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, rhs) + gv: R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)) = (lv,) R.output(gv) return gv - verify_model(BFloat16Model(), example_args, {}, expected) + verify_model(Model(), example_args, {}, Expected) if __name__ == "__main__": tvm.testing.main() +1 diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 60c02df6cc98..00c61bd31f23 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -168,9 +168,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -202,9 +203,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -352,9 +354,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -386,9 +389,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -874,6 +878,27 @@ def main( verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) +def test_outer(): + class Outer(torch.nn.Module): + def forward(self, x, y): + return torch.outer(x, y) + + @tvm.script.ir_module + class expected: + @R.function + def main( + a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32") + ) -> R.Tensor((3, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b) + gv: R.Tensor((3, 4), dtype="float32") = lv + R.output(gv) + return gv + + input_infos = [([3], "float32"), ([4], "float32")] + verify_model(Outer(), input_infos, {}, expected) + + @tvm.testing.requires_gpu def test_softplus(): import torch @@ -984,6 +1009,106 @@ def main( verify_model(Prelu2(), input_info, {}, expected) +def test_maxpool1d(): + input_info = [([1, 3, 10], "float32")] + + class MaxPool1d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=2) + + def forward(self, input): + return self.pool(input) + + class MaxPool1d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool1d(input, kernel_size=2) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.max_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv: R.Tensor((1, 3, 5), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.max_pool1d( + input_1, + pool_size=[3], + strides=[1], + dilation=[1], + padding=[1, 1], + layout="NCW", + out_layout="NCW", + ) + gv: R.Tensor((1, 3, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool1d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2, dilation=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 3), dtype="float32"): # Corrected here + with R.dataflow(): + lv: R.Tensor((1, 3, 3), dtype="float32") = R.nn.max_pool1d( + input_1, + pool_size=[3], + strides=[2], + dilation=[2], + padding=[0, 0], + layout="NCW", + out_layout="NCW", + ) + gv: R.Tensor((1, 3, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool1d(), input_info, {}, expected1) + verify_model(MaxPool1d_functional(), input_info, {}, expected1) + verify_model(MaxPool1d2(), input_info, {}, expected2) + verify_model(MaxPool1d3(), input_info, {}, expected3) + + def test_maxpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -1087,6 +1212,203 @@ def main( verify_model(MaxPool2d3(), input_info, {}, expected3) +def test_maxpool3d(): + input_info = [([1, 3, 10, 10, 10], "float32")] + + class MaxPool3d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1]) + + def forward(self, input): + return self.pool(input) + + class MaxPool3d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = R.nn.max_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 10, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[1, 2, 2]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 4, 4), dtype="float32"): # Fixed here + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = R.nn.max_pool3d( + input_1, + pool_size=[2, 2, 2], + strides=[2, 2, 2], + dilation=[1, 2, 2], + padding=[0, 0, 0, 0, 0, 0], + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 5, 4, 4), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool3d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5, 5), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.nn.max_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool3d(), input_info, {}, expected1) + verify_model(MaxPool3d_functional(), input_info, {}, expected1) + verify_model(MaxPool3d2(), input_info, {}, expected2) + verify_model(MaxPool3d3(), input_info, {}, expected3) + + +def test_avgpool1d(): + input_info = [([1, 3, 10], "float32")] + + class AvgPool1d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=1) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10), dtype="float32"): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[1], + strides=[1], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = lv + R.output(gv) + return gv + + class AvgPool1d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool1d(kernel_size=4, stride=2, padding=2, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool1d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d( + input, kernel_size=4, stride=2, padding=2, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[4], + strides=[2], + dilation=[1], + padding=[2, 2], + ceil_mode=True, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = lv + R.output(gv) + return gv + + class AvgPool1d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool1d(input, kernel_size=2) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool1d( + input_1, + pool_size=[2], + strides=[2], + dilation=[1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCW", + out_layout="NCW", + ) + gv = lv + R.output(gv) + return gv + + verify_model(AvgPool1d(), input_info, {}, expected1) + verify_model(AvgPool1d2(), input_info, {}, expected2) + verify_model(AvgPool1d3(), input_info, {}, expected2) + verify_model(AvgPool1d4(), input_info, {}, expected3) + + def test_avgpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -1181,6 +1503,138 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): verify_model(AvgPool2d4(), input_info, {}, expected3) +def test_avgpool3d(): + input_info = [([1, 3, 8, 8, 8], "float32")] + + class AvgPool3d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d(kernel_size=[1, 1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") + ) -> R.Tensor((1, 3, 8, 8, 8), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.avg_pool3d( + input_1, + pool_size=[1, 1, 1], + strides=[1, 1, 1], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = lv + R.output(gv) + return gv + + class AvgPool3d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], stride=2, padding=1, ceil_mode=True + ) + + def forward(self, input): + return self.pool(input) + + class AvgPool3d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d( + input, kernel_size=[3, 3, 3], stride=2, padding=1, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[3, 3, 3], + strides=[2, 2, 2], + dilation=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + ceil_mode=True, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = lv + R.output(gv) + return gv + + class AvgPool3d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool3d(input, kernel_size=[2, 1, 2], stride=[2, 1, 2]) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool3d( + input_1, + pool_size=[2, 1, 2], + strides=[2, 1, 2], + dilation=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + ceil_mode=False, + count_include_pad=True, + layout="NCDHW", + out_layout="NCDHW", + ) + gv = lv + R.output(gv) + return gv + + verify_model(AvgPool3d(), input_info, {}, expected1) + verify_model(AvgPool3d2(), input_info, {}, expected2) + verify_model(AvgPool3d3(), input_info, {}, expected2) + verify_model(AvgPool3d4(), input_info, {}, expected3) + + +def test_adaptive_avgpool1d(): + input_info = [([1, 3, 16], "float32")] + + class AdaptiveAvgPool1d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool1d(8) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool1d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool1d(input, 8) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 16), dtype="float32") + ) -> R.Tensor((1, 3, 8), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 8), dtype="float32") = R.nn.adaptive_avg_pool1d( + input_1, output_size=[8], layout="NCW", out_layout="NCW" + ) + gv: R.Tensor((1, 3, 8), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool1d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool1d1(), input_info, {}, expected1) + + def test_adaptive_avgpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -1215,6 +1669,39 @@ def main( verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1) +def test_adaptive_avgpool3d(): + input_info = [([1, 3, 16, 16, 16], "float32")] + + class AdaptiveAvgPool3d0(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((8, 8, 8)) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool3d1(torch.nn.Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool3d(input, (8, 8, 8)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 16, 16, 16), dtype="float32") + ) -> R.Tensor((1, 3, 8, 8, 8), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.adaptive_avg_pool3d( + input_1, output_size=[8, 8, 8], layout="NCDHW", out_layout="NCDHW" + ) + gv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool3d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool3d1(), input_info, {}, expected1) + + def test_flatten(): input_info = [([1, 3, 10, 10], "float32")] @@ -2772,27 +3259,32 @@ def main( verify_model(ReLU1(), input_info, {}, expected_relu) # relu6 - class ReLU6(Module): + class ReLU6_1(torch.nn.Module): def __init__(self): super().__init__() self.relu6 = torch.nn.ReLU6() - def forward(self, input): - return self.relu6(input) + def forward(self, x): + return self.relu6(x) + + class ReLU6_2(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.relu6(x) @tvm.script.ir_module - class expected_relu6: + class expected: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0, 6) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu6(inp_0) gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv - verify_model(ReLU6(), input_info, {}, expected_relu6) + verify_model(ReLU6_1(), input_info, {}, expected) + verify_model(ReLU6_2(), input_info, {}, expected) # selu class Selu1(Module): @@ -3031,7 +3523,7 @@ def main( method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="", @@ -3068,7 +3560,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="", @@ -3105,7 +3597,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="", @@ -3116,6 +3608,43 @@ def main( verify_model(Interpolate3(), input_info, {}, expected3) + class Interpolate4(Module): + def forward(self, input): + return torch.nn.functional.interpolate( + input, + size=None, + scale_factor=(2.0, 1.0), + mode="bicubic", + align_corners=False, + ) + + @tvm.script.ir_module + class expected4: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 20, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 20, 10), dtype="float32") = R.image.resize2d( + input_1, + (20, 10), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="cubic", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.75, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 20, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate4(), input_info, {}, expected4) + def test_addmm(): input_info = [ @@ -3264,8 +3793,7 @@ def main( R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32"), - R.Tensor((0, 3, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + ) = R.split(input_1, indices_or_sections=3, axis=0) lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] @@ -3301,8 +3829,7 @@ def main( R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), R.Tensor((3, 1, 10, 10), dtype="float32"), - R.Tensor((3, 0, 10, 10), dtype="float32"), - ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + ) = R.split(input_1, indices_or_sections=3, axis=1) lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] @@ -4553,6 +5080,51 @@ def main( verify_model(Scatter(), input_info, {}, expected) +def test_slice_scatter(): + class SliceScatter1(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=1, start=1, end=7, step=2) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + a: R.Tensor((8, 8, 10, 10), dtype="float32"), + b: R.Tensor((8, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((8, 8, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 8, 10, 10), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), axis=1 + ) + gv: R.Tensor((8, 8, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class SliceScatter2(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=0, start=0, end=6, step=1) + + @I.ir_module + class expected2: + @R.function + def main( + a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), dtype="float32") + ) -> R.Tensor((8, 16), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), axis=0 + ) + gv: R.Tensor((8, 16), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10), "float32")], {}, expected1 + ) + + verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16), "float32")], {}, expected2) + + def test_masked_scatter(): class MaskedScatter1(Module): def forward(self, data, mask, src): @@ -5542,9 +6114,19 @@ def main( verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected) -def test_bfloat16(): - # TODO(mshr-h): Add tests for all the dtypes supported in EP frontend - class BFloat16Model(Module): +@pytest.mark.parametrize( + "torch_dtype, relax_dtype", + [ + (torch.float32, "float32"), + (torch.float16, "float16"), + (torch.bfloat16, "bfloat16"), + (torch.int64, "int64"), + (torch.int32, "int32"), + (torch.bool, "bool"), + ], +) +def test_dtypes(torch_dtype, relax_dtype): + class Model(Module): def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): return torch.ops.aten.add(lhs, rhs) @@ -5552,16 +6134,16 @@ def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): class Expected: @R.function def main( - lhs: R.Tensor((10, 10), dtype="bfloat16"), - rhs: R.Tensor((10, 10), dtype="bfloat16"), - ) -> R.Tensor((10, 10), dtype="bfloat16"): + lhs: R.Tensor((10, 10), dtype=relax_dtype), + rhs: R.Tensor((10, 10), dtype=relax_dtype), + ) -> R.Tensor((10, 10), dtype=relax_dtype): with R.dataflow(): - lv: R.Tensor((10, 10), dtype="bfloat16") = relax.op.add(lhs, rhs) - gv: R.Tensor((10, 10), dtype="bfloat16") = lv + lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, rhs) + gv: R.Tensor((10, 10), dtype=relax_dtype) = lv R.output(gv) return gv - verify_model(BFloat16Model(), [([10, 10], "bfloat16"), ([10, 10], "bfloat16")], {}, Expected) + verify_model(Model(), [([10, 10], torch_dtype), ([10, 10], torch_dtype)], {}, Expected) def test_eye(): diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index ef97cfd9056c..5aa48dd18aba 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -119,8 +119,8 @@ def test_sym( def _compile_cc(src: Path, dst: Path): # pylint: disable=import-outside-toplevel - from tvm._ffi.base import py_str - from tvm._ffi.libinfo import find_include_path + from tvm.base import py_str + from tvm.libinfo import find_include_path # pylint: enable=import-outside-toplevel diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 483e48217d92..5c400ef8be28 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -305,7 +305,7 @@ def test( method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="void", @@ -385,6 +385,7 @@ def test_nn(): class Model(Module): def test(self, x: Tensor, weight: Tensor, bias: Tensor): relu_out = op.relu(x) + relu6_out = op.relu6(x) silu_out = op.silu(x) gelu_out = op.gelu(x) sigmoid_out = op.sigmoid(x) @@ -409,6 +410,7 @@ def test( R.func_attr({"num_input": 4}) with R.dataflow(): relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x) + relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x) silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x) sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 9de77937480e..6c3334f64d12 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -402,13 +402,11 @@ def test_binary_bool(op_name: str): verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.BOOL) -@pytest.mark.skip(reason="opset 18 is not supported in CI") @pytest.mark.parametrize("op_name", ["BitwiseAnd", "BitwiseOr", "BitwiseXor"]) def test_bitwise(op_name: str): verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.UINT64, opset=18) -@pytest.mark.skip(reason="opset 18 is not supported in CI") def test_bitwise_not(): verify_unary( "BitwiseNot", @@ -447,9 +445,9 @@ def test_bitwise_shift(direction: str): "Sinh", "Cosh", "Tanh", - "Asin", - "Acos", - "Atan", + # "Asin", // TODO @jikechao, fix the precision loss due to the Taylor approximation + # "Acos", + # "Atan", "Asinh", "Acosh", "Atanh", @@ -945,7 +943,6 @@ def test_selu(): verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3}) -@pytest.mark.skip(reason="opset 18 is not supported in CI") def test_mish(): verify_unary("Mish", [3, 32, 32], opset=18) @@ -1303,6 +1300,24 @@ def test_layer_norm(): model = helper.make_model(graph, producer_name="layer_norm_test") check_correctness(model) + # Test case with no bias that is an optional input + layer_norm_node = helper.make_node("LayerNormalization", ["a", "b"], ["d"], epsilon=1e-12) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), + ], + outputs=[ + helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_test") + check_correctness(model) + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) diff --git a/tests/python/relax/test_op_distributed.py b/tests/python/relax/test_op_distributed.py index dc5440c0c113..2290174e464f 100644 --- a/tests/python/relax/test_op_distributed.py +++ b/tests/python/relax/test_op_distributed.py @@ -16,7 +16,7 @@ # under the License. import pytest import tvm -from tvm._ffi.base import TVMError +from tvm.base import TVMError import tvm.testing from tvm import relax from tvm.script.parser import relax as R diff --git a/tests/python/relax/test_op_grad.py b/tests/python/relax/test_op_grad.py index 5ae30adb70ef..a8a17c7a135a 100644 --- a/tests/python/relax/test_op_grad.py +++ b/tests/python/relax/test_op_grad.py @@ -16,7 +16,7 @@ # under the License. import pytest import tvm -from tvm._ffi.base import TVMError +from tvm.base import TVMError import tvm.testing from tvm import relax from tvm.ir import Op diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 1bf4444848b7..a0ff507ef880 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -74,9 +74,12 @@ def test_linear_unit_infer_struct_info(): _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), "float32")) _check_inference(bb, relax.op.nn.relu(x6), relax.TensorStructInfo((2, 3), "float32", vdev0)) + _check_inference(bb, relax.op.nn.relu6(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.relu6(x6), relax.TensorStructInfo((2, 3), "float32", vdev0)) _check_inference(bb, relax.op.nn.silu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) _check_inference(bb, relax.op.nn.gelu(x2), relax.TensorStructInfo(dtype="float32")) _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.nn.relu6(x3), relax.TensorStructInfo((2, 3), dtype="")) _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 3), "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 4), dtype="")) @@ -93,6 +96,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic(): _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo((4, n), "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, n), "float32")) _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, n), "float32")) @@ -106,6 +110,7 @@ def test_linear_unit_infer_struct_info_shape_var(): _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.relu6(x1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, "float32")) diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 2533a2fcadcb..d4461a122de8 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -25,9 +25,17 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x1", R.Tensor((2, 3, 64), "float32")) + x2 = relax.Var("x2", R.Tensor((2, 3, 8, 28, 28), "float32")) + assert relax.op.nn.max_pool1d(x1).op == Op.get("relax.nn.max_pool1d") assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d") + assert relax.op.nn.max_pool3d(x2).op == Op.get("relax.nn.max_pool3d") + assert relax.op.nn.avg_pool1d(x).op == Op.get("relax.nn.avg_pool1d") assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d") + assert relax.op.nn.avg_pool3d(x).op == Op.get("relax.nn.avg_pool3d") + assert relax.op.nn.adaptive_avg_pool1d(x).op == Op.get("relax.nn.adaptive_avg_pool1d") assert relax.op.nn.adaptive_avg_pool2d(x).op == Op.get("relax.nn.adaptive_avg_pool2d") + assert relax.op.nn.adaptive_avg_pool3d(x).op == Op.get("relax.nn.adaptive_avg_pool3d") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -35,6 +43,197 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) +def test_max_pool1d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor(ndim=3)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor()) + x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) + + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32")) + _check_inference( + bb, relax.op.nn.max_pool1d(x5), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, pool_size=3), relax.TensorStructInfo((2, 3, 30), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, strides=2), relax.TensorStructInfo((2, 3, 16), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, padding=1), relax.TensorStructInfo((2, 3, 34), "float32") + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x0, dilation=2), relax.TensorStructInfo((2, 3, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x0, layout="NCW", out_layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32"), + ) + _check_inference( + bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference( + bb, relax.op.nn.max_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.nn.max_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) + + +def test_max_pool1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + w = tir.Var("w", "int64") + c16 = tir.Var("c16", "int64") + + x0 = relax.Var("x", R.Tensor((n, c, w), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, w, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool1d(x0, pool_size=3, strides=3, padding=2, dilation=2), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(w - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x1, layout="NCW16c", out_layout="NWC"), + relax.TensorStructInfo((n, w, c * 16), "float32"), + ) + + +def test_max_pool1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x1, layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_max_pool1d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool1d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool1d(x, pool_size=5, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15), "float32"), + ) + + +def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + w = tir.Var("w", "int64") + x = relax.Var("x", R.Tensor((n, c, w), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(w, 2)), "float32"), + ) + + +def test_max_pool1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) + + _check_inference(bb, relax.op.nn.max_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16")) + _check_inference(bb, relax.op.nn.max_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8")) + _check_inference(bb, relax.op.nn.max_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64")) + + +def test_max_pool1d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + max_pool1d = relax.op.nn.max_pool1d(x, pool_size=3, strides=1, padding=1, dilation=1) + + assert max_pool1d.attrs.strides[0].dtype == "int64" + assert max_pool1d.attrs.padding[0].dtype == "int64" + assert max_pool1d.attrs.padding[1].dtype == "int64" + assert max_pool1d.attrs.dilation[0].dtype == "int64" + + +def test_max_pool1d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, pool_size=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, strides=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.max_pool1d(x, dilation=(1, 2)) + + +def test_max_pool1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x, layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x, out_layout="OWI")) + + +def test_max_pool1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=5)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x1)) + + +def test_max_pool1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.max_pool1d(x1)) + + def test_max_pool2d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") @@ -265,313 +464,1176 @@ def test_max_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.max_pool2d(x1)) -def test_avg_pool2d_infer_struct_info(): +def test_max_pool3d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") - x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) - x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) - x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x0 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 16, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=5)) x3 = relax.Var("x", R.Tensor("float32")) - x4 = relax.Var("x", R.Tensor(ndim=4)) + x4 = relax.Var("x", R.Tensor(ndim=5)) x5 = relax.Var("x", R.Tensor()) - x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) - x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) + x6 = relax.Var("x", R.Tensor((2, 4, 16, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 16, 32, 32), "float32", vdev0)) _check_inference( - bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 16, 32, 32), "float32") ) _check_inference( - bb, relax.op.nn.avg_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0) + bb, relax.op.nn.max_pool3d(x7), relax.TensorStructInfo((2, 3, 16, 32, 32), "float32", vdev0) ) _check_inference( bb, - relax.op.nn.avg_pool2d(x0, pool_size=3), - relax.TensorStructInfo((2, 3, 30, 30), "float32"), + relax.op.nn.max_pool3d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 14, 30, 30), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x0, pool_size=(5, 3)), - relax.TensorStructInfo((2, 3, 28, 30), "float32"), + relax.op.nn.max_pool3d(x0, pool_size=(3, 5, 3)), + relax.TensorStructInfo((2, 3, 14, 28, 30), "float32"), ) _check_inference( - bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + bb, + relax.op.nn.max_pool3d(x0, padding=1), + relax.TensorStructInfo((2, 3, 18, 34, 34), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x0, padding=[1, 2]), - relax.TensorStructInfo((2, 3, 34, 36), "float32"), + relax.op.nn.max_pool3d(x0, padding=[1, 2, 3]), + relax.TensorStructInfo((2, 3, 18, 36, 38), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x0, strides=2), - relax.TensorStructInfo((2, 3, 16, 16), "float32"), + relax.op.nn.max_pool3d(x0, strides=2), + relax.TensorStructInfo((2, 3, 8, 16, 16), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x0, dilation=2), - relax.TensorStructInfo((2, 3, 32, 32), "float32"), + relax.op.nn.max_pool3d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 16, 32, 32), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x1, layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.op.nn.max_pool3d(x1, layout="NDHWC"), + relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x0, out_layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.op.nn.max_pool3d(x0, out_layout="NDHWC"), + relax.TensorStructInfo((2, 16, 32, 32, 3), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), - relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + relax.op.nn.max_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), + relax.TensorStructInfo((2, 16, 32, 32, 4, 16), "float32"), ) _check_inference( - bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) ) _check_inference( - bb, relax.op.nn.avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.max_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) ) - _check_inference(bb, relax.op.nn.avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) - _check_inference(bb, relax.op.nn.avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.max_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.max_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5)) -def test_avg_pool2d_infer_struct_info_shape_symbolic(): +def test_max_pool3d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") c = tir.Var("c", "int64") c16 = tir.Var("c16", "int64") + id = tir.Var("id", "int64") ih = tir.Var("ih", "int64") iw = tir.Var("iw", "int64") - x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) - x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32")) _check_inference( bb, - relax.op.nn.avg_pool2d( - x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) + relax.op.nn.max_pool3d( + x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), dilation=(2, 2, 2) ), relax.TensorStructInfo( ( n, c, + tvm.tir.floordiv(id - 1, 3) + 1, tvm.tir.floordiv(ih - 1, 3) + 1, tvm.tir.floordiv(iw - 1, 3) + 1, ), "float32", ), ) + _check_inference( bb, - relax.op.nn.avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), - relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + relax.op.nn.max_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), + relax.TensorStructInfo((n, id, ih, iw, c * 16), "float32"), ) -def test_avg_pool2d_infer_struct_info_shape_var(): +def test_max_pool3d_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) - s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) s2 = relax.Var("s", relax.ShapeStructInfo()) x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) _check_inference( - bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo(dtype="float32", ndim=5) ) _check_inference( bb, - relax.op.nn.avg_pool2d(x1, layout="NCHW16c"), - relax.TensorStructInfo(dtype="float32", ndim=5), + relax.op.nn.max_pool3d(x1, layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x2), - relax.TensorStructInfo(dtype="float32", ndim=4), + relax.op.nn.max_pool3d(x2), + relax.TensorStructInfo(dtype="float32", ndim=5), ) -def test_avg_pool2d_infer_struct_info_ceil_mode(): +def test_max_pool3d_infer_struct_info_ceil_mode(): bb = relax.BlockBuilder() - x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) _check_inference( bb, - relax.op.nn.avg_pool2d(x, pool_size=3, strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 16, 16), "float32"), + relax.op.nn.max_pool3d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), ) _check_inference( bb, - relax.op.nn.avg_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), - relax.TensorStructInfo((2, 3, 15, 16), "float32"), + relax.op.nn.max_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"), ) -def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): +def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") c = tir.Var("c", "int64") + id_ = tir.Var("id", "int64") ih = tir.Var("ih", "int64") iw = tir.Var("iw", "int64") - x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) _check_inference( bb, - relax.op.nn.avg_pool2d( - x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True + relax.op.nn.max_pool3d( + x, + pool_size=(3, 3, 3), + strides=(2, 2, 2), + padding=(1, 1, 1), + dilation=(2, 2, 2), + ceil_mode=True, + ), + relax.TensorStructInfo( + (n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), + "float32", ), - relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), ) -def test_avg_pool2d_infer_struct_info_more_input_dtype(): +def test_max_pool3d_infer_struct_info_more_input_dtype(): bb = relax.BlockBuilder() - x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) - x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) - x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) _check_inference( - bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + bb, relax.op.nn.max_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") ) - _check_inference(bb, relax.op.nn.avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) _check_inference( - bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + bb, relax.op.nn.max_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.max_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") ) -def test_avg_pool2d_stride_padding_dilation_int64(): - x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) - avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) +def test_max_pool3d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + max_pool3d = relax.op.nn.max_pool3d( + x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1) + ) - assert avg_pool2d.attrs.strides[0].dtype == "int64" - assert avg_pool2d.attrs.strides[1].dtype == "int64" - assert avg_pool2d.attrs.padding[0].dtype == "int64" - assert avg_pool2d.attrs.padding[1].dtype == "int64" - assert avg_pool2d.attrs.padding[2].dtype == "int64" - assert avg_pool2d.attrs.padding[3].dtype == "int64" - assert avg_pool2d.attrs.dilation[0].dtype == "int64" - assert avg_pool2d.attrs.dilation[1].dtype == "int64" + assert max_pool3d.attrs.strides[0].dtype == "int64" + assert max_pool3d.attrs.strides[1].dtype == "int64" + assert max_pool3d.attrs.strides[2].dtype == "int64" + assert max_pool3d.attrs.padding[0].dtype == "int64" + assert max_pool3d.attrs.padding[1].dtype == "int64" + assert max_pool3d.attrs.padding[2].dtype == "int64" + assert max_pool3d.attrs.padding[3].dtype == "int64" + assert max_pool3d.attrs.padding[4].dtype == "int64" + assert max_pool3d.attrs.dilation[0].dtype == "int64" + assert max_pool3d.attrs.dilation[1].dtype == "int64" + assert max_pool3d.attrs.dilation[2].dtype == "int64" -def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length(): - x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) +def test_max_pool3d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) with pytest.raises(TVMError): - relax.op.nn.avg_pool2d(x, pool_size=(1, 2, 3)) + relax.op.nn.max_pool3d(x, pool_size=(1, 2, 3, 4)) with pytest.raises(TVMError): - relax.op.nn.avg_pool2d(x, strides=(1, 2, 3)) + relax.op.nn.max_pool3d(x, strides=(1, 2, 3, 4)) with pytest.raises(TVMError): - relax.op.nn.avg_pool2d(x, padding=(1, 2, 3)) + relax.op.nn.max_pool3d(x, padding=(1, 2, 3, 4)) with pytest.raises(TVMError): - relax.op.nn.avg_pool2d(x, dilation=(1, 2, 3)) + relax.op.nn.max_pool3d(x, dilation=(1, 2, 3, 4)) -def test_avg_pool2d_infer_struct_info_wrong_layout_string(): +def test_max_pool3d_infer_struct_info_wrong_layout_string(): bb = relax.BlockBuilder() - x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.avg_pool2d(x, layout="OIHW")) + bb.normalize(relax.op.nn.max_pool3d(x, layout="OIHW")) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.avg_pool2d(x, out_layout="OHWI")) + bb.normalize(relax.op.nn.max_pool3d(x, out_layout="OHWI")) -def test_avg_pool2d_wrong_input_ndim(): +def test_max_pool3d_wrong_input_ndim(): bb = relax.BlockBuilder() - x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) - x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.avg_pool2d(x0)) + bb.normalize(relax.op.nn.max_pool3d(x0)) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.avg_pool2d(x1)) + bb.normalize(relax.op.nn.max_pool3d(x1)) -def test_avg_pool2d_infer_struct_info_wrong_input_type(): +def test_max_pool3d_infer_struct_info_wrong_input_type(): bb = relax.BlockBuilder() - x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) - x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.avg_pool2d(x0)) + bb.normalize(relax.op.nn.max_pool3d(x0)) with pytest.raises(TVMError): - bb.normalize(relax.op.nn.avg_pool2d(x1)) + bb.normalize(relax.op.nn.max_pool3d(x1)) -def test_adaptive_avg_pool2d_infer_struct_info(): +def test_avg_pool1d_infer_struct_info(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") - x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) - x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) - x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) x3 = relax.Var("x", R.Tensor("float32")) - x4 = relax.Var("x", R.Tensor(ndim=4)) + x4 = relax.Var("x", R.Tensor(ndim=3)) x5 = relax.Var("x", R.Tensor()) - x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) - x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float32")) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + bb, relax.op.nn.avg_pool1d(x7), relax.TensorStructInfo((2, 3, 32), "float32", vdev0) ) _check_inference( bb, - relax.op.nn.adaptive_avg_pool2d(x7), - relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0), + relax.op.nn.avg_pool1d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30), "float32"), ) _check_inference( bb, - relax.op.nn.adaptive_avg_pool2d(x0, output_size=30), - relax.TensorStructInfo((2, 3, 30, 30), "float32"), + relax.op.nn.avg_pool1d(x0, padding=1), + relax.TensorStructInfo((2, 3, 34), "float32"), ) _check_inference( bb, - relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)), - relax.TensorStructInfo((2, 3, 28, 30), "float32"), + relax.op.nn.avg_pool1d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 35), "float32"), ) _check_inference( bb, - relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.op.nn.avg_pool1d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16), "float32"), ) _check_inference( bb, - relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"), - relax.TensorStructInfo((2, 32, 32, 3), "float32"), + relax.op.nn.avg_pool1d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32), "float32"), ) _check_inference( bb, - relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), - relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), - ) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + relax.op.nn.avg_pool1d(x1, layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32"), ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + bb, + relax.op.nn.avg_pool1d(x0, out_layout="NWC"), + relax.TensorStructInfo((2, 32, 3), "float32"), ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4) + bb, relax.op.nn.avg_pool1d(x2), relax.TensorStructInfo(dtype="float32", ndim=3) ) _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4) + bb, relax.op.nn.avg_pool1d(x3), relax.TensorStructInfo(dtype="float32", ndim=3) ) + _check_inference(bb, relax.op.nn.avg_pool1d(x4), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.nn.avg_pool1d(x5), relax.TensorStructInfo(dtype="", ndim=3)) -def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): +def test_avg_pool1d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() n = tir.Var("n", "int64") c = tir.Var("c", "int64") c16 = tir.Var("c16", "int64") - ih = tir.Var("ih", "int64") iw = tir.Var("iw", "int64") - x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) - x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) - _check_inference( - bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, ih, iw), "float32") - ) - _check_inference( - bb, - relax.op.nn.adaptive_avg_pool2d(x0, output_size=256), - relax.TensorStructInfo((n, c, 256, 256), "float32"), - ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x0, pool_size=3, strides=3, padding=2, dilation=2), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x1, layout="NCW16c", out_layout="NWC"), + relax.TensorStructInfo((n, iw, c * 16), "float32"), + ) + + +def test_avg_pool1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x1, layout="NCW16c"), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_avg_pool1d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool1d(x, pool_size=5, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15), "float32"), + ) + + +def test_avg_pool1d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), + relax.TensorStructInfo( + (n, c, tvm.tir.floordiv(iw, 2)), + "float32", + ), + ) + + +def test_avg_pool1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32), "int64")) + _check_inference(bb, relax.op.nn.avg_pool1d(x0), relax.TensorStructInfo((2, 3, 32), "float16")) + _check_inference(bb, relax.op.nn.avg_pool1d(x1), relax.TensorStructInfo((2, 3, 32), "int8")) + _check_inference(bb, relax.op.nn.avg_pool1d(x2), relax.TensorStructInfo((2, 3, 32), "int64")) + + +def test_avg_pool1d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + avg_pool1d = relax.op.nn.avg_pool1d(x, 3, strides=1, padding=1, dilation=1) + + assert avg_pool1d.attrs.strides[0].dtype == "int64" + assert avg_pool1d.attrs.padding[0].dtype == "int64" + assert avg_pool1d.attrs.padding[1].dtype == "int64" + assert avg_pool1d.attrs.dilation[0].dtype == "int64" + + +def test_avg_pool1d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, pool_size=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, strides=(1, 2)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool1d(x, dilation=(1, 2)) + + +def test_avg_pool1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x, layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x, out_layout="OWI")) + + +def test_avg_pool1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x1)) + + +def test_avg_pool1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool1d(x1)) + + +def test_avg_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) + + _check_inference( + bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x7), relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0) + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, pool_size=(5, 3)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorStructInfo((2, 3, 34, 34), "float32") + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, padding=[1, 2]), + relax.TensorStructInfo((2, 3, 34, 36), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.nn.avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.nn.avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4)) + + +def test_avg_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool2d( + x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"), + relax.TensorStructInfo((n, ih, iw, c * 16), "float32"), + ) + + +def test_avg_pool2d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x1, layout="NCHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x2), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_avg_pool2d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool2d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16), "float32"), + ) + + +def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool2d( + x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True + ), + relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + ) + + +def test_avg_pool2d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64")) + _check_inference( + bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float16") + ) + _check_inference(bb, relax.op.nn.avg_pool2d(x1), relax.TensorStructInfo((2, 3, 32, 32), "int8")) + _check_inference( + bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), "int64") + ) + + +def test_avg_pool2d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 1), dilation=(1, 1)) + + assert avg_pool2d.attrs.strides[0].dtype == "int64" + assert avg_pool2d.attrs.strides[1].dtype == "int64" + assert avg_pool2d.attrs.padding[0].dtype == "int64" + assert avg_pool2d.attrs.padding[1].dtype == "int64" + assert avg_pool2d.attrs.padding[2].dtype == "int64" + assert avg_pool2d.attrs.padding[3].dtype == "int64" + assert avg_pool2d.attrs.dilation[0].dtype == "int64" + assert avg_pool2d.attrs.dilation[1].dtype == "int64" + + +def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, pool_size=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, strides=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, padding=(1, 2, 3)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool2d(x, dilation=(1, 2, 3)) + + +def test_avg_pool2d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x, out_layout="OHWI")) + + +def test_avg_pool2d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x1)) + + +def test_avg_pool2d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool2d(x1)) + + +def test_avg_pool3d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=5)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=5)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32", vdev0)) + + _check_inference( + bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float32") + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x7), relax.TensorStructInfo((2, 3, 32, 32, 32), "float32", vdev0) + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, pool_size=3), + relax.TensorStructInfo((2, 3, 30, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, pool_size=(5, 3, 3)), + relax.TensorStructInfo((2, 3, 28, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, padding=1), + relax.TensorStructInfo((2, 3, 34, 34, 34), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, padding=[1, 2, 3]), + relax.TensorStructInfo((2, 3, 34, 36, 38), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, strides=2), + relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, dilation=2), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x1, layout="NCDHW"), + relax.TensorStructInfo((2, 32, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x0, out_layout="NCDHW"), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), + relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.nn.avg_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.nn.avg_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5)) + + +def test_avg_pool3d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + id_ = tir.Var("id", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, id_, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool3d( + x0, pool_size=(3, 3, 3), strides=(3, 3, 3), padding=(2, 2, 2), dilation=(2, 2, 2) + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(id_ - 1, 3) + 1, + tvm.tir.floordiv(ih - 1, 3) + 1, + tvm.tir.floordiv(iw - 1, 3) + 1, + ), + "float32", + ), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), + relax.TensorStructInfo((n, id_, ih, iw, c * 16), "float32"), + ) + + +def test_avg_pool3d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x1, layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x2), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + + +def test_avg_pool3d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool3d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.avg_pool3d(x, pool_size=(5, 3, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16, 16), "float32"), + ) + + +def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + id_ = tir.Var("id", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.avg_pool3d( + x, + pool_size=(3, 3, 3), + strides=(2, 2, 2), + padding=(1, 1, 1), + dilation=(2, 2, 2), + ceil_mode=True, + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.floordiv(id_, 2), + tvm.tir.floordiv(ih, 2), + tvm.tir.floordiv(iw, 2), + ), + "float32", + ), + ) + + +def test_avg_pool3d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) + + _check_inference( + bb, relax.op.nn.avg_pool3d(x0), relax.TensorStructInfo((2, 3, 32, 32, 32), "float16") + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.avg_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + ) + + +def test_avg_pool3d_stride_padding_dilation_int64(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + avg_pool3d = relax.op.nn.avg_pool3d( + x, (3, 3, 3), strides=(1, 1, 1), padding=(1, 1, 1), dilation=(1, 1, 1) + ) + + assert avg_pool3d.attrs.strides[0].dtype == "int64" + assert avg_pool3d.attrs.strides[1].dtype == "int64" + assert avg_pool3d.attrs.strides[2].dtype == "int64" + assert avg_pool3d.attrs.padding[0].dtype == "int64" + assert avg_pool3d.attrs.padding[1].dtype == "int64" + assert avg_pool3d.attrs.padding[2].dtype == "int64" + assert avg_pool3d.attrs.dilation[0].dtype == "int64" + assert avg_pool3d.attrs.dilation[1].dtype == "int64" + assert avg_pool3d.attrs.dilation[2].dtype == "int64" + + +def test_avg_pool3d_wrong_pool_size_strides_padding_dilation_length(): + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, pool_size=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, strides=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, padding=(1, 2, 3, 4)) + with pytest.raises(TVMError): + relax.op.nn.avg_pool3d(x, dilation=(1, 2, 3, 4)) + + +def test_avg_pool3d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x, layout="OIHW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x, out_layout="OHWI")) + + +def test_avg_pool3d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x1)) + + +def test_avg_pool3d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.avg_pool3d(x1)) + + +def test_adaptive_avg_pool1d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + + x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor(ndim=3)) + x4 = relax.Var("x", R.Tensor()) + + x5 = relax.Var("x", R.Tensor((2, 3, 32), "float32", vdev0)) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0), + relax.TensorStructInfo((2, 3, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x5), + relax.TensorStructInfo((2, 3, 32), "float32", vdev0), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0, output_size=16), + relax.TensorStructInfo((2, 3, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x2), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x3), + relax.TensorStructInfo(dtype="", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x4), + relax.TensorStructInfo(dtype="", ndim=3), + ) + + +def test_adaptive_avg_pool1d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + l = tir.Var("l", "int64") + + x0 = relax.Var("x", R.Tensor((n, c, l), "float32")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0), + relax.TensorStructInfo((n, c, l), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0, output_size=64), + relax.TensorStructInfo((n, c, 64), "float32"), + ) + + +def test_adaptive_avg_pool1d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x0, output_size=20), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool1d(x1), + relax.TensorStructInfo(s1, dtype="float32"), + ) + + +def test_adaptive_avg_pool1d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 64), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 64), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 64), "int64")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool1d(x0), relax.TensorStructInfo((2, 3, 64), "float16") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool1d(x1), relax.TensorStructInfo((2, 3, 64), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool1d(x2), relax.TensorStructInfo((2, 3, 64), "int64") + ) + + +def test_adaptive_avg_pool1d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 64), "float32")) + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool1d(x, output_size=(32, 32)) + + +def test_adaptive_avg_pool1d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 64), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x, layout="OIW")) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x, out_layout="OWI")) + + +def test_adaptive_avg_pool1d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x1)) + + +def test_adaptive_avg_pool1d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 64))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 64), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool1d(x1)) + + +def test_adaptive_avg_pool2d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32", vdev0)) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x7), + relax.TensorStructInfo((2, 3, 32, 32), "float32", vdev0), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=30), + relax.TensorStructInfo((2, 3, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)), + relax.TensorStructInfo((2, 3, 28, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"), + relax.TensorStructInfo((2, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"), + relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"), + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x3), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x4), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x5), relax.TensorStructInfo(dtype="", ndim=4) + ) + + +def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, ih, iw), "float32") + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool2d(x0, output_size=256), + relax.TensorStructInfo((n, c, 256, 256), "float32"), + ) _check_inference( bb, relax.op.nn.adaptive_avg_pool2d(x0, output_size=(256, 128)), @@ -668,5 +1730,197 @@ def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1)) +def test_adaptive_avg_pool3d_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + x1 = relax.Var("x", R.Tensor((2, 32, 32, 32, 3), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=5)) + x3 = relax.Var("x", R.Tensor("float32")) + x4 = relax.Var("x", R.Tensor(ndim=5)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 32, 16), "float32")) + x7 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32", vdev0)) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x7), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32", vdev0), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=30), + relax.TensorStructInfo((2, 3, 30, 30, 30), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=(28, 30, 32)), + relax.TensorStructInfo((2, 3, 28, 30, 32), "float32"), + ) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW"), + relax.TensorStructInfo((2, 32, 32, 32, 3), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, out_layout="NCDHW"), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x6, layout="NCDHW16c", out_layout="NDHWC16c"), + relax.TensorStructInfo((2, 32, 32, 32, 4, 16), "float32"), + ) + + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x3), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x4), relax.TensorStructInfo(dtype="", ndim=5) + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x5), relax.TensorStructInfo(dtype="", ndim=5) + ) + + +def test_adaptive_avg_pool3d_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + c16 = tir.Var("c16", "int64") + d = tir.Var("d", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + + x0 = relax.Var("x", R.Tensor((n, c, d, ih, iw), "float32")) + x1 = relax.Var("x", R.Tensor((n, c, d, ih, iw, c16), "float32")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0), + relax.TensorStructInfo((n, c, d, ih, iw), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=256), + relax.TensorStructInfo((n, c, 256, 256, 256), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=(256, 128, 64)), + relax.TensorStructInfo((n, c, 256, 128, 64), "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW16c", out_layout="NDHWC"), + relax.TensorStructInfo((n, d, ih, iw, c * 16), "float32"), + ) + + +def test_adaptive_avg_pool3d_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=5)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.nn.adaptive_avg_pool3d(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, output_size=32), + relax.TensorStructInfo(dtype="float32", ndim=5), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x1, layout="NCDHW16c"), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0, out_layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x2, out_layout="NCDHW16c"), + relax.TensorStructInfo(dtype="float32", ndim=6), + ) + + +def test_adaptive_avg_pool3d_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "int64")) + + _check_inference( + bb, + relax.op.nn.adaptive_avg_pool3d(x0), + relax.TensorStructInfo((2, 3, 32, 32, 32), "float16"), + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x1), relax.TensorStructInfo((2, 3, 32, 32, 32), "int8") + ) + _check_inference( + bb, relax.op.nn.adaptive_avg_pool3d(x2), relax.TensorStructInfo((2, 3, 32, 32, 32), "int64") + ) + + +def test_adaptive_avg_pool3d_wrong_output_size_ndim(): + x = relax.Var("x", R.Tensor((2, 3, 32, 32, 32), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.adaptive_avg_pool3d(x, (32, 32, 32, 32)) + + +def test_adaptive_avg_pool3d_infer_struct_info_wrong_layout_string(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 28, 28, 28), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x, layout="OIDHW")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x, out_layout="OHIDW")) + + +def test_adaptive_avg_pool3d_wrong_input_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 28, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x1)) + + +def test_adaptive_avg_pool3d_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28, 28))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28, 28), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.adaptive_avg_pool3d(x1)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 1fdded92393c..c94dd9f5789d 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -24,7 +24,7 @@ import tvm import tvm.testing from tvm import relax -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.script import ir as I, relax as R, tir as T exec_mode = tvm.testing.parameter("bytecode", "compiled") diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index db4130f947d1..262e37b91b1b 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -1434,7 +1434,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="void", @@ -1477,7 +1477,7 @@ def main( method="linear", coordinate_transformation_mode="half_pixel", rounding_method="round", - cubic_alpha=-0.5, + cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0, out_dtype="void", diff --git a/tests/python/relax/test_transform_gradient.py b/tests/python/relax/test_transform_gradient.py index 072edea5c400..47c41ca108f9 100644 --- a/tests/python/relax/test_transform_gradient.py +++ b/tests/python/relax/test_transform_gradient.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm import relax -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.ir.base import assert_structural_equal from tvm.script.parser import relax as R, tir as T, ir as I diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index f5a20b298a57..44469acdc1c0 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -282,7 +282,8 @@ def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64 T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) - T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww < T.int64(5), rxplaceholder[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww] / T.Cast("float32", T.max((T.min(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) + T.int64(3) - v_wh * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) - v_wh * T.int64(2) - T.int64(2), T.int64(0))) * (T.min(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) + T.int64(4) - v_ww * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) - v_ww * T.int64(2) - T.int64(1), T.int64(0))), T.int64(1))), T.float32(0)) + T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww < T.int64(5), rxplaceholder[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww] / T.Cast("float32", T.max((T.min(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) + T.int64(3) - v_wh * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh - T.int64(1), T.int64(0)) * T.int64(2)) * (T.min(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) + T.int64(4) - v_ww * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) - v_ww * T.int64(2) - T.int64(1), T.int64(0))), T.int64(1))), T.float32(0.0)) + @R.function def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"): cls = Expected diff --git a/tests/python/runtime/test_runtime_error.py b/tests/python/runtime/test_runtime_error.py index 6f950d4ca77c..6eb6cc9b641b 100644 --- a/tests/python/runtime/test_runtime_error.py +++ b/tests/python/runtime/test_runtime_error.py @@ -92,7 +92,7 @@ def flevel3(): @functools.lru_cache() def _has_debug_symbols(): - lib = tvm._ffi.base._LIB + lib = tvm.base._LIB headers = subprocess.check_output(["objdump", "--section-headers", lib._name], encoding="utf-8") return ".debug" in headers diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 88e9481c4c07..6711ccf92f3f 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -103,26 +103,7 @@ def check_remote(): assert f1(10) == 11 f3 = client.get_function("rpc.test.except") - with pytest.raises(tvm._ffi.base.TVMError): - f3("abc") - - f2 = client.get_function("rpc.test.strcat") - assert f2("abc", 11) == "abc:11" - - check_remote() - - -@tvm.testing.requires_rpc -def test_rpc_simple_wlog(): - server = rpc.Server(key="x1") - client = rpc.connect("127.0.0.1", server.port, key="x1", enable_logging=True) - - def check_remote(): - f1 = client.get_function("rpc.test.addone") - assert f1(10) == 11 - f3 = client.get_function("rpc.test.except") - - with pytest.raises(tvm._ffi.base.TVMError): + with pytest.raises(tvm.base.TVMError): f3("abc") f2 = client.get_function("rpc.test.strcat") @@ -428,9 +409,9 @@ def run_arr_test(): @tvm.testing.requires_rpc def test_rpc_return_remote_object(): def check(client, is_local): - make_shape = client.get_function("runtime.ShapeTuple") - get_elem = client.get_function("runtime.GetShapeTupleElem") - get_size = client.get_function("runtime.GetShapeTupleSize") + make_shape = client.get_function("ffi.Shape") + get_elem = client.get_function("testing.GetShapeElem") + get_size = client.get_function("testing.GetShapeSize") shape = make_shape(2, 3) assert shape.type_key == "runtime.RPCObjectRef" assert get_elem(shape, 0) == 2 @@ -681,7 +662,7 @@ def test_compiled_function_with_zero_arguments(call_with_unused_argument): """RPC functions do not require an argument This is a regression test. When no arguments are provided, RPC - provides NULL as the `TVMValue* args` argument to a PackedFunc. + provides NULL as the `TVMFFIAny* args` argument to a PackedFunc. However, previous implementations of `MakePackedAPI` unconditionally asserted that the `args` pointer was non-null. This assertion is now generated only when the function accepts diff --git a/tests/python/target/test_riscv_features.py b/tests/python/target/test_riscv_features.py index 765d492a2dc0..17452a86dbf7 100644 --- a/tests/python/target/test_riscv_features.py +++ b/tests/python/target/test_riscv_features.py @@ -24,20 +24,16 @@ # fmt: off min_llvm_version, tvm_target, vec_width = tvm.testing.parameters( - # generic, no-vec -> (default 256) - (-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+i,+m", 256), - (-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+64bit,+a,+c,+d,+f,+m", 256), - # generic, with-vec -> (default 256) - (-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 256), - (-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 256), - # explicit -vector-width - (-1, "llvm -device=riscv_cpu -vector-width=128 -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 128), - (-1, "llvm -device=riscv_cpu -vector-width=128 -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 128), - (-1, "llvm -device=riscv_cpu -vector-width=512 -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 512), - (-1, "llvm -device=riscv_cpu -vector-width=512 -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 512), + # generic, no vector -> (default 128) + (-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+i,+m", 128), + (-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+64bit,+a,+c,+d,+f,+m", 128), + # generic, with vector -> (default zvl128b) + (-1, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v", 128), + (-1, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v", 128), # explicit +zvlXXXb - (14, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v,+zvl64b", 64), - (14, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v,+zvl64b", 64), + (14, "llvm -device=riscv_cpu -mtriple=riscv32-linux-gnu -mcpu=generic-rv32 -mattr=+i,+m,+v,+zvl64b", 128), + (14, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v,+zvl256b", 256), + (14, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v,+zvl512b", 512), # vendor CPU (17, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=sifive-x280", 512), (18, "llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=sifive-p670", 128), diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 9925f54be4db..b0850a89b5c5 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -882,6 +882,14 @@ def te_workload(): _check_workload(te_workload, tir_workload) +def test_global_pool(): + # fix the issue-17938 + data = te.placeholder((1, 1, 32, 32), dtype="int8", name="data") + op_output = topi.nn.global_pool(data=data, pool_type="avg", layout="NCHW") + f = te.create_prim_func([data, op_output]) + assert f + + def test_nested_reduce_domain_dependency(): @T.prim_func def tir_workload( diff --git a/tests/python/te/test_te_verify_compute.py b/tests/python/te/test_te_verify_compute.py index 7ea9321e18c9..880e960da7fa 100644 --- a/tests/python/te/test_te_verify_compute.py +++ b/tests/python/te/test_te_verify_compute.py @@ -35,14 +35,14 @@ def test_verify_compute(): # Valid compute try: B = te.compute((n,), f1, name="B") - except tvm._ffi.base.TVMError as ex: + except tvm.base.TVMError as ex: assert False # # Valid compute try: B = te.compute((n,), f2, name="B") - except tvm._ffi.base.TVMError as ex: + except tvm.base.TVMError as ex: assert False # @@ -50,7 +50,7 @@ def test_verify_compute(): try: B = te.compute((n,), f3, name="B") assert False - except tvm._ffi.base.TVMError as ex: + except tvm.base.TVMError as ex: pass # @@ -58,7 +58,7 @@ def test_verify_compute(): try: B = te.compute((n,), f4, name="B") assert False - except tvm._ffi.base.TVMError as ex: + except tvm.base.TVMError as ex: pass # @@ -66,7 +66,7 @@ def test_verify_compute(): try: B0, B1 = te.compute((n,), f5, name="B") assert False - except tvm._ffi.base.TVMError as ex: + except tvm.base.TVMError as ex: pass # @@ -74,7 +74,7 @@ def test_verify_compute(): try: B0, B1 = te.compute((n,), f6, name="B") assert False - except tvm._ffi.base.TVMError as ex: + except tvm.base.TVMError as ex: pass diff --git a/tests/python/testing/test_type_annotation_checker.py b/tests/python/testing/test_type_annotation_checker.py index 9af356b97198..42ce1e103903 100644 --- a/tests/python/testing/test_type_annotation_checker.py +++ b/tests/python/testing/test_type_annotation_checker.py @@ -46,7 +46,7 @@ def str_func(x: str) -> str: [5], [], # Tuples are allowed to be used as lists, because both are - # represented in FFI as tvm::runtime::Array. + # represented in FFI as tvm::Array. (1, 2, 3), ], "negative_cases": [ diff --git a/tests/python/tir-base/test_tir_base.py b/tests/python/tir-base/test_tir_base.py index b33e846741cd..d204ebfb6084 100644 --- a/tests/python/tir-base/test_tir_base.py +++ b/tests/python/tir-base/test_tir_base.py @@ -17,7 +17,7 @@ import tvm import pytest from tvm import tir -from tvm._ffi.base import TVMError +from tvm.base import TVMError from tvm.ir.transform import PassContext import itertools import pytest diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index d2a73c12e79f..3e731f55fb0e 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -79,7 +79,7 @@ def test_unary_intrin(): (tvm.tir.atanh, lambda x: np.arctanh(x)), ] - def run_test(tvm_intrin, np_func): + def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): m = te.var( "m", ) @@ -98,10 +98,28 @@ def run_test(tvm_intrin, np_func): a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), dev) b = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) func(a, b) - tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=atol, rtol=rtol) + + # Out‐of‐bounds test for asin/acos + name = tvm_intrin.__name__ + if name in ("asin", "acos"): + # generate some values outside [-1, 1] + n = 8 + out_np = np.concatenate( + [ + np.random.uniform(1.1, 2.0, size=n // 2), + np.random.uniform(-2.0, -1.1, size=n // 2), + ] + ).astype(A.dtype) + a2 = tvm.nd.array(out_np, dev) + b2 = tvm.nd.array(np.empty_like(out_np), dev) + func(a2, b2) + # all outputs should be NaN + assert np.all(np.isnan(b2.numpy())) for func in test_funcs: - run_test(*func) + atol = rtol = 1e-3 if func[0].__name__ in ["asin", "acos", "atan"] else 1e-5 + run_test(*func, atol, rtol) def test_binary_intrin(): diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index 22344acfe1d4..f09f7417baf6 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -816,7 +816,7 @@ def before(a: T.handle): warning_msg = ( "Warning: The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " - "AArch64 SVE targets, but the target was " + "VLA targets, but the target was " ) captured = capfd.readouterr().err assert warning_msg in captured diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 5208262221b9..e7e64d89168e 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -352,7 +352,7 @@ def test_no_normalization_without_commoning(): def func_distributivity( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - B[i1] = x * (y + z) + B[i1] = (y + z) * x B[i2] = x * y + x * z diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 46487f969b96..da079f46e38e 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -430,7 +430,7 @@ def tvm_callback_cuda_postproc(code, _): # Restore previous postproc func to avoid impacting other tests if prev_postproc is None: - tvm._ffi.registry.remove_global_func(func_name) + tvm.ffi.registry.remove_global_func(func_name) else: tvm.register_func(func_name, prev_postproc, override=True) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 9b61255285be..13bb1c60cb53 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -670,7 +670,7 @@ def expected(a: T.handle, b: T.handle): def test_vectorize_with_explicitly_disabled_buffer_level_predication(): - # Since the target has the SVE feature, buffer level predication is enabled + # Since the target has the VLA feature, buffer level predication is enabled # by default. However, it has been explicitly disabled by the pass context # option, so no buffer-level predicates should be added. @T.prim_func diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh index 2eabac31cc28..f7c9f3c097af 100755 --- a/tests/scripts/task_java_unittest.sh +++ b/tests/scripts/task_java_unittest.sh @@ -35,16 +35,13 @@ cleanup() } trap cleanup 0 -# python3 "$SCRIPT_DIR"/test_add_cpu.py "$TEMP_DIR" -# python3 "$SCRIPT_DIR"/test_add_gpu.py "$TEMP_DIR" - -# Skip the Java RPC Unittests, see https://github.com/apache/tvm/issues/13168 -# # start rpc proxy server -# PORT=$(( ( RANDOM % 1000 ) + 9000 )) -# python3 $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 & - -# make jvmpkg -# make jvmpkg JVM_TEST_ARGS="-DskipTests=false \ -# -Dtest.tempdir=$TEMP_DIR \ -# -Dtest.rpc.proxy.host=localhost \ -# -Dtest.rpc.proxy.port=$PORT" +make jvmpkg + +# Skip the Java Tests for now +exit 0 + +# expose tvm runtime lib to system env +export LD_LIBRARY_PATH=$CURR_DIR/../../build/:$LD_LIBRARY_PATH +python "$SCRIPT_DIR"/prepare_test_libs.py "$TEMP_DIR" +make jvmpkg JVM_TEST_ARGS="-DskipTests=false\ + -Dtest.tempdir=$TEMP_DIR" diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 3b270b21f60a..6a6a2171bd1f 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -72,9 +72,6 @@ function shard2 { echo "clang-format check..." tests/lint/git-clang-format.sh - echo "Rust check..." - tests/lint/rust_format.sh - echo "Docker check..." tests/lint/docker-format.sh } diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index b217f692b4a3..7b58658bd7c7 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -161,12 +161,6 @@ npm install npm run typedoc cd .. -# Rust doc -cd rust -# Temp disable rust doc build -# cargo doc --workspace --no-deps -cd .. - # Prepare the doc dir rm -rf _docs mv docs/_build/html _docs @@ -174,7 +168,6 @@ rm -f _docs/.buildinfo mkdir -p _docs/reference/api mv docs/doxygen/html _docs/reference/api/doxygen mv jvm/core/target/site/apidocs _docs/reference/api/javadoc -# mv rust/target/doc _docs/api/rust mv web/dist/docs _docs/reference/api/typedoc git rev-parse HEAD > _docs/commit_hash diff --git a/version.py b/version.py index db1f6cadfb8c..bf2f0343d6c7 100644 --- a/version.py +++ b/version.py @@ -20,8 +20,8 @@ This script runs and update all the locations that related to versions List of affected files: -- tvm-root/python/tvm/_ffi/libinfo.py -- tvm-root/include/tvm/runtime/c_runtime_api.h +- tvm-root/python/tvm/libinfo.py +- tvm-root/include/tvm/runtime/base.h - tvm-root/conda/recipe/meta.yaml - tvm-root/web/package.json """ @@ -170,7 +170,7 @@ def sync_version(pub_ver, local_ver, dry_run): """Synchronize version.""" # python uses the PEP-440: local version update( - os.path.join(PROJ_ROOT, "python", "tvm", "_ffi", "libinfo.py"), + os.path.join(PROJ_ROOT, "python", "tvm", "libinfo.py"), r"(?<=__version__ = \")[.0-9a-z\+]+", local_ver, dry_run, @@ -179,7 +179,7 @@ def sync_version(pub_ver, local_ver, dry_run): # Note that full git hash is already available in libtvm # C++ header update( - os.path.join(PROJ_ROOT, "include", "tvm", "runtime", "c_runtime_api.h"), + os.path.join(PROJ_ROOT, "include", "tvm", "runtime", "base.h"), r'(?<=TVM_VERSION ")[.0-9a-z\+]+', pub_ver, dry_run, diff --git a/web/.eslintignore b/web/.eslintignore index f71ee79871c4..1549e07c251e 100644 --- a/web/.eslintignore +++ b/web/.eslintignore @@ -1,2 +1,4 @@ dist debug +tvmjs_runtime_wasi.js +lib diff --git a/web/apps/node/example.js b/web/apps/node/example.js index 580bbf57ab80..62c9157c7c29 100644 --- a/web/apps/node/example.js +++ b/web/apps/node/example.js @@ -31,7 +31,7 @@ const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); tvmjs.instantiate(wasmSource, tvmjs.createPolyfillWASI()) .then((tvm) => { tvm.beginScope(); - const log_info = tvm.getGlobalFunc("testing.log_info_str"); + const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str"); log_info("hello world"); // List all the global functions from the runtime. console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index e50e6c37d34c..922b25b0d74b 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -28,12 +28,11 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY -#include +#include #include -#include -#include #include "../../src/runtime/rpc/rpc_local_session.h" @@ -59,27 +58,33 @@ TVM_DLL void TVMWasmFreeSpace(void* data); * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer 3A * \return 0 if success. */ -TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out); +TVM_DLL int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFFIObjectHandle* out); + +/*! + * \brief Get the last error message. + * \return The last error message. + */ +TVM_DLL const char* TVMFFIWasmGetLastError(); // --- APIs to be implemented by the frontend. --- + /*! - * \brief Wasm frontend packed function caller. + * \brief Wasm frontend new ffi call function caller. * + * \param self The pointer to the ffi::Function. * \param args The arguments - * \param type_codes The type codes of the arguments * \param num_args Number of arguments. - * \param ret The return value handle. - * \param resource_handle The handle additional resource handle from front-end. + * \param result The return value handle. * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. */ -extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, - void* resource_handle); - +extern int TVMFFIWasmSafeCall(void* self, const TVMFFIAny* args, int32_t num_args, + TVMFFIAny* result); /*! - * \brief Wasm frontend resource finalizer. - * \param resource_handle The pointer to the external resource. + * \brief Delete ffi::Function. + * \param self The pointer to the ffi::Function. */ -extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +extern void TVMFFIWasmFunctionDeleter(void* self); + } // extern "C" void* TVMWasmAllocSpace(int size) { @@ -89,9 +94,14 @@ void* TVMWasmAllocSpace(int size) { void TVMWasmFreeSpace(void* arr) { delete[] static_cast(arr); } -int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) { - return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer, - out); +int TVMFFIWasmFunctionCreate(void* self, TVMFFIObjectHandle* out) { + return TVMFFIFunctionCreate(self, TVMFFIWasmSafeCall, TVMFFIWasmFunctionDeleter, out); +} + +const char* TVMFFIWasmGetLastError() { + static thread_local std::string last_error; + last_error = ::tvm::ffi::details::MoveFromSafeCallRaised().what(); + return last_error.c_str(); } namespace tvm { @@ -291,7 +301,7 @@ class AsyncLocalSession : public LocalSession { } }; -TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { +TVM_FFI_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { return CreateRPCSessionModule(std::make_shared()); }); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index b8ebadff4f5c..40dfb31ad19f 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -27,23 +27,20 @@ #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 #define TVM_FFI_USE_LIBBACKTRACE 0 +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY -#include #include -#include "src/runtime/c_runtime_api.cc" -#include "src/runtime/container.cc" #include "src/runtime/contrib/sort/sort.cc" #include "src/runtime/cpu_device_api.cc" +#include "src/runtime/device_api.cc" #include "src/runtime/file_utils.cc" #include "src/runtime/library_module.cc" #include "src/runtime/logging.cc" #include "src/runtime/module.cc" #include "src/runtime/ndarray.cc" -#include "src/runtime/object.cc" #include "src/runtime/profiling.cc" -#include "src/runtime/registry.cc" #include "src/runtime/rpc/rpc_channel.cc" #include "src/runtime/rpc/rpc_endpoint.cc" #include "src/runtime/rpc/rpc_event_impl.cc" @@ -107,45 +104,24 @@ void LogMessageImpl(const std::string& file, int lineno, int level, const std::s } // namespace detail -TVM_REGISTER_GLOBAL("testing.echo").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = args[0]; -}); - -TVM_REGISTER_GLOBAL("testing.call").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - (args[0].cast()).CallPacked(args.Slice(1), ret); -}); - -TVM_REGISTER_GLOBAL("testing.ret_string").set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - *ret = args[0].cast(); -}); - -TVM_REGISTER_GLOBAL("testing.log_info_str") +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.call") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - LOG(INFO) << args[0].cast(); + (args[0].cast()).CallPacked(args.Slice(1), ret); }); -TVM_REGISTER_GLOBAL("testing.log_fatal_str") +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.log_info_str") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - LOG(FATAL) << args[0].cast(); + LOG(INFO) << args[0].cast(); }); -TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.add_one").set_body_typed([](int x) { return x + 1; }); -TVM_REGISTER_GLOBAL("testing.wrap_callback") +TVM_FFI_REGISTER_GLOBAL("tvmjs.testing.wrap_callback") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { ffi::Function pf = args[0].cast(); *ret = ffi::TypedFunction([pf]() { pf(); }); }); -// internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.object_use_count") - .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { - auto obj = args[0].cast(); - // subtract the current one because we always copy - // and get another value. - *ret = (obj.use_count() - 1); - }); - void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, std::string dtype) { if (format == "f32-to-bf16" && dtype == "float32") { std::vector buffer(bytes.length() / 2); @@ -167,15 +143,15 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format, } } -TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); +TVM_FFI_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage); // Concatenate n TVMArrays -TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat") +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { std::vector data; for (int i = 0; i < args.size(); ++i) { // Get i-th TVMArray - auto* arr_i = args[i].as(); + auto* arr_i = args[i].as(); ICHECK(arr_i != nullptr); for (size_t j = 0; j < arr_i->size(); ++j) { // Push back each j-th element of the i-th array @@ -220,7 +196,7 @@ NDArray ConcatEmbeddings(const std::vector& embeddings) { } // Concatenate n NDArrays -TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { std::vector embeddings; for (int i = 0; i < args.size(); ++i) { @@ -230,5 +206,19 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings") *ret = result; }); +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyFromBytes") + .set_body_typed([](NDArray nd, TVMFFIByteArray* bytes) { + nd.CopyFromBytes(bytes->data, bytes->size); + }); + +TVM_FFI_REGISTER_GLOBAL("tvmjs.runtime.NDArrayCopyToBytes") + .set_body_typed([](NDArray nd) -> ffi::Bytes { + size_t size = GetDataSize(*(nd.operator->())); + std::string bytes; + bytes.resize(size); + nd.CopyToBytes(bytes.data(), size); + return ffi::Bytes(bytes); + }); + } // namespace runtime } // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 3d74d77f14ce..00b1db266a0b 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -26,13 +26,11 @@ #define TVM_LOG_STACK_TRACE 0 #define TVM_LOG_DEBUG 0 #define TVM_LOG_CUSTOMIZE 1 +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 1 #define DMLC_USE_LOGGING_LIBRARY -#include -#include +#include #include -#include -#include #include #include @@ -152,7 +150,10 @@ typedef dmlc::ThreadLocalStore WebGPUThreadStore; WebGPUThreadEntry::WebGPUThreadEntry() : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) {} -WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); } +WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { + static thread_local WebGPUThreadEntry inst = WebGPUThreadEntry(); + return &inst; +} class WebGPUModuleNode final : public runtime::ModuleNode { public: @@ -241,12 +242,13 @@ Module WebGPUModuleLoadBinary(void* strm) { } // for now webgpu is hosted via a vulkan module. -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); +TVM_FFI_REGISTER_GLOBAL("runtime.module.loadbinary_webgpu").set_body_typed(WebGPUModuleLoadBinary); -TVM_REGISTER_GLOBAL("device_api.webgpu").set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { - DeviceAPI* ptr = WebGPUDeviceAPI::Global(); - *rv = static_cast(ptr); -}); +TVM_FFI_REGISTER_GLOBAL("device_api.webgpu") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* rv) { + DeviceAPI* ptr = WebGPUDeviceAPI::Global(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm diff --git a/web/package.json b/web/package.json index 583232d20951..b4fc25e12fcf 100644 --- a/web/package.json +++ b/web/package.json @@ -45,5 +45,9 @@ "typedoc-plugin-missing-exports": "2.0.0", "typescript": "^4.9.5", "ws": "^7.2.5" + }, + "dependencies": { + "audit": "^0.0.6", + "fix": "^0.0.6" } } diff --git a/web/src/asyncify.ts b/web/src/asyncify.ts index 703dbbf80a10..6074a559e00d 100644 --- a/web/src/asyncify.ts +++ b/web/src/asyncify.ts @@ -70,6 +70,15 @@ export class AsyncifyHandler { return this.exports.asyncify_stop_rewind !== undefined; } + /** + * Get the current asynctify state + * + * @returns The current asynctify state + */ + isNormalStackState(): boolean { + return this.state == AsyncifyStateKind.None; + } + /** * Get the current asynctify state * diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts index c4941f07d57a..c9a5e263d5f2 100644 --- a/web/src/ctypes.ts +++ b/web/src/ctypes.ts @@ -27,231 +27,165 @@ export type Pointer = number; /** A pointer offset, need to add a base address to get a valid ptr. */ export type PtrOffset = number; -// -- TVM runtime C API -- /** - * const char *TVMGetLastError(); - */ -export type FTVMGetLastError = () => Pointer; - -/** - * void TVMAPISetLastError(const char* msg); - */ -export type FTVMAPISetLastError = (msg: Pointer) => void; - -/** - * int TVMModGetFunction(TVMModuleHandle mod, - * const char* func_name, - * int query_imports, - * TVMFunctionHandle *out); - */ -export type FTVMModGetFunction = ( - mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; -/** - * int TVMModImport(TVMModuleHandle mod, - * TVMModuleHandle dep); - */ -export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; - -/** - * int TVMModFree(TVMModuleHandle mod); - */ -export type FTVMModFree = (mod: Pointer) => number; - -/** - * int TVMFuncFree(TVMFunctionHandle func); - */ -export type FTVMFuncFree = (func: Pointer) => number; - -/** - * int TVMFuncCall(TVMFunctionHandle func, - * TVMValue* arg_values, - * int* type_codes, - * int num_args, - * TVMValue* ret_val, - * int* ret_type_code); - */ -export type FTVMFuncCall = ( - func: Pointer, argValues: Pointer, typeCode: Pointer, - nargs: number, retValue: Pointer, retCode: Pointer) => number; - -/** - * int TVMCFuncSetReturn(TVMRetValueHandle ret, - * TVMValue* value, - * int* type_code, - * int num_ret); - */ -export type FTVMCFuncSetReturn = ( - ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; - -/** - * int TVMCbArgToReturn(TVMValue* value, int* code); - */ -export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; - -/** - * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + * Size of common data types. */ -export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + TVMFFIAny = 8 * 2, + DLDataType = I32, + DLDevice = I32 + I32, + ObjectHeader = 8 * 2, +} +//---------------The new TVM FFI--------------- /** - * int TVMFuncRegisterGlobal( - * const char* name, TVMFunctionHandle f, int override); - */ -export type FTVMFuncRegisterGlobal = ( - name: Pointer, f: Pointer, override: number) => number; + * Type Index in new TVM FFI. + * + * We are keeping the same style as C API here. + */ +export const enum TypeIndex { + kTVMFFINone = 0, + /*! \brief POD int value */ + kTVMFFIInt = 1, + /*! \brief POD bool value */ + kTVMFFIBool = 2, + /*! \brief POD float value */ + kTVMFFIFloat = 3, + /*! \brief Opaque pointer object */ + kTVMFFIOpaquePtr = 4, + /*! \brief DLDataType */ + kTVMFFIDataType = 5, + /*! \brief DLDevice */ + kTVMFFIDevice = 6, + /*! \brief DLTensor* */ + kTVMFFIDLTensorPtr = 7, + /*! \brief const char**/ + kTVMFFIRawStr = 8, + /*! \brief TVMFFIByteArray* */ + kTVMFFIByteArrayPtr = 9, + /*! \brief R-value reference to ObjectRef */ + kTVMFFIObjectRValueRef = 10, + /*! \brief Start of statically defined objects. */ + kTVMFFIStaticObjectBegin = 64, + /*! + * \brief Object, all objects starts with TVMFFIObject as its header. + * \note We will also add other fields + */ + kTVMFFIObject = 64, + /*! + * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIStr = 65, + /*! + * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIBytes = 66, + /*! \brief Error object. */ + kTVMFFIError = 67, + /*! \brief Function object. */ + kTVMFFIFunction = 68, + /*! \brief Array object. */ + kTVMFFIArray = 69, + /*! \brief Map object. */ + kTVMFFIMap = 70, + /*! + * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } + */ + kTVMFFIShape = 71, + /*! + * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... } + */ + kTVMFFINDArray = 72, + /*! \brief Runtime module object. */ + kTVMFFIModule = 73, +} -/** - *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); - */ -export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; +// -- TVM Wasm Auxiliary C API -- -/** - * int TVMArrayAlloc(const tvm_index_t* shape, - * int ndim, - * int dtype_code, - * int dtype_bits, - * int dtype_lanes, - * int device_type, - * int device_id, - * TVMArrayHandle* out); - */ -export type FTVMArrayAlloc = ( - shape: Pointer, ndim: number, - dtypeCode: number, dtypeBits: number, - dtypeLanes: number, deviceType: number, deviceId: number, - out: Pointer) => number; +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; -/** - * int TVMArrayFree(TVMArrayHandle handle); - */ -export type FTVMArrayFree = (handle: Pointer) => number; +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; -/** - * int TVMArrayCopyFromBytes(TVMArrayHandle handle, - * void* data, - * size_t nbytes); - */ -export type FTVMArrayCopyFromBytes = ( - handle: Pointer, data: Pointer, nbytes: number) => number; +/** const char* TVMFFIWasmGetLastError(); */ +export type FTVMFFIWasmGetLastError = () => Pointer; /** - * int TVMArrayCopyToBytes(TVMArrayHandle handle, - * void* data, - * size_t nbytes); + * int TVMFFIWasmSafeCallType(void* self, const TVMFFIAny* args, + * int32_t num_args, TVMFFIAny* result); */ -export type FTVMArrayCopyToBytes = ( - handle: Pointer, data: Pointer, nbytes: number) => number; +export type FTVMFFIWasmSafeCallType = ( + self: Pointer, args: Pointer, num_args: number, + result: Pointer) => number; /** - * int TVMArrayCopyFromTo(TVMArrayHandle from, - * TVMArrayHandle to, - * TVMStreamHandle stream); + * int TVMFFIWasmFunctionCreate(void* resource_handle, TVMFunctionHandle* out); */ -export type FTVMArrayCopyFromTo = ( - from: Pointer, to: Pointer, stream: Pointer) => number; +export type FTVMFFIWasmFunctionCreate = ( + resource_handle: Pointer, out: Pointer) => number; /** - * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + * void TVMFFIWasmFunctionDeleter(void* self); */ -export type FTVMSynchronize = ( - deviceType: number, deviceId: number, stream: Pointer) => number; +export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void; /** - * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, - * int* type_codes, - * int num_args, - * TVMValue* out_ret_value, - * int* out_ret_tcode); + * int TVMFFIObjectFree(TVMFFIObjectHandle obj); */ -export type FTVMBackendPackedCFunc = ( - argValues: Pointer, argCodes: Pointer, nargs: number, - outValue: Pointer, outCode: Pointer) => number; - +export type FTVMFFIObjectFree = (obj: Pointer) => number; /** - * int TVMObjectFree(TVMObjectHandle obj); + * int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex); */ -export type FTVMObjectFree = (obj: Pointer) => number; +export type FTVMFFITypeKeyToIndex = (type_key: Pointer, out_tindex: Pointer) => number; /** - * int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); + * int TVMFFIAnyViewToOwnedAny(const TVMFFIAny* any_view, TVMFFIAny* out); */ -export type FTVMObjectGetTypeIndex = (obj: Pointer, out_tindex: Pointer) => number; +export type FTVMFFIAnyViewToOwnedAny = (any_view: Pointer, out: Pointer) => number; /** - * int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key); + * void TVMFFIErrorSetRaisedByCStr(const char* kind, const char* message); */ -export type FTVMObjectTypeIndex2Key = (type_index: number, out_type_key: Pointer) => number; +export type FTVMFFIErrorSetRaisedByCStr = (kind: Pointer, message: Pointer) => void; /** - * int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + * int TVMFFIFunctionSetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle f, + * int override); */ -export type FTVMObjectTypeKey2Index = (type_key: Pointer, out_tindex: Pointer) => number; - -// -- TVM Wasm Auxiliary C API -- - -/** void* TVMWasmAllocSpace(int size); */ -export type FTVMWasmAllocSpace = (size: number) => Pointer; - -/** void TVMWasmFreeSpace(void* data); */ -export type FTVMWasmFreeSpace = (ptr: Pointer) => void; +export type FTVMFFIFunctionSetGlobal = (name: Pointer, f: Pointer, override: number) => number; /** - * int TVMWasmPackedCFunc(TVMValue* args, - * int* type_codes, - * int num_args, - * TVMRetValueHandle ret, - * void* resource_handle); + * int TVMFFIFunctionGetGlobal(const TVMFFIByteArray* name, TVMFFIObjectHandle* out); */ -export type FTVMWasmPackedCFunc = ( - args: Pointer, typeCodes: Pointer, nargs: number, - ret: Pointer, resourceHandle: Pointer) => number; +export type FTVMFFIFunctionGetGlobal = (name: Pointer, out: Pointer) => number; /** - * int TVMWasmFuncCreateFromCFunc(void* resource_handle, - * TVMFunctionHandle *out); + * int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args, + * TVMFFIAny* result); */ -export type FTVMWasmFuncCreateFromCFunc = ( - resource: Pointer, out: Pointer) => number; +export type FTVMFFIFunctionCall = (func: Pointer, args: Pointer, num_args: number, + result: Pointer) => number; /** - * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + * int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType* out); */ -export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; +export type FTVMFFIDataTypeFromString = (str: Pointer, out: Pointer) => number; /** - * Size of common data types. + * int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIObjectHandle* out); */ -export const enum SizeOf { - U8 = 1, - U16 = 2, - I32 = 4, - I64 = 8, - F32 = 4, - F64 = 8, - TVMValue = 8, - DLDataType = I32, - DLDevice = I32 + I32, -} +export type FTVMFFIDataTypeToString = (dtype: Pointer, out: Pointer) => number; /** - * Argument Type code in TVM FFI. + * TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); */ -export const enum ArgTypeCode { - Int = 0, - UInt = 1, - Float = 2, - TVMOpaqueHandle = 3, - Null = 4, - TVMDataType = 5, - DLDevice = 6, - TVMDLTensorHandle = 7, - TVMObjectHandle = 8, - TVMModuleHandle = 9, - TVMPackedFuncHandle = 10, - TVMStr = 11, - TVMBytes = 12, - TVMNDArrayHandle = 13, - TVMObjectRValueRefArg = 14, - TVMArgBool = 15, -} +export type FTVMFFIGetTypeInfo = (type_index: number) => Pointer; diff --git a/web/src/environment.ts b/web/src/environment.ts index 42a873f1284e..01e19a1c18f4 100644 --- a/web/src/environment.ts +++ b/web/src/environment.ts @@ -75,7 +75,7 @@ export class Environment implements LibraryProvider { * We maintain a separate table so that we can have un-limited amount * of functions that do not maps to the address space. */ - packedCFuncTable: Array = [ + packedCFuncTable: Array = [ undefined, ]; /** @@ -115,28 +115,27 @@ export class Environment implements LibraryProvider { // eslint-disable-next-line @typescript-eslint/no-unused-vars "emscripten_notify_memory_growth": (index: number): void => {} }; - const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + const wasmSafeCall: ctypes.FTVMFFIWasmSafeCallType = ( + self: Pointer, args: Pointer, - typeCodes: Pointer, - nargs: number, - ret: Pointer, - resourceHandle: Pointer + num_args: number, + result: Pointer ): number => { - const cfunc = this.packedCFuncTable[resourceHandle]; + const cfunc = this.packedCFuncTable[self]; assert(cfunc !== undefined); - return cfunc(args, typeCodes, nargs, ret, resourceHandle); + return cfunc(self, args, num_args, result); }; - const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( - resourceHandle: Pointer + const wasmFunctionDeleter: ctypes.FTVMFFIWasmFunctionDeleter = ( + self: Pointer ): void => { - this.packedCFuncTable[resourceHandle] = undefined; - this.packedCFuncTableFreeId.push(resourceHandle); + this.packedCFuncTable[self] = undefined; + this.packedCFuncTableFreeId.push(self); }; const newEnv = { - TVMWasmPackedCFunc: wasmPackedCFunc, - TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "TVMFFIWasmSafeCall": wasmSafeCall, + "TVMFFIWasmFunctionDeleter": wasmFunctionDeleter, "__console_log": (msg: string): void => { this.logger(msg); } diff --git a/web/src/memory.ts b/web/src/memory.ts index b0d4ff3bf194..850f3bd37195 100644 --- a/web/src/memory.ts +++ b/web/src/memory.ts @@ -137,16 +137,6 @@ export class Memory { result.set(this.viewU8.slice(ptr, ptr + numBytes)); return result; } - /** - * Load TVMByteArray from ptr. - * - * @param ptr The address of the header. - */ - loadTVMBytes(ptr: Pointer): Uint8Array { - const data = this.loadPointer(ptr); - const length = this.loadUSize(ptr + this.sizeofPtr()); - return this.loadRawBytes(data, length); - } /** * Load null-terminated C-string from ptr. * @param ptr The head address @@ -178,7 +168,56 @@ export class Memory { } this.viewU8.set(bytes, ptr); } - + // the following functions are related to TVM FFI + /** + * Load the object type index from the object handle. + * @param objectHandle The handle of the object. + * @returns The object type index. + */ + loadObjectTypeIndex(objectHandle: Pointer): number { + return this.loadI32(objectHandle); + } + /** + * Load the type key from the type info pointer. + * @param typeInfoPtr The pointer to the type info. + * @returns The type key. + */ + loadTypeInfoTypeKey(typeInfoPtr: Pointer): string { + const typeKeyPtr = typeInfoPtr + 2 * SizeOf.I32; + return this.loadByteArrayAsString(typeKeyPtr); + } + /** + * Load bytearray as string from ptr. + * @param byteArrayPtr The head address of the bytearray. + */ + loadByteArrayAsString(byteArrayPtr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const ptr = this.loadPointer(byteArrayPtr); + const length = this.loadUSize(byteArrayPtr + this.sizeofPtr()); + // NOTE: the views are still valid for read. + const ret = []; + for (let i = 0; i < length; i++) { + ret.push(String.fromCharCode(this.viewU8[ptr + i])); + } + return ret.join(""); + } + /** + * Load bytearray as bytes from ptr. + * @param byteArrayPtr The head address of the bytearray. + */ + loadByteArrayAsBytes(byteArrayPtr: Pointer): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const ptr = this.loadPointer(byteArrayPtr); + const length = this.loadUSize(byteArrayPtr + this.sizeofPtr()); + const result = new Uint8Array(length); + result.set(this.viewU8.slice(ptr, ptr + length)); + return result; +} + // private functions /** * Update memory view after the memory growth. */ @@ -365,6 +404,21 @@ export class CachedCallStack implements Disposable { this.viewU8.set(bytes, offset); } + /** + * Allocate a byte array for a string and return the offset of the byte array. + * @param data The string to allocate. + * @returns The offset of the byte array. + */ + allocByteArrayForString(data: string): PtrOffset { + const dataUint8: Uint8Array = StringToUint8Array(data); + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(dataUint8.length); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + this.storeRawBytes(dataOffset, dataUint8); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + return headerOffset; + } /** * Allocate then set C-String pointer to the offset. * This function will call into allocBytes to allocate necessary data. diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts index 46848a6dec1c..1e3af6f6438e 100644 --- a/web/src/rpc_server.ts +++ b/web/src/rpc_server.ts @@ -17,7 +17,7 @@ * under the License. */ -import { SizeOf, ArgTypeCode } from "./ctypes"; +import { SizeOf, TypeIndex } from "./ctypes"; import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; import { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu"; import * as compact from "./compact"; @@ -228,21 +228,16 @@ export class RPCServer { // eslint-disable-next-line @typescript-eslint/no-unused-vars const ver = Uint8ArrayToString(reader.readByteArray()); const nargs = reader.readU32(); - const tcodes = []; const args = []; for (let i = 0; i < nargs; ++i) { - tcodes.push(reader.readU32()); - } - - for (let i = 0; i < nargs; ++i) { - const tcode = tcodes[i]; - if (tcode === ArgTypeCode.TVMStr) { + const typeIndex = reader.readU32(); + if (typeIndex === TypeIndex.kTVMFFIRawStr) { const str = Uint8ArrayToString(reader.readByteArray()); args.push(str); - } else if (tcode === ArgTypeCode.TVMBytes) { + } else if (typeIndex === TypeIndex.kTVMFFIByteArrayPtr) { args.push(reader.readByteArray()); } else { - throw new Error("cannot support type code " + tcode); + throw new Error("cannot support type index " + typeIndex); } } this.onInitServer(args, header, body); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 5c47c0e7a52f..47902086f588 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -20,7 +20,7 @@ /** * TVM JS Wasm Runtime library. */ -import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; +import { Pointer, PtrOffset, SizeOf, TypeIndex } from "./ctypes"; import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./support"; @@ -90,8 +90,8 @@ class FFILibrary implements Disposable { checkCall(code: number): void { if (code != 0) { const msgPtr = (this.exports - .TVMGetLastError as ctypes.FTVMGetLastError)(); - throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + .TVMFFIWasmGetLastError as ctypes.FTVMFFIWasmGetLastError)(); + throw new Error(this.memory.loadCString(msgPtr)); } } @@ -153,6 +153,13 @@ class FFILibrary implements Disposable { * Manages extra runtime context for the runtime. */ class RuntimeContext implements Disposable { + functionListGlobalNamesFunctor: PackedFunc; + moduleGetFunction: PackedFunc; + moduleImport: PackedFunc; + ndarrayEmpty: PackedFunc; + ndarrayCopyFromTo: PackedFunc; + ndarrayCopyFromJSBytes: PackedFunc; + ndarrayCopyToJSBytes: PackedFunc; arrayGetItem: PackedFunc; arrayGetSize: PackedFunc; arrayMake: PackedFunc; @@ -173,10 +180,21 @@ class RuntimeContext implements Disposable { applyPresenceAndFrequencyPenalty: PackedFunc; applySoftmaxWithTemperature: PackedFunc; concatEmbeddings: PackedFunc | undefined; - + bool: PackedFunc; private autoDisposeScope: Array> = []; - constructor(getGlobalFunc: (name: string) => PackedFunc) { + constructor( + getGlobalFunc: (name: string) => PackedFunc + ) { + this.functionListGlobalNamesFunctor = getGlobalFunc( + "ffi.FunctionListGlobalNamesFunctor" + ); + this.moduleGetFunction = getGlobalFunc("runtime.ModuleGetFunction"); + this.moduleImport = getGlobalFunc("runtime.ModuleImport"); + this.ndarrayEmpty = getGlobalFunc("runtime.TVMArrayAllocWithScope"); + this.ndarrayCopyFromTo = getGlobalFunc("runtime.TVMArrayCopyFromTo"); + this.ndarrayCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyFromBytes"); + this.ndarrayCopyToJSBytes = getGlobalFunc("tvmjs.runtime.NDArrayCopyToBytes"); this.arrayGetItem = getGlobalFunc("runtime.ArrayGetItem"); this.arrayGetSize = getGlobalFunc("runtime.ArraySize"); this.arrayMake = getGlobalFunc("runtime.Array"); @@ -189,18 +207,14 @@ class RuntimeContext implements Disposable { this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name"); - this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); + this.makeShapeTuple = getGlobalFunc("ffi.Shape"); this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob"); this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); - try { - this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); - } catch { - // TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility. - } + this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); } dispose(): void { @@ -306,35 +320,6 @@ export class Scalar { } } -/** - * Cell holds the PackedFunc object. - */ -class PackedFuncCell implements Disposable { - private handle: Pointer; - private lib: FFILibrary; - - constructor(handle: Pointer, lib: FFILibrary) { - this.handle = handle; - this.lib = lib; - } - - dispose(): void { - if (this.handle != 0) { - this.lib.checkCall( - (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) - ); - this.handle = 0; - } - } - - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("PackedFunc has already been disposed"); - } - return this.handle; - } -} - const DeviceEnumToStr: Record = { 1: "cpu", 2: "cuda", @@ -392,7 +377,7 @@ export class DLDevice { toString(): string { return ( - DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + DeviceEnumToStr[this.deviceType] + ":" + this.deviceId.toString() ); } } @@ -444,12 +429,78 @@ export class DLDataType { } } +/** + * Generic object base + */ +export class TVMObject implements Disposable { + protected handle: Pointer; + protected lib: FFILibrary; + protected ctx: RuntimeContext; + + constructor( + handle: Pointer, + lib: FFILibrary, + ctx: RuntimeContext + ) { + this.handle = handle; + this.lib = lib; + this.ctx = ctx; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get handle of module, check it is not null. + * + * @param requireNotNull require handle is not null. + * @returns The handle. + */ + getHandle(requireNotNull = true): Pointer { + if (requireNotNull && this.handle === 0) { + throw Error("Object has already been disposed"); + } + return this.handle; + } + + /** get the type index of the object */ + typeIndex(): number { + if (this.handle === 0) { + throw Error("The current Object has already been disposed"); + } + return this.lib.memory.loadObjectTypeIndex(this.handle); + } + + /** get the type key of the object */ + typeKey(): string { + const type_index = this.typeIndex(); + const typeInfoPtr = (this.lib.exports.TVMFFIGetTypeInfo as ctypes.FTVMFFIGetTypeInfo)( + type_index + ); + return this.lib.memory.loadTypeInfoTypeKey(typeInfoPtr); + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell extends TVMObject { + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) { + super(handle, lib, ctx); + } +} + /** * n-dimnesional array. */ -export class NDArray implements Disposable { - /** Internal array handle. */ - private handle: Pointer; + +export class NDArray extends TVMObject { /** Number of dimensions. */ ndim: number; /** Data type of the array. */ @@ -463,16 +514,14 @@ export class NDArray implements Disposable { private byteOffset: number; private dltensor: Pointer; private dataPtr: Pointer; - private lib: FFILibrary; - private ctx: RuntimeContext; private dlDataType: DLDataType; - constructor(handle: Pointer, isView: boolean, lib: FFILibrary, ctx: RuntimeContext) { - this.handle = handle; + constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext, isView: boolean) { + // if the array is a view, we need to create a new object with a null handle + // so dispose won't trigger memory free + const objectHandle = isView ? 0 : handle; + super(objectHandle, lib, ctx); this.isView = isView; - this.lib = lib; - this.ctx = ctx; - if (this.isView) { this.dltensor = handle; } else { @@ -535,20 +584,6 @@ export class NDArray implements Disposable { /*relative_byte_offset=*/ new Scalar(0, "int"), ); } - - /** - * Get handle of ndarray, check it is not null. - * - * @param requireNotNull require handle is not null. - * @returns The handle. - */ - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("NDArray has already been disposed"); - } - return this.handle; - } - /** * Get dataPtr of NDarray * @@ -561,14 +596,6 @@ export class NDArray implements Disposable { return this.dataPtr; } - dispose(): void { - if (this.handle != 0 && !this.isView) { - this.lib.checkCall( - (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) - ); - this.handle = 0; - } - } /** * Copy data from another NDArray or javascript array. * The number of elements must match. @@ -581,13 +608,7 @@ export class NDArray implements Disposable { Int32Array | Int8Array | Uint8Array | Uint8ClampedArray ): this { if (data instanceof NDArray) { - this.lib.checkCall( - (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( - data.getHandle(), - this.getHandle(), - 0 - ) - ); + this.ctx.ndarrayCopyFromTo(data, this); return this; } else { const size = this.shape.reduce((a, b) => { @@ -639,21 +660,7 @@ export class NDArray implements Disposable { if (nbytes != data.length) { throw new Error("Expect the data's length equals nbytes=" + nbytes); } - - const stack = this.lib.getOrAllocCallStack(); - - const tempOffset = stack.allocRawBytes(nbytes); - const tempPtr = stack.ptrFromOffset(tempOffset); - this.lib.memory.storeRawBytes(tempPtr, data); - this.lib.checkCall( - (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( - this.getHandle(), - tempPtr, - nbytes - ) - ); - - this.lib.recycleCallStack(stack); + this.ctx.ndarrayCopyFromJSBytes(this, data); return this; } /** @@ -664,26 +671,7 @@ export class NDArray implements Disposable { if (this.device.deviceType != DeviceStrToEnum.cpu) { throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); } - const size = this.shape.reduce((a, b) => { - return a * b; - }, 1); - - const nbytes = this.dlDataType.numStorageBytes() * size; - const stack = this.lib.getOrAllocCallStack(); - - const tempOffset = stack.allocRawBytes(nbytes); - const tempPtr = stack.ptrFromOffset(tempOffset); - this.lib.checkCall( - (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( - this.getHandle(), - tempPtr, - nbytes - ) - ); - const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); - - this.lib.recycleCallStack(stack); - return ret; + return this.ctx.ndarrayCopyToJSBytes(this) as Uint8Array; } /** @@ -709,52 +697,22 @@ export class NDArray implements Disposable { } private getDLTensorFromArrayHandle(handle: Pointer): Pointer { - // Note: this depends on the NDArray C ABI. - // keep this function in case of ABI change. - return handle; + return handle + SizeOf.ObjectHeader; } } + /** * Runtime Module. */ -export class Module implements Disposable { - private handle: Pointer; - private lib: FFILibrary; - private makePackedFunc: (ptr: Pointer) => PackedFunc; - +export class Module extends TVMObject { constructor( handle: Pointer, lib: FFILibrary, - makePackedFunc: (ptr: Pointer) => PackedFunc + ctx: RuntimeContext, ) { - this.handle = handle; - this.lib = lib; - this.makePackedFunc = makePackedFunc; - } - - dispose(): void { - if (this.handle != 0) { - this.lib.checkCall( - (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) - ); - this.handle = 0; - } - } - - /** - * Get handle of module, check it is not null. - * - * @param requireNotNull require handle is not null. - * @returns The handle. - */ - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("Module has already been disposed"); - } - return this.handle; + super(handle, lib, ctx); } - /** * Get a function in the module. * @param name The name of the function. @@ -762,33 +720,7 @@ export class Module implements Disposable { * @returns The result function. */ getFunction(name: string, queryImports = true): PackedFunc { - if (this.handle === 0) { - throw Error("Module has already been disposed"); - } - const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); - - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - - stack.commitToWasmMemory(outOffset); - - this.lib.checkCall( - (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( - this.getHandle(), - stack.ptrFromOffset(nameOffset), - queryImports ? 1 : 0, - outPtr - ) - ); - const handle = this.lib.memory.loadPointer(outPtr); - this.lib.recycleCallStack(stack); - if (handle === 0) { - throw Error("Cannot find function " + name); - } - const ret = this.makePackedFunc(handle); - return ret; + return this.ctx.moduleGetFunction(this, name, queryImports) as PackedFunc; } /** @@ -796,100 +728,16 @@ export class Module implements Disposable { * @param mod The module to be imported. */ importModule(mod: Module): void { - this.lib.checkCall( - (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( - this.getHandle(), - mod.getHandle() - ) - ); + this.ctx.moduleImport(this, mod); } } -/** - * Generic object base - */ -export class TVMObject implements Disposable { - private handle: Pointer; - private lib: FFILibrary; - protected ctx: RuntimeContext; - - constructor( - handle: Pointer, - lib: FFILibrary, - ctx: RuntimeContext - ) { - this.handle = handle; - this.lib = lib; - this.ctx = ctx; - } - - dispose(): void { - if (this.handle != 0) { - this.lib.checkCall( - (this.lib.exports.TVMObjectFree as ctypes.FTVMObjectFree)(this.handle) - ); - this.handle = 0; - } - } - - /** - * Get handle of module, check it is not null. - * - * @param requireNotNull require handle is not null. - * @returns The handle. - */ - getHandle(requireNotNull = true): Pointer { - if (requireNotNull && this.handle === 0) { - throw Error("Module has already been disposed"); - } - return this.handle; - } - - /** get the type index of the object */ - typeIndex(): number { - if (this.handle === 0) { - throw Error("The current Object has already been disposed"); - } - const stack = this.lib.getOrAllocCallStack(); - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - - this.lib.checkCall( - (this.lib.exports.TVMObjectGetTypeIndex as ctypes.FTVMObjectGetTypeIndex)( - this.getHandle(), - outPtr - ) - ); - const result = this.lib.memory.loadU32(outPtr); - this.lib.recycleCallStack(stack); - return result; - } - - /** get the type key of the object */ - typeKey(): string { - const type_index = this.typeIndex(); - const stack = this.lib.getOrAllocCallStack(); - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - this.lib.checkCall( - (this.lib.exports.TVMObjectTypeIndex2Key as ctypes.FTVMObjectTypeIndex2Key)( - type_index, - outPtr - ) - ); - const result = this.lib.memory.loadCString( - this.lib.memory.loadPointer(outPtr) - ); - this.lib.recycleCallStack(stack); - return result; - } -} /** Objectconstructor */ type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; /** All possible object types. */ -type TVMObjectBase = TVMObject | NDArray | Module | PackedFunc; +type TVMObjectBase = TVMObject | PackedFunc; /** Runtime array object. */ export class TVMArray extends TVMObject { @@ -1212,38 +1060,16 @@ export class Instance implements Disposable { * @returns The name list. */ listGlobalFuncNames(): Array { - const stack = this.lib.getOrAllocCallStack(); - - const outSizeOffset = stack.allocPtrArray(2); - - const outSizePtr = stack.ptrFromOffset(outSizeOffset); - const outArrayPtr = stack.ptrFromOffset( - outSizeOffset + this.lib.sizeofPtr() - ); - - this.lib.checkCall( - (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( - outSizePtr, - outArrayPtr - ) - ); - - const size = this.memory.loadI32(outSizePtr); - const array = this.memory.loadPointer(outArrayPtr); - const names: Array = []; - - for (let i = 0; i < size; ++i) { - names.push( - this.memory.loadCString( - this.memory.loadPointer(array + this.lib.sizeofPtr() * i) - ) - ); - } - - this.lib.recycleCallStack(stack); - return names; + return this.withNewScope(() => { + const functor = this.ctx.functionListGlobalNamesFunctor(); + const numNames = functor(new Scalar(-1, "int")) as number; + const names = new Array(numNames); + for (let i = 0; i < numNames; i++) { + names[i] = functor(new Scalar(i, "int")) as string; + } + return names; + }); } - /** * Register function to be global function in tvm runtime. * @param name The name of the function. @@ -1262,12 +1088,10 @@ export class Instance implements Disposable { const ioverride = override ? 1 : 0; const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const nameOffset = stack.allocByteArrayForString(name); stack.commitToWasmMemory(); - this.lib.checkCall( - (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + (this.lib.exports.TVMFFIFunctionSetGlobal as ctypes.FTVMFFIFunctionSetGlobal)( stack.ptrFromOffset(nameOffset), packedFunc._tvmPackedCell.getHandle(), ioverride @@ -1289,15 +1113,14 @@ export class Instance implements Disposable { private getGlobalFuncInternal(name: string, autoAttachToScope = true): PackedFunc { const stack = this.lib.getOrAllocCallStack(); - const nameOffset = stack.allocRawBytes(name.length + 1); - stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const nameOffset = stack.allocByteArrayForString(name); const outOffset = stack.allocPtrArray(1); const outPtr = stack.ptrFromOffset(outOffset); stack.commitToWasmMemory(outOffset); this.lib.checkCall( - (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + (this.exports.TVMFFIFunctionGetGlobal as ctypes.FTVMFFIFunctionGetGlobal)( stack.ptrFromOffset(nameOffset), outPtr ) @@ -1335,7 +1158,7 @@ export class Instance implements Disposable { private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { if (this.isPackedFunc(func)) return func as PackedFunc; - const ret = this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + const ret = this.createPackedFuncFromSafeCallType(this.wrapJSFuncAsSafeCallType(func)); if (autoAttachToScope) return this.ctx.attachToCurrentScope(ret); return ret; } @@ -1603,52 +1426,6 @@ export class Instance implements Disposable { } } - /** - * Convert dtype to {@link DLDataType} - * - * @param dtype The input dtype string or DLDataType. - * @returns The converted result. - */ - toDLDataType(dtype: string | DLDataType): DLDataType { - if (dtype instanceof DLDataType) return dtype; - if (typeof dtype === "string") { - let pattern = dtype; - let code, - bits = 32, - lanes = 1; - if (pattern.substring(0, 5) === "float") { - pattern = pattern.substring(5, pattern.length); - code = DLDataTypeCode.Float; - } else if (pattern.substring(0, 3) === "int") { - pattern = pattern.substring(3, pattern.length); - code = DLDataTypeCode.Int; - } else if (pattern.substring(0, 4) === "uint") { - pattern = pattern.substring(4, pattern.length); - code = DLDataTypeCode.UInt; - } else if (pattern.substring(0, 6) === "handle") { - pattern = pattern.substring(5, pattern.length); - code = DLDataTypeCode.OpaqueHandle; - bits = 64; - } else { - throw new Error("Unknown dtype " + dtype); - } - - const arr = pattern.split("x"); - if (arr.length >= 1) { - const parsed = parseInt(arr[0]); - if (parsed + "" === arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new DLDataType(code, bits, lanes); - } else { - throw new Error("Unknown dtype " + dtype); - } - } - /** * Create a new {@link Scalar} that can be passed to a PackedFunc. * @param value The number value. @@ -1698,36 +1475,8 @@ export class Instance implements Disposable { dtype: string | DLDataType = "float32", dev: DLDevice = this.device("cpu", 0) ): NDArray { - dtype = this.toDLDataType(dtype); shape = typeof shape === "number" ? [shape] : shape; - - const stack = this.lib.getOrAllocCallStack(); - const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); - for (let i = 0; i < shape.length; ++i) { - stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); - } - - const outOffset = stack.allocPtrArray(1); - const outPtr = stack.ptrFromOffset(outOffset); - stack.commitToWasmMemory(outOffset); - - this.lib.checkCall( - (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( - stack.ptrFromOffset(shapeOffset), - shape.length, - dtype.code, - dtype.bits, - dtype.lanes, - dev.deviceType, - dev.deviceId, - outPtr - ) - ); - const ret = this.ctx.attachToCurrentScope( - new NDArray(this.memory.loadPointer(outPtr), false, this.lib, this.ctx) - ); - this.lib.recycleCallStack(stack); - return ret; + return this.ctx.ndarrayEmpty(this.makeShapeTuple(shape), dtype, dev, null); } /** @@ -1936,15 +1685,13 @@ export class Instance implements Disposable { typeKey: string ): number { const stack = this.lib.getOrAllocCallStack(); - const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1); - stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey)); + const typeKeyOffset = stack.allocByteArrayForString(typeKey); const outOffset = stack.allocPtrArray(1); const outPtr = stack.ptrFromOffset(outOffset); stack.commitToWasmMemory(outOffset); - this.lib.checkCall( - (this.lib.exports.TVMObjectTypeKey2Index as ctypes.FTVMObjectTypeKey2Index)( + (this.lib.exports.TVMFFITypeKeyToIndex as ctypes.FTVMFFITypeKeyToIndex)( stack.ptrFromOffset(typeKeyOffset), outPtr ) @@ -2153,6 +1900,10 @@ export class Instance implements Disposable { (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMArray(handle, lib, ctx); }); + this.registerObjectConstructor("runtime.Module", + (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { + return new Module(handle, lib, ctx); + }); } /** Register global packed functions needed by the backend to the env. */ @@ -2224,8 +1975,8 @@ export class Instance implements Disposable { this.registerAsyncServerFunc("testing.asyncAddOne", addOne); } - private createPackedFuncFromCFunc( - func: ctypes.FTVMWasmPackedCFunc + private createPackedFuncFromSafeCallType( + func: ctypes.FTVMFFIWasmSafeCallType ): PackedFunc { let findex = this.env.packedCFuncTable.length; if (this.env.packedCFuncTableFreeId.length != 0) { @@ -2240,7 +1991,7 @@ export class Instance implements Disposable { const outPtr = stack.ptrFromOffset(outOffset); this.lib.checkCall( (this.exports - .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + .TVMFFIWasmFunctionCreate as ctypes.FTVMFFIWasmFunctionCreate)( findex, outPtr ) @@ -2256,20 +2007,19 @@ export class Instance implements Disposable { * * @parma stack The call stack * @param args The input arguments. - * @param argsValue The offset of argsValue. - * @param argsCode The offset of argsCode. + * @param packedArgs The offset of packedArgs. */ setPackedArguments( stack: CachedCallStack, args: Array, - argsValue: PtrOffset, - argsCode: PtrOffset + packedArgs: PtrOffset, ): void { for (let i = 0; i < args.length; ++i) { let val = args[i]; const tp = typeof val; - const valueOffset = argsValue + i * SizeOf.TVMValue; - const codeOffset = argsCode + i * SizeOf.I32; + const argOffset = packedArgs + i * SizeOf.TVMFFIAny; + const argTypeIndexOffset = argOffset; + const argValueOffset = argOffset + SizeOf.I32 * 2; // Convert string[] to a TVMArray of, hence treated as a TVMObject if (val instanceof Array && val.every(e => typeof e === "string")) { @@ -2278,97 +2028,100 @@ export class Instance implements Disposable { val = this.makeTVMArray(tvmStringArray); } + // clear off the extra padding valuesbefore ptr storage + stack.storeI32(argTypeIndexOffset + SizeOf.I32, 0); + stack.storeI32(argValueOffset + SizeOf.I32, 0); if (val instanceof NDArray) { if (!val.isView) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINDArray); + stack.storePtr(argValueOffset, val.getHandle()); } else { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMDLTensorHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDLTensorPtr); + stack.storePtr(argValueOffset, val.getHandle()); } } else if (val instanceof Scalar) { if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { - stack.storeI64(valueOffset, val.value); - stack.storeI32(codeOffset, ArgTypeCode.Int); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIInt); + stack.storeI64(argValueOffset, val.value); } else if (val.dtype.startsWith("float")) { - stack.storeF64(valueOffset, val.value); - stack.storeI32(codeOffset, ArgTypeCode.Float); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat); + stack.storeF64(argValueOffset, val.value); } else { assert(val.dtype === "handle", "Expect handle"); - stack.storePtr(valueOffset, val.value); - stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIOpaquePtr); + stack.storePtr(argValueOffset, val.value); } } else if (val instanceof DLDevice) { - stack.storeI32(valueOffset, val.deviceType); - stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); - stack.storeI32(codeOffset, ArgTypeCode.DLDevice); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDevice); + stack.storeI32(argValueOffset, val.deviceType); + stack.storeI32(argValueOffset + SizeOf.I32, val.deviceId); + } else if (tp === "boolean") { + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIBool); + stack.storeI64(argValueOffset, val ? 1 : 0); } else if (tp === "number") { - stack.storeF64(valueOffset, val); - stack.storeI32(codeOffset, ArgTypeCode.Float); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat); + stack.storeF64(argValueOffset, val); // eslint-disable-next-line no-prototype-builtins } else if (tp === "function" && val.hasOwnProperty("_tvmPackedCell")) { - stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle()); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction); } else if (val === null || val === undefined) { - stack.storePtr(valueOffset, 0); - stack.storeI32(codeOffset, ArgTypeCode.Null); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINone); + stack.storePtr(argValueOffset, 0); } else if (tp === "string") { - stack.allocThenSetArgString(valueOffset, val); - stack.storeI32(codeOffset, ArgTypeCode.TVMStr); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIRawStr); + stack.allocThenSetArgString(argValueOffset, val); } else if (val instanceof Uint8Array) { - stack.allocThenSetArgBytes(valueOffset, val); - stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIByteArrayPtr); + stack.allocThenSetArgBytes(argValueOffset, val); } else if (val instanceof Function) { val = this.toPackedFuncInternal(val, false); stack.tempArgs.push(val); - stack.storePtr(valueOffset, val._tvmPackedCell.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction); + stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle()); } else if (val instanceof Module) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); + stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIModule); + stack.storePtr(argValueOffset, val.getHandle()); } else if (val instanceof TVMObject) { - stack.storePtr(valueOffset, val.getHandle()); - stack.storeI32(codeOffset, ArgTypeCode.TVMObjectHandle); + stack.storeI32(argTypeIndexOffset, val.typeIndex()); + stack.storePtr(argValueOffset, val.getHandle()); } else { - throw new Error("Unsupported argument type " + tp); + throw new Error("Unsupported argument type " + tp + " value=`" + val.toString() + "`"); } } } - private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + private wrapJSFuncAsSafeCallType(func: Function): ctypes.FTVMFFIWasmSafeCallType { const lib = this.lib; return ( - argValues: Pointer, - argCodes: Pointer, - nargs: number, - ret: Pointer, // eslint-disable-next-line @typescript-eslint/no-unused-vars - _handle: Pointer + self: Pointer, + packedArgs: Pointer, + numArgs: number, + ret: Pointer ): number => { const jsArgs = []; // use scope to track js values. this.ctx.beginScope(); - for (let i = 0; i < nargs; ++i) { - const valuePtr = argValues + i * SizeOf.TVMValue; - const codePtr = argCodes + i * SizeOf.I32; - let tcode = lib.memory.loadI32(codePtr); - - if ( - tcode === ArgTypeCode.TVMObjectHandle || - tcode === ArgTypeCode.TVMObjectRValueRefArg || - tcode === ArgTypeCode.TVMPackedFuncHandle || - tcode === ArgTypeCode.TVMNDArrayHandle || - tcode === ArgTypeCode.TVMModuleHandle - ) { + for (let i = 0; i < numArgs; ++i) { + const argPtr = packedArgs + i * SizeOf.TVMFFIAny; + const typeIndex = lib.memory.loadI32(argPtr); + + if (typeIndex >= TypeIndex.kTVMFFIRawStr) { + // NOTE: the following code have limitations in asyncify mode. + // The reason is that the TVMFFIAnyViewToOwnedAny will simply + // get skipped during the rewinding process, causing memory failure + if (!this.asyncifyHandler.isNormalStackState()) { + throw Error("Cannot handle str/object argument callback in asyncify mode"); + } lib.checkCall( - (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( - valuePtr, - codePtr + (lib.exports.TVMFFIAnyViewToOwnedAny as ctypes.FTVMFFIAnyViewToOwnedAny)( + argPtr, + argPtr ) ); } - tcode = lib.memory.loadI32(codePtr); - jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); + jsArgs.push(this.retValueToJS(argPtr, true)); } let rv: any; @@ -2378,12 +2131,16 @@ export class Instance implements Disposable { // error handling // store error via SetLastError this.ctx.endScope(); - const errMsg = "JSCallbackError: " + error.message; + const errKind = "JSCallbackError" + const errMsg = error.message; const stack = lib.getOrAllocCallStack(); + const errKindOffset = stack.allocRawBytes(errKind.length + 1); + stack.storeRawBytes(errKindOffset, StringToUint8Array(errKind)); const errMsgOffset = stack.allocRawBytes(errMsg.length + 1); stack.storeRawBytes(errMsgOffset, StringToUint8Array(errMsg)); stack.commitToWasmMemory(); - (this.lib.exports.TVMAPISetLastError as ctypes.FTVMAPISetLastError)( + (this.lib.exports.FTVMFFIErrorSetRaisedByCStr as ctypes.FTVMFFIErrorSetRaisedByCStr)( + stack.ptrFromOffset(errKindOffset), stack.ptrFromOffset(errMsgOffset) ); this.lib.recycleCallStack(stack); @@ -2395,18 +2152,14 @@ export class Instance implements Disposable { this.ctx.endScope(); if (rv !== undefined && rv !== null) { const stack = lib.getOrAllocCallStack(); - const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); - const codeOffset = stack.allocRawBytes(SizeOf.I32); - this.setPackedArguments(stack, [rv], valueOffset, codeOffset); - const valuePtr = stack.ptrFromOffset(valueOffset); - const codePtr = stack.ptrFromOffset(codeOffset); + const argOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); + this.setPackedArguments(stack, [rv], argOffset); stack.commitToWasmMemory(); + const argPtr = stack.ptrFromOffset(argOffset); lib.checkCall( - (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( - ret, - valuePtr, - codePtr, - 1 + (lib.exports.TVMFFIAnyViewToOwnedAny as ctypes.FTVMFFIAnyViewToOwnedAny)( + argPtr, + ret ) ); lib.recycleCallStack(stack); @@ -2416,38 +2169,25 @@ export class Instance implements Disposable { } private makePackedFunc(handle: Pointer): PackedFunc { - const cell = new PackedFuncCell(handle, this.lib); - + const cell = new PackedFuncCell(handle, this.lib, this.ctx); const packedFunc = (...args: any): any => { const stack = this.lib.getOrAllocCallStack(); - - const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); - const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); - - this.setPackedArguments(stack, args, valueOffset, tcodeOffset); - - const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); - const rcodeOffset = stack.allocRawBytes(SizeOf.I32); - const rvaluePtr = stack.ptrFromOffset(rvalueOffset); - const rcodePtr = stack.ptrFromOffset(rcodeOffset); - - // pre-store the rcode to be null, in case caller unwind - // and not have chance to reset this rcode. - stack.storeI32(rcodeOffset, ArgTypeCode.Null); + const argsOffset = stack.allocRawBytes(SizeOf.TVMFFIAny * args.length); + this.setPackedArguments(stack, args, argsOffset); + const retOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); + // pre-store the result to be null + stack.storeI32(retOffset, TypeIndex.kTVMFFINone); stack.commitToWasmMemory(); - this.lib.checkCall( - (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + (this.exports.TVMFFIFunctionCall as ctypes.FTVMFFIFunctionCall)( cell.getHandle(), - stack.ptrFromOffset(valueOffset), - stack.ptrFromOffset(tcodeOffset), + stack.ptrFromOffset(argsOffset), args.length, - rvaluePtr, - rcodePtr + stack.ptrFromOffset(retOffset) ) ); - const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); + const ret = this.retValueToJS(stack.ptrFromOffset(retOffset), false); this.lib.recycleCallStack(stack); return ret; }; @@ -2463,78 +2203,91 @@ export class Instance implements Disposable { /** * Creaye return value of the packed func. The value us auto-tracked for dispose. - * @param rvaluePtr The location of rvalue - * @param tcode The type code. + * @param resultAnyPtr The location of rvalue * @param callbackArg Whether it is being used in callbackArg. * @returns The JS value. */ - private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { - switch (tcode) { - case ArgTypeCode.Int: - case ArgTypeCode.UInt: - case ArgTypeCode.TVMArgBool: - return this.memory.loadI64(rvaluePtr); - case ArgTypeCode.Float: - return this.memory.loadF64(rvaluePtr); - case ArgTypeCode.TVMOpaqueHandle: { - return this.memory.loadPointer(rvaluePtr); + private retValueToJS(resultAnyPtr: Pointer, callbackArg: boolean): any { + const typeIndex = this.memory.loadI32(resultAnyPtr); + const valuePtr = resultAnyPtr + SizeOf.I32 * 2; + switch (typeIndex) { + case TypeIndex.kTVMFFINone: return undefined; + case TypeIndex.kTVMFFIBool: + return this.memory.loadI64(valuePtr) != 0; + case TypeIndex.kTVMFFIInt: + return this.memory.loadI64(valuePtr); + case TypeIndex.kTVMFFIFloat: + return this.memory.loadF64(valuePtr); + case TypeIndex.kTVMFFIOpaquePtr: { + return this.memory.loadPointer(valuePtr); } - case ArgTypeCode.TVMNDArrayHandle: { + case TypeIndex.kTVMFFINDArray: { return this.ctx.attachToCurrentScope( - new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, this.ctx) + new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false) ); } - case ArgTypeCode.TVMDLTensorHandle: { + case TypeIndex.kTVMFFIDLTensorPtr: { assert(callbackArg); // no need to attach as we are only looking at view - return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib, this.ctx); + return new NDArray(this.memory.loadPointer(valuePtr), this.lib, this.ctx, true); } - case ArgTypeCode.TVMPackedFuncHandle: { + case TypeIndex.kTVMFFIFunction: { return this.ctx.attachToCurrentScope( - this.makePackedFunc(this.memory.loadPointer(rvaluePtr)) + this.makePackedFunc(this.memory.loadPointer(valuePtr)) ); } - case ArgTypeCode.TVMModuleHandle: { - return this.ctx.attachToCurrentScope( - new Module( - this.memory.loadPointer(rvaluePtr), - this.lib, - (ptr: Pointer) => { - return this.ctx.attachToCurrentScope(this.makePackedFunc(ptr)); - } - ) + case TypeIndex.kTVMFFIDevice: { + const deviceType = this.memory.loadI32(valuePtr); + const deviceId = this.memory.loadI32(valuePtr + SizeOf.I32); + return this.device(deviceType, deviceId); + } + case TypeIndex.kTVMFFIDataType: { + // simply return dtype as tring to keep things simple + this.lib.checkCall( + (this.lib.exports.TVMFFIDataTypeToString as ctypes.FTVMFFIDataTypeToString)(valuePtr, valuePtr) + ); + const strObjPtr = this.memory.loadPointer(valuePtr); + const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) + ); + return result; + } + case TypeIndex.kTVMFFIStr: { + const strObjPtr = this.memory.loadPointer(valuePtr); + const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(strObjPtr) ); + return result; } - case ArgTypeCode.TVMObjectHandle: { - const obj = new TVMObject( - this.memory.loadPointer(rvaluePtr), - this.lib, - this.ctx + case TypeIndex.kTVMFFIBytes: { + const bytesObjPtr = this.memory.loadPointer(valuePtr); + const result = this.memory.loadByteArrayAsBytes(bytesObjPtr + SizeOf.ObjectHeader); + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectFree as ctypes.FTVMFFIObjectFree)(bytesObjPtr) ); - const func = this.objFactory.get(obj.typeIndex()) - if (func != undefined) { - return this.ctx.attachToCurrentScope( - func(obj.getHandle(), this.lib, this.ctx) + return result; + } + default: { + if (typeIndex >= TypeIndex.kTVMFFIStaticObjectBegin) { + const obj = new TVMObject( + this.memory.loadPointer(valuePtr), + this.lib, + this.ctx ); + const func = this.objFactory.get(obj.typeIndex()) + if (func != undefined) { + return this.ctx.attachToCurrentScope( + func(obj.getHandle(), this.lib, this.ctx) + ); + } else { + return this.ctx.attachToCurrentScope(obj); + } } else { - return this.ctx.attachToCurrentScope(obj); + throw new Error("Unsupported return type code=" + typeIndex); } } - case ArgTypeCode.Null: return undefined; - case ArgTypeCode.DLDevice: { - const deviceType = this.memory.loadI32(rvaluePtr); - const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); - return this.device(deviceType, deviceId); - } - case ArgTypeCode.TVMStr: { - const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); - return ret; - } - case ArgTypeCode.TVMBytes: { - return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); - } - default: - throw new Error("Unsupported return type code=" + tcode); } } } diff --git a/web/tests/node/test_ndarray.js b/web/tests/node/test_ndarray.js index 8d369216d2d8..495d05070147 100644 --- a/web/tests/node/test_ndarray.js +++ b/web/tests/node/test_ndarray.js @@ -38,7 +38,7 @@ function testArrayCopy(dtype, arrayType) { let data = [1, 2, 3, 4, 5, 6]; let a = tvm.empty([2, 3], dtype).copyFrom(data); - assert(a.device.toString() == "cpu(0)"); + assert(a.device.toString() == "cpu:0"); assert(a.shape[0] == 2 && a.shape[1] == 3); let ret = a.toArray(); diff --git a/web/tests/node/test_object.js b/web/tests/node/test_object.js index 2423ef4ceb46..3db3bd9c8431 100644 --- a/web/tests/node/test_object.js +++ b/web/tests/node/test_object.js @@ -42,10 +42,5 @@ test("object", () => { let t1 = b.get(1); assert(t1.getHandle() == t.getHandle()); - - let ret_string = tvm.getGlobalFunc("testing.ret_string"); - let s1 = ret_string("hello"); - assert(s1 == "hello"); - ret_string.dispose(); }); }); diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index e1d070f0e473..e2b6c7b7c9b3 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -37,7 +37,7 @@ let tvm = new tvmjs.Instance( test("GetGlobal", () => { tvm.beginScope(); let flist = tvm.listGlobalFuncNames(); - let faddOne = tvm.getGlobalFunc("testing.add_one"); + let faddOne = tvm.getGlobalFunc("tvmjs.testing.add_one"); let fecho = tvm.getGlobalFunc("testing.echo"); assert(faddOne(tvm.scalar(1, "int")) == 2); @@ -146,31 +146,6 @@ test("ExceptionPassing", () => { tvm.endScope(); }); - -test("AsyncifyFunc", async () => { - if (!tvm.asyncifyEnabled()) { - console.log("Skip asyncify tests as it is not enabled.."); - return; - } - tvm.beginScope(); - tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) { - await new Promise(resolve => setTimeout(resolve, 10)); - return x; - }); - let fecho = tvm.wrapAsyncifyPackedFunc( - tvm.getGlobalFunc("async_sleep_echo") - ); - let fcall = tvm.wrapAsyncifyPackedFunc( - tvm.getGlobalFunc("testing.call") - ); - assert((await fecho(1)) == 1); - assert((await fecho(2)) == 2); - assert((await fcall(fecho, 2) == 2)); - tvm.endScope(); - assert(fecho._tvmPackedCell.getHandle(false) == 0); - assert(fcall._tvmPackedCell.getHandle(false) == 0); -}); - test("NDArrayCbArg", () => { tvm.beginScope(); let use_count = tvm.getGlobalFunc("testing.object_use_count"); @@ -204,8 +179,32 @@ test("NDArrayCbArg", () => { test("Logging", () => { tvm.beginScope(); - const log_info = tvm.getGlobalFunc("testing.log_info_str"); + const log_info = tvm.getGlobalFunc("tvmjs.testing.log_info_str"); log_info("helow world") log_info.dispose(); tvm.endScope(); }); + +test("AsyncifyFunc", async () => { + if (!tvm.asyncifyEnabled()) { + console.log("Skip asyncify tests as it is not enabled.."); + return; + } + tvm.beginScope(); + tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) { + await new Promise(resolve => setTimeout(resolve, 10)); + return x; + }); + let fecho = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("async_sleep_echo") + ); + let fcall = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("tvmjs.testing.call") + ); + assert((await fecho(1)) == 1); + assert((await fecho(2)) == 2); + assert((await fcall(fecho, 2) == 2)); + tvm.endScope(); + assert(fecho._tvmPackedCell.getHandle(false) == 0); + assert(fcall._tvmPackedCell.getHandle(false) == 0); +}); diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py index e831afd9d3f8..8925da00a489 100644 --- a/web/tests/python/webgpu_rpc_test.py +++ b/web/tests/python/webgpu_rpc_test.py @@ -35,7 +35,6 @@ def test_rpc(): return # generate the wasm library target = tvm.target.Target("webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm") - runtime = Runtime("cpp", {"system-lib": True}) n = te.var("n") A = te.placeholder((n,), name="A") From 8aecab87ef3e816dfd3d74713c78c40f1834a6dd Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Wed, 4 Jun 2025 12:17:37 +0530 Subject: [PATCH 13/13] Minor fix --- src/runtime/metal/metal_module.mm | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 6eac3280639f..c1a5ccfdd1a3 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -267,11 +267,7 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) auto it = fmap_.find(name); if (it == fmap_.end()) { ret = ffi::Function(); -<<<<<<< HEAD - return ret; -======= return; ->>>>>>> main } const FunctionInfo& info = it->second; MetalWrappedFunc f;