From 8b5a81d834da3cd913c3bd7f040e88dee80079f7 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 13:22:25 -0400 Subject: [PATCH 01/47] trying to understand why batchnorm returns all zeros --- .../torch/exported_program_translator.py | 24 ++++++++++---- python/tvm/relax/op/nn/nn.py | 14 ++++++++ .../relax/test_from_exported_to_cuda.py | 33 +++++++++++++++++++ 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e8e870671402..bdd028134b7d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -57,16 +57,27 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + + print("calling batch_norm with the following parameters:") + print("x:", x) + print("weight:", weight) + print("bias:", bias) + print("running_mean:", running_mean) + print("running_var:", running_var) + print("momentum:", momentum) + print("eps:", eps) return self.block_builder.emit( relax.op.nn.batch_norm( - x, - weight, - bias, - running_mean, - running_var, - axis=1, + data=x, + gamma=weight, + beta=bias, + moving_mean=running_mean, + moving_var=running_var, + axis=1, # Always over channel epsilon=eps, + center=False, # TODO + scale=False, # TODO momentum=momentum, ) ) @@ -235,6 +246,7 @@ def create_convert_map( "linalg_vector_norm.default": self._linalg_vector_norm, # neural network "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, + "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 5a1895cbc14f..367970b1bf76 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1486,6 +1486,20 @@ def batch_norm( result : relax.Expr The computed result. """ + print("\n!!The parameters passed to _ffi_api.batch_norm are: !!!!!!!!!") + print("data: ", data) + print("gamma: ", gamma) + print("beta: ", beta) + print("moving_mean: ", moving_mean) + print(dir(moving_mean)) # TODO find a way to print args + print("moving_mean args 0: ", moving_mean.args[0]) + print("moving_var: ", moving_var) + print("moving_var args 0: ", moving_var.args[0]) + print("axis: ", axis) + print("epsilon: ", epsilon) + print("center: ", center) + print("scale: ", scale) + print("momentum: ", momentum) return _ffi_api.batch_norm( # type: ignore data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum ) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index d39bb8e9fea3..18e48b06e67f 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,8 +15,14 @@ # specific language governing permissions and limitations # under the License. +# TODO remove +import sys +sys.path.append('/ssd1/htalendr/tvm/python') + + import numpy as np import torch +from torch import nn from torch.export import export import tvm @@ -88,5 +94,32 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) +@tvm.testing.parametrize_targets("cuda") + +def test_batch_norm(target, dev): + + # TODO no momentum + # raw_data = np.random.randn(1,2,1,1).astype(np.float32) + raw_data = np.array([[[[10.0]],[[20.0]]]]).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2, eps=1e-02, momentum=0.0, + affine=False, track_running_stats=True, + device=None, dtype=None).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + # TODO correct output above should be [9.95, 19.9] https://chatgpt.com/c/67cf1bc1-1934-8006-9b22-8166c46ee1bc + + # TODO with momentum + # raw_data = np.random.randn(1,4,2,2).astype(np.float32) + # torch_module0 = nn.BatchNorm2d(4, eps=1e-05, momentum=0.0, + # affine=False, track_running_stats=True, + # device=None, dtype=None).eval() + # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + + # TODO uncomment this one with simpler arguments? + # raw_data = np.random.randn(4,2,2,2).astype(np.float32) + # torch_module0 = nn.BatchNorm2d(2).eval() + # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + if __name__ == "__main__": tvm.testing.main() From 99373ae84152c8e2b5bc10c1ec48b68c8fec4dea Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Mon, 10 Mar 2025 14:56:12 -0400 Subject: [PATCH 02/47] debugging training vs non-training batch norm --- .../torch/exported_program_translator.py | 1 + python/tvm/relax/op/nn/nn.py | 18 ++++++++++-------- python/tvm/relax/transform/legalize_ops/nn.py | 2 +- python/tvm/topi/nn/batch_norm.py | 12 ++++++++++++ src/relax/op/nn/nn.cc | 1 + .../python/relax/test_from_exported_to_cuda.py | 10 ++++------ 6 files changed, 29 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index bdd028134b7d..4c25765f906b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -247,6 +247,7 @@ def create_convert_map( # neural network "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? + "_native_batch_norm_legit_functional.default": self._batch_norm_legit_no_training, # when I don't do eval . TODO doesn't work right now! "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 367970b1bf76..86e39903bff8 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1492,14 +1492,16 @@ def batch_norm( print("beta: ", beta) print("moving_mean: ", moving_mean) print(dir(moving_mean)) # TODO find a way to print args - print("moving_mean args 0: ", moving_mean.args[0]) - print("moving_var: ", moving_var) - print("moving_var args 0: ", moving_var.args[0]) - print("axis: ", axis) - print("epsilon: ", epsilon) - print("center: ", center) - print("scale: ", scale) - print("momentum: ", momentum) + # moving_mean.show() + # print("moving_mean handle type: ", moving_mean.handle) + # print("moving_mean handle: ", type(moving_mean.handle)) + # print("moving_var: ", moving_var) + # print("moving_var args 0: ", moving_var.args[0]) + # print("axis: ", axis) + # print("epsilon: ", epsilon) + # print("center: ", center) + # print("scale: ", scale) + # print("momentum: ", momentum) return _ffi_api.batch_norm( # type: ignore data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum ) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index d9fb4701f7e9..8389a99b919a 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -553,7 +553,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: scale=call.attrs.scale, # By default relax batch_norm is training mode. # To transform it to inference mode, use DecomposeOpsForInference. - training=True, + training=False, # TODO this is wrong!! momentum=call.attrs.momentum, ) diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index 3181efd7daa6..fe0f77c63af4 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -111,6 +111,9 @@ def batch_norm( shape = [1] * len(data.shape) shape[axis] = data.shape[axis] + print("IN BATCH_NORM !!!!!!!!!!!!!!!!!!!!!") + print("training: ", training) + if training: reduce_axes = list(range(len(data.shape))) reduce_axes.remove(axis) @@ -125,6 +128,15 @@ def batch_norm( else: moving_mean_rs = topi.reshape(moving_mean, shape) moving_var_rs = topi.reshape(moving_var, shape) + + print("CALCULATING OUT FOR BATCH_NORM !!!!!!!!!!!!!!!!!!!!!") + print("data shape: ", data.shape) + print("data: ", data) + print("moving_mean shape ", moving_mean.shape) + print("moving_mean: ", moving_mean) + print("moving_var shape: ", moving_var.shape) + print("moving_var: ", moving_var) + out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) if scale: diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b4668d65d399..b8f9729784b1 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -253,6 +253,7 @@ TVM_REGISTER_NODE_TYPE(BatchNormAttrs); Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale, double momentum) { + ObjectPtr attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 18e48b06e67f..b35e3b59e4b1 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -48,8 +48,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - # TODO try pipeline below? - # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, dev) @@ -99,11 +97,11 @@ def forward(self, x): def test_batch_norm(target, dev): # TODO no momentum - # raw_data = np.random.randn(1,2,1,1).astype(np.float32) - raw_data = np.array([[[[10.0]],[[20.0]]]]).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2, eps=1e-02, momentum=0.0, + raw_data = np.random.randn(8,8,4,4).astype(np.float32) + # raw_data = np.array([[[[10.0]],[[20.0]]]]).astype(np.float32) + torch_module0 = nn.BatchNorm2d(8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, - device=None, dtype=None).eval() + device=None, dtype=None) assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # TODO correct output above should be [9.95, 19.9] https://chatgpt.com/c/67cf1bc1-1934-8006-9b22-8166c46ee1bc From b0e11541e3aa9446b0f18f819c64913e7450d143 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 15 Mar 2025 22:29:02 -0400 Subject: [PATCH 03/47] added training in attrs --- include/tvm/relax/attrs/nn.h | 2 ++ .../tvm/relax/frontend/torch/exported_program_translator.py | 2 ++ python/tvm/relax/transform/legalize_ops/nn.py | 4 +--- src/relax/op/nn/nn.cc | 4 +++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 832934417484..4ccb1b4a9936 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -462,6 +462,7 @@ struct BatchNormAttrs : public tvm::AttrsNode { bool center; bool scale; double momentum; + bool training; TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") { TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied."); @@ -470,6 +471,7 @@ struct BatchNormAttrs : public tvm::AttrsNode { "Indicating if the beta offset will be added to the normalized tensor."); TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); TVM_ATTR_FIELD(momentum).describe("The value used for the moving_mean and moving_var update."); + TVM_ATTR_FIELD(training).describe("Whether we are training (not in eval mode)."); } }; // struct BatchNormAttrs diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2987b529dcc4..d8815d7b39be 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -277,6 +277,7 @@ def create_convert_map( # linear algebra "linalg_vector_norm.default": self._linalg_vector_norm, # neural network + # TODO figure out all calls to batchnorm HERE and in fx_translator "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? "_native_batch_norm_legit_functional.default": self._batch_norm_legit_no_training, # when I don't do eval . TODO doesn't work right now! @@ -436,6 +437,7 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" + print("Found a function called", func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 8389a99b919a..7b2090f7366d 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -551,9 +551,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: epsilon=call.attrs.epsilon, center=call.attrs.center, scale=call.attrs.scale, - # By default relax batch_norm is training mode. - # To transform it to inference mode, use DecomposeOpsForInference. - training=False, # TODO this is wrong!! + training=call.attrs.training, momentum=call.attrs.momentum, ) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b8f9729784b1..193de8227e1b 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -252,7 +252,8 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, TVM_REGISTER_NODE_TYPE(BatchNormAttrs); Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // - int axis, double epsilon, bool center, bool scale, double momentum) { + int axis, double epsilon, bool center, bool scale, double momentum, + bool training) { ObjectPtr attrs = make_object(); attrs->axis = axis; @@ -260,6 +261,7 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ attrs->center = center; attrs->scale = scale; attrs->momentum = momentum; + attrs->training = training; static const Op& op = Op::Get("relax.nn.batch_norm"); return Call(op, From dde7872064863000a5f5d83881684624ae7ab093 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 15 Mar 2025 22:30:58 -0400 Subject: [PATCH 04/47] training False --- .../frontend/torch/exported_program_translator.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index d8815d7b39be..66c6659b3435 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -57,15 +57,7 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) - - print("calling batch_norm with the following parameters:") - print("x:", x) - print("weight:", weight) - print("bias:", bias) - print("running_mean:", running_mean) - print("running_var:", running_var) - print("momentum:", momentum) - print("eps:", eps) + training = False # This method is only called for eval mode return self.block_builder.emit( relax.op.nn.batch_norm( @@ -79,6 +71,7 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: center=False, # TODO scale=False, # TODO momentum=momentum, + training=training, ) ) From dff60db0ed08d6e92b05179baebfa3dbe5f751ae Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 15 Mar 2025 22:36:20 -0400 Subject: [PATCH 05/47] training argument in nn.py --- include/tvm/relax/attrs/nn.h | 2 +- python/tvm/relax/op/nn/nn.py | 25 ++++++++----------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 4ccb1b4a9936..8f63012e095a 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -471,7 +471,7 @@ struct BatchNormAttrs : public tvm::AttrsNode { "Indicating if the beta offset will be added to the normalized tensor."); TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); TVM_ATTR_FIELD(momentum).describe("The value used for the moving_mean and moving_var update."); - TVM_ATTR_FIELD(training).describe("Whether we are training (not in eval mode)."); + TVM_ATTR_FIELD(training).describe("Whether we are training (i.e., not in eval mode)."); } }; // struct BatchNormAttrs diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 86e39903bff8..907be8a3f4c0 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1393,6 +1393,7 @@ def batch_norm( center: bool = True, scale: bool = True, momentum: float = 0.1, + training: bool = True, ) -> Expr: r""" Batch normalization layer (Ioffe and Szegedy, 2014). @@ -1481,29 +1482,19 @@ def batch_norm( momentum : float The value used for the moving_mean and moving_var update. + training : bool + A boolean value to indicate whether training or in eval mode. By default. + relax batch_norm is training mode. To transform it to inference mode, + can use DecomposeOpsForInference. + + Returns ------- result : relax.Expr The computed result. """ - print("\n!!The parameters passed to _ffi_api.batch_norm are: !!!!!!!!!") - print("data: ", data) - print("gamma: ", gamma) - print("beta: ", beta) - print("moving_mean: ", moving_mean) - print(dir(moving_mean)) # TODO find a way to print args - # moving_mean.show() - # print("moving_mean handle type: ", moving_mean.handle) - # print("moving_mean handle: ", type(moving_mean.handle)) - # print("moving_var: ", moving_var) - # print("moving_var args 0: ", moving_var.args[0]) - # print("axis: ", axis) - # print("epsilon: ", epsilon) - # print("center: ", center) - # print("scale: ", scale) - # print("momentum: ", momentum) return _ffi_api.batch_norm( # type: ignore - data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum + data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum, training ) From f1986d9af19da8349dcdadfe7ebb4a8564424a2e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sat, 15 Mar 2025 22:40:41 -0400 Subject: [PATCH 06/47] little cleanup before building --- python/tvm/topi/nn/batch_norm.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index fe0f77c63af4..fa03aed73660 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -111,9 +111,6 @@ def batch_norm( shape = [1] * len(data.shape) shape[axis] = data.shape[axis] - print("IN BATCH_NORM !!!!!!!!!!!!!!!!!!!!!") - print("training: ", training) - if training: reduce_axes = list(range(len(data.shape))) reduce_axes.remove(axis) @@ -129,14 +126,6 @@ def batch_norm( moving_mean_rs = topi.reshape(moving_mean, shape) moving_var_rs = topi.reshape(moving_var, shape) - print("CALCULATING OUT FOR BATCH_NORM !!!!!!!!!!!!!!!!!!!!!") - print("data shape: ", data.shape) - print("data: ", data) - print("moving_mean shape ", moving_mean.shape) - print("moving_mean: ", moving_mean) - print("moving_var shape: ", moving_var.shape) - print("moving_var: ", moving_var) - out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) if scale: From 1545b993870066d92de1c9052ac7a558525e809b Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 01:15:59 -0400 Subject: [PATCH 07/47] fix copy-paste errors --- src/relax/op/nn/nn.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 193de8227e1b..04b0411f2f03 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -391,7 +391,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call, TVM_REGISTER_OP("relax.nn.layer_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm) @@ -503,7 +503,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call, TVM_REGISTER_OP("relax.nn.group_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("data", "Tensor", "Input to which group_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoGroupNorm) From 77cc1d816851b38d8e3cf245031111aa60aa8be7 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 01:31:44 -0400 Subject: [PATCH 08/47] builds, but should probably just update nn.h instead --- src/relax/op/nn/nn.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 04b0411f2f03..86c193f01673 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -251,10 +251,9 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, /* relax.nn.batch_norm */ TVM_REGISTER_NODE_TYPE(BatchNormAttrs); -Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // - int axis, double epsilon, bool center, bool scale, double momentum, - bool training) { - +Expr batch_norm_impl(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale, double momentum, + bool training) { ObjectPtr attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -269,8 +268,7 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ std::move(moving_var)}, Attrs{attrs}, {}); } - -TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); +TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm_impl); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); From 0dbf8fe1e5236b41d1f7dc9a69f99090120da136 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 01:32:52 -0400 Subject: [PATCH 09/47] batch_norm build --- src/relax/op/nn/nn.cc | 7 +++---- src/relax/op/nn/nn.h | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 86c193f01673..826711538c68 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -251,9 +251,8 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, /* relax.nn.batch_norm */ TVM_REGISTER_NODE_TYPE(BatchNormAttrs); -Expr batch_norm_impl(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // - int axis, double epsilon, bool center, bool scale, double momentum, - bool training) { +Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // + int axis, double epsilon, bool center, bool scale, double momentum, bool training) { ObjectPtr attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -268,7 +267,7 @@ Expr batch_norm_impl(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr mo std::move(moving_var)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm_impl); +TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { Array input_sinfo = GetInputTensorStructInfo(call, ctx); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index a3658fed5430..28c14139b97b 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -68,7 +68,7 @@ Expr log_softmax(Expr data, int axis); /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // - int axis, double epsilon, bool center, bool scale, double momentum); + int axis, double epsilon, bool center, bool scale, double momentum, bool training); /*! \brief Compute layer normalization. */ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, From 1164d215ce894d0607b5e26495bb0029ba659971 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 01:37:18 -0400 Subject: [PATCH 10/47] first batchnorm test passes with .eval(), but not without, and copy fials --- tests/python/relax/test_from_exported_to_cuda.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 76073e4c4de5..23db8cc60c94 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,11 @@ # specific language governing permissions and limitations # under the License. +# TODO remove +import sys +sys.path.append('/ssd1/htalendr/tvm/python') + + import tvm from tvm import relax import tvm.testing @@ -290,7 +295,7 @@ def test_batch_norm(target, dev): # raw_data = np.array([[[[10.0]],[[20.0]]]]).astype(np.float32) torch_module0 = nn.BatchNorm2d(8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, - device=None, dtype=None) + device=None, dtype=None).eval() # TODO make test pass without .eval() ! assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # TODO correct output above should be [9.95, 19.9] https://chatgpt.com/c/67cf1bc1-1934-8006-9b22-8166c46ee1bc From a72ce6ecf4311f3bd94491e69dec7355b876eeb9 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 01:40:12 -0400 Subject: [PATCH 11/47] copy failing --- python/tvm/relax/frontend/torch/exported_program_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 66c6659b3435..3bb4ff211a16 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -308,6 +308,7 @@ def create_convert_map( "cat.default": self._cat, "clamp.Tensor": self._clamp, "concat.default": self._cat, + "copy.default": self._copy_, "copy_.default": self._copy_, "cumsum.default": self._cumsum, "expand.default": self._expand, From 42728f7075931507ebdd79fdce5e07720bed27c6 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 02:32:15 -0400 Subject: [PATCH 12/47] todo --- .../torch/base_fx_graph_translator.py | 34 ++++++++++++++++++- .../torch/exported_program_translator.py | 1 + 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6bbc9d5de618..95b6d57dde11 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1084,8 +1084,40 @@ def _detach(self, node: fx.Node) -> relax.Var: # by the translator, and therefore we just return a copy of the input. return self.env[node.args[0]] + # TODO need to find correct way to implement copy.default. this is just one guess + def _copy(self, node: fx.Node) -> relax.Var: + # Returns a copy of the input tensor + import torch # type: ignore + + print("node.args[0]: ", node.args[0]) + print("node.args[1]: ", node.args[1]) + print("self.env[node.args[0]]", self.env[node.args[0]]) + print("self.env[node.args[1]]", self.env[node.args[1]]) + + print("type(node.args[0]): ", type(node.args[0])) + print("type(node.args[1]): ", type(node.args[1])) + print("type(self.env[node.args[0]]): ", type(self.env[node.args[0]])) + print("type(self.env[node.args[1]]): ", type(self.env[node.args[1]])) + + + x = self.env[node.args[0]] + print('A') + if len(node.args) == 2: + print('B') + if isinstance(node.args[1], torch.dtype): + print('C') + dtype = self._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + print('D') + dtype = self._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + print('E') + return x + + def _copy_(self, node: fx.Node) -> relax.Var: - # Copies the source tensor's to the destination tensor + # Copies the source tensor's into the destination tensor # In TVM, that means simply returning the source tensor return self.env[node.args[1]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3bb4ff211a16..c27b9b9852ed 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -331,6 +331,7 @@ def create_convert_map( "view.default": self._reshape, "reshape.default": self._reshape, # tensor creation + "copy.default": self._copy, "_to_copy.default": self._to_copy, "lift_fresh_copy.default": self._to_copy, "detach.default": self._detach, From 9ee0672319c19334fbd5de9c8fdaa0eeff6ed6d7 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 11:56:27 -0400 Subject: [PATCH 13/47] cleanup --- .../relax/test_from_exported_to_cuda.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 23db8cc60c94..4e43f3ee3d6d 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,11 +15,6 @@ # specific language governing permissions and limitations # under the License. -# TODO remove -import sys -sys.path.append('/ssd1/htalendr/tvm/python') - - import tvm from tvm import relax import tvm.testing @@ -289,28 +284,25 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_batch_norm(target, dev): - - # TODO no momentum + # No momentum, eval raw_data = np.random.randn(8,8,4,4).astype(np.float32) - # raw_data = np.array([[[[10.0]],[[20.0]]]]).astype(np.float32) torch_module0 = nn.BatchNorm2d(8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, - device=None, dtype=None).eval() # TODO make test pass without .eval() ! + device=None, dtype=None).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # TODO correct output above should be [9.95, 19.9] https://chatgpt.com/c/67cf1bc1-1934-8006-9b22-8166c46ee1bc - # TODO with momentum - # raw_data = np.random.randn(1,4,2,2).astype(np.float32) - # torch_module0 = nn.BatchNorm2d(4, eps=1e-05, momentum=0.0, - # affine=False, track_running_stats=True, - # device=None, dtype=None).eval() - # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + # With momentum, eval + raw_data = np.random.randn(1,4,2,2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(4, eps=1e-05, momentum=0.0, + affine=False, track_running_stats=True, + device=None, dtype=None).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # TODO uncomment this one with simpler arguments? - # raw_data = np.random.randn(4,2,2,2).astype(np.float32) - # torch_module0 = nn.BatchNorm2d(2).eval() - # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + # Default args, eval + raw_data = np.random.randn(4,2,2,2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) if __name__ == "__main__": From e3f0236bd0d87f17fcd59747c643dd1e70587dd0 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 12:20:40 -0400 Subject: [PATCH 14/47] training failing --- .../torch/base_fx_graph_translator.py | 32 ------------- .../torch/exported_program_translator.py | 45 +++++++++++++++---- .../relax/test_from_exported_to_cuda.py | 7 ++- 3 files changed, 43 insertions(+), 41 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 95b6d57dde11..7aab3059e96e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1084,38 +1084,6 @@ def _detach(self, node: fx.Node) -> relax.Var: # by the translator, and therefore we just return a copy of the input. return self.env[node.args[0]] - # TODO need to find correct way to implement copy.default. this is just one guess - def _copy(self, node: fx.Node) -> relax.Var: - # Returns a copy of the input tensor - import torch # type: ignore - - print("node.args[0]: ", node.args[0]) - print("node.args[1]: ", node.args[1]) - print("self.env[node.args[0]]", self.env[node.args[0]]) - print("self.env[node.args[1]]", self.env[node.args[1]]) - - print("type(node.args[0]): ", type(node.args[0])) - print("type(node.args[1]): ", type(node.args[1])) - print("type(self.env[node.args[0]]): ", type(self.env[node.args[0]])) - print("type(self.env[node.args[1]]): ", type(self.env[node.args[1]])) - - - x = self.env[node.args[0]] - print('A') - if len(node.args) == 2: - print('B') - if isinstance(node.args[1], torch.dtype): - print('C') - dtype = self._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - print('D') - dtype = self._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - print('E') - return x - - def _copy_(self, node: fx.Node) -> relax.Var: # Copies the source tensor's into the destination tensor # In TVM, that means simply returning the source tensor diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c27b9b9852ed..caad4ca0b60e 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -45,7 +45,7 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: ########## Neural Network ########## - def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + def _batch_norm(self, node: fx.Node, training) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -53,19 +53,18 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: dtype = x.struct_info.dtype weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) - running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) - running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + # running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + # running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) - training = False # This method is only called for eval mode return self.block_builder.emit( relax.op.nn.batch_norm( data=x, gamma=weight, beta=bias, - moving_mean=running_mean, - moving_var=running_var, + # moving_mean=running_mean, + # moving_var=running_var, axis=1, # Always over channel epsilon=eps, center=False, # TODO @@ -75,6 +74,37 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: ) ) + def _batch_norm_training(self, node: fx.Node) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) + eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + + return self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + momentum=momentum, + ) + ) + + def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + # This method should only be called for torch exported programs corresponding to eval mode + training = False + return self._batch_norm(node, training) + def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_groups = node.args[1] @@ -273,7 +303,7 @@ def create_convert_map( # TODO figure out all calls to batchnorm HERE and in fx_translator "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? - "_native_batch_norm_legit_functional.default": self._batch_norm_legit_no_training, # when I don't do eval . TODO doesn't work right now! + "_native_batch_norm_legit_functional.default": self._batch_norm_training, # when I don't do eval . TODO doesn't work right now! "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, @@ -331,7 +361,6 @@ def create_convert_map( "view.default": self._reshape, "reshape.default": self._reshape, # tensor creation - "copy.default": self._copy, "_to_copy.default": self._to_copy, "lift_fresh_copy.default": self._to_copy, "detach.default": self._detach, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 4e43f3ee3d6d..488ddad9d676 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,11 @@ # specific language governing permissions and limitations # under the License. +# TODO remove +import sys +sys.path.append('/ssd1/htalendr/tvm/python') + + import tvm from tvm import relax import tvm.testing @@ -288,7 +293,7 @@ def test_batch_norm(target, dev): raw_data = np.random.randn(8,8,4,4).astype(np.float32) torch_module0 = nn.BatchNorm2d(8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, - device=None, dtype=None).eval() + device=None, dtype=None)#.eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # With momentum, eval From 3f680878d34a583e86d9c5ac2c7d34ba58160388 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 12:54:20 -0400 Subject: [PATCH 15/47] no need to pass center and scale since default ok --- .../torch/exported_program_translator.py | 43 +++++-------------- .../relax/test_from_exported_to_cuda.py | 3 +- 2 files changed, 11 insertions(+), 35 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index caad4ca0b60e..608b3946bb73 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -53,8 +53,8 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: dtype = x.struct_info.dtype weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) - # running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) - # running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) @@ -63,42 +63,19 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: data=x, gamma=weight, beta=bias, - # moving_mean=running_mean, - # moving_var=running_var, + moving_mean=running_mean, + moving_var=running_var, axis=1, # Always over channel epsilon=eps, - center=False, # TODO - scale=False, # TODO momentum=momentum, training=training, ) ) def _batch_norm_training(self, node: fx.Node) -> relax.Var: - import numpy as np - - x = self.env[node.args[0]] - channel = int(self.shape_of(x)[1]) - dtype = x.struct_info.dtype - weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) - bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) - running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) - running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) - momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) - eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) - - return self.block_builder.emit( - relax.op.nn.batch_norm( - x, - weight, - bias, - running_mean, - running_var, - axis=1, - epsilon=eps, - momentum=momentum, - ) - ) + # This method should only be called for torch exported programs corresponding to training mode + training = False + return self._batch_norm(node, training) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: # This method should only be called for torch exported programs corresponding to eval mode @@ -300,10 +277,10 @@ def create_convert_map( # linear algebra "linalg_vector_norm.default": self._linalg_vector_norm, # neural network - # TODO figure out all calls to batchnorm HERE and in fx_translator + # TODO figure out all calls to batchnorm here AND in fx_translator + # "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, - "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? - "_native_batch_norm_legit_functional.default": self._batch_norm_training, # when I don't do eval . TODO doesn't work right now! + "_native_batch_norm_legit_functional.default": self._batch_norm_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 488ddad9d676..f54aff01e0fc 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -293,7 +293,7 @@ def test_batch_norm(target, dev): raw_data = np.random.randn(8,8,4,4).astype(np.float32) torch_module0 = nn.BatchNorm2d(8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, - device=None, dtype=None)#.eval() + device=None, dtype=None).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # With momentum, eval @@ -303,7 +303,6 @@ def test_batch_norm(target, dev): device=None, dtype=None).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # Default args, eval raw_data = np.random.randn(4,2,2,2).astype(np.float32) torch_module0 = nn.BatchNorm2d(2).eval() From 5cd314d69e21b7b4af60bd48069d498adbc25b48 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 12:56:51 -0400 Subject: [PATCH 16/47] cleanup --- .../tvm/relax/frontend/torch/exported_program_translator.py | 2 -- python/tvm/relax/op/nn/nn.py | 1 - tests/python/relax/test_from_exported_to_cuda.py | 5 ----- 3 files changed, 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 608b3946bb73..38cfda23af38 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -315,7 +315,6 @@ def create_convert_map( "cat.default": self._cat, "clamp.Tensor": self._clamp, "concat.default": self._cat, - "copy.default": self._copy_, "copy_.default": self._copy_, "cumsum.default": self._cumsum, "expand.default": self._expand, @@ -438,7 +437,6 @@ def from_exported_program( assert ( func_name in self.convert_map ), f"Unsupported function type {func_name}" - print("Found a function called", func_name) self.env[node] = self.convert_map[func_name](node) else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 907be8a3f4c0..44377abed735 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1487,7 +1487,6 @@ def batch_norm( relax batch_norm is training mode. To transform it to inference mode, can use DecomposeOpsForInference. - Returns ------- result : relax.Expr diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index f54aff01e0fc..f8a44a841c15 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,11 +15,6 @@ # specific language governing permissions and limitations # under the License. -# TODO remove -import sys -sys.path.append('/ssd1/htalendr/tvm/python') - - import tvm from tvm import relax import tvm.testing From d5d30b7d16341f0535224b28c954741661c2af6e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 12:59:06 -0400 Subject: [PATCH 17/47] cleanup --- .../tvm/relax/frontend/torch/exported_program_translator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 38cfda23af38..b1a6dc33ff1f 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -74,7 +74,7 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: def _batch_norm_training(self, node: fx.Node) -> relax.Var: # This method should only be called for torch exported programs corresponding to training mode - training = False + training = True return self._batch_norm(node, training) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: @@ -277,8 +277,6 @@ def create_convert_map( # linear algebra "linalg_vector_norm.default": self._linalg_vector_norm, # neural network - # TODO figure out all calls to batchnorm here AND in fx_translator - # "batch_norm.default": self._batch_norm_legit_no_training, # TODO keep or not? "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "_native_batch_norm_legit_functional.default": self._batch_norm_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, From 125a9a6c8546cc20badf131c3ff6f8f87ed042d3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 13:16:33 -0400 Subject: [PATCH 18/47] reformat --- .../torch/exported_program_translator.py | 11 ++++++----- python/tvm/relax/transform/legalize_ops/nn.py | 2 +- .../relax/test_from_exported_to_cuda.py | 19 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b1a6dc33ff1f..d3faf5df0dd7 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -65,20 +65,21 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: beta=bias, moving_mean=running_mean, moving_var=running_var, - axis=1, # Always over channel + axis=1, # Always over channel epsilon=eps, momentum=momentum, training=training, ) ) - def _batch_norm_training(self, node: fx.Node) -> relax.Var: - # This method should only be called for torch exported programs corresponding to training mode + def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: + # This method is called for batch_norm in training mode + # TODO does not have correctness! training = True return self._batch_norm(node, training) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: - # This method should only be called for torch exported programs corresponding to eval mode + # This method is called for batch_norm in eval mode training = False return self._batch_norm(node, training) @@ -277,8 +278,8 @@ def create_convert_map( # linear algebra "linalg_vector_norm.default": self._linalg_vector_norm, # neural network + "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, - "_native_batch_norm_legit_functional.default": self._batch_norm_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 7b2090f7366d..4c8bdbc6615c 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -551,7 +551,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: epsilon=call.attrs.epsilon, center=call.attrs.center, scale=call.attrs.scale, - training=call.attrs.training, + training=call.attrs.training, momentum=call.attrs.momentum, ) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index f8a44a841c15..297f4ec0479e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -282,24 +282,23 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") - def test_batch_norm(target, dev): # No momentum, eval - raw_data = np.random.randn(8,8,4,4).astype(np.float32) - torch_module0 = nn.BatchNorm2d(8, eps=1e-02, momentum=0.0, - affine=False, track_running_stats=True, - device=None, dtype=None).eval() + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # With momentum, eval - raw_data = np.random.randn(1,4,2,2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(4, eps=1e-05, momentum=0.0, - affine=False, track_running_stats=True, - device=None, dtype=None).eval() + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # Default args, eval - raw_data = np.random.randn(4,2,2,2).astype(np.float32) + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d(2).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From 3f0eaea97cfdd397c101c74352086e7c03042749 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 15:02:21 -0400 Subject: [PATCH 19/47] batch norm default and print torch version --- .../tvm/relax/frontend/torch/exported_program_translator.py | 1 + tests/python/relax/test_from_exported_to_cuda.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 101738ae89ed..76a7839976c8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -297,6 +297,7 @@ def create_convert_map( # neural network "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, + "batch_norm.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 297f4ec0479e..9ae2d79c29f9 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -303,5 +303,11 @@ def test_batch_norm(target, dev): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_dummy(target, dev): + version = torch.__version__ + assert 0, f"Torch version is {version}" + + if __name__ == "__main__": tvm.testing.main() From 79e3ec6232eef984667afa8b4ecdf0d508061ea2 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 15:22:49 -0400 Subject: [PATCH 20/47] whitespace --- python/tvm/relax/op/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 44377abed735..09a7df5149f9 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1484,7 +1484,7 @@ def batch_norm( training : bool A boolean value to indicate whether training or in eval mode. By default. - relax batch_norm is training mode. To transform it to inference mode, + relax batch_norm is training mode. To transform it to inference mode, can use DecomposeOpsForInference. Returns From 79c4a0e7e6197e371fc705b329f70ac0ba8493aa Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 16 Mar 2025 17:11:36 -0400 Subject: [PATCH 21/47] remove dummy test --- .../relax/test_from_exported_to_cuda.py | 53 +++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 9ae2d79c29f9..669e9c25de28 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +# TODO remove +import sys +sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build + import tvm from tvm import relax import tvm.testing @@ -281,6 +285,33 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) +# TODO can combine the tests together (they are separete to know which test fails) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm(target, dev): + # No momentum, eval + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm(target, dev): + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm(target, dev): + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_batch_norm(target, dev): # No momentum, eval @@ -304,9 +335,25 @@ def test_batch_norm(target, dev): @tvm.testing.parametrize_targets("cuda") -def test_dummy(target, dev): - version = torch.__version__ - assert 0, f"Torch version is {version}" +def test_batch_norm(target, dev): + # No momentum, eval + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) if __name__ == "__main__": From b9697f3901ccfe3e6ed7e0f26f932ff9010025ee Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 13:47:47 -0400 Subject: [PATCH 22/47] getting a tuple as output of batchnorm --- .../torch/exported_program_translator.py | 37 ++- .../test_from_exported_batch_norm_only.py | 211 ++++++++++++++++++ 2 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 tests/python/relax/test_from_exported_batch_norm_only.py diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 76a7839976c8..fe5efe95685d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -46,6 +46,7 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: ########## Neural Network ########## def _batch_norm(self, node: fx.Node, training) -> relax.Var: + print("Inside batch norm") import numpy as np x = self.env[node.args[0]] @@ -58,8 +59,8 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) - return self.block_builder.emit( - relax.op.nn.batch_norm( + # TODO restore + inside = relax.op.nn.batch_norm( data=x, gamma=weight, beta=bias, @@ -70,15 +71,23 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: momentum=momentum, training=training, ) + print("type of inside", type(inside)) # + + outside = self.block_builder.emit( + inside ) + print("type of outside", type(outside)) # + return outside def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: + print("Inside batch norm functional") # This method is called for batch_norm in training mode # TODO does not have correctness! training = True return self._batch_norm(node, training) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + print("Inside batch norm no training") # This method is called for batch_norm in eval mode training = False return self._batch_norm(node, training) @@ -111,6 +120,7 @@ def _upsample_impl( method: str, align_corners: bool, ) -> relax.Var: + print("Inside upsample impl") coord_trans = "align_corners" if align_corners else "half_pixel" if size is None: @@ -124,13 +134,31 @@ def _upsample_impl( else: size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) - return self.block_builder.emit( - relax.op.image.resize2d( + # TODO restore + # return self.block_builder.emit( + # relax.op.image.resize2d( + # x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + # ) + # ) + + inside = relax.op.image.resize2d( x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans ) + + print("type of inside", type(inside)) # + + + outside = self.block_builder.emit( + inside ) + print("type of outside", type(outside)) # + + return outside + + def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: + print("Inside upsample bilinear 2d") x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) align_corners = ( @@ -142,6 +170,7 @@ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: ) def _upsample_nearest2d(self, node: fx.node) -> relax.Var: + print("Inside upsample nearest 2d") x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py new file mode 100644 index 000000000000..340ac25a7277 --- /dev/null +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TODO remove +import sys +sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build + +import tvm +from tvm import relax +import tvm.testing +import numpy as np +import torch +from torch import nn +from torch.export import export +from tvm.relax.frontend.torch import from_exported_program +from torch.nn import Softmax, Upsample + + +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): + """ + This util ensures that a torch module can successfully be exported to TVM + using torch.export and that the resuling IR program gives the same result + as PyTorch when ran on CUDA. + """ + raw_data_for_tvm = raw_data.copy() # In case the data is modified + torch_data = torch.from_numpy(raw_data) + example_args = (torch_data,) + + with torch.no_grad(): + exported_program = export(torch_module, example_args) + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) + + mod_from_torch.show() # TODO remove + + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) + + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) + vm = relax.VirtualMachine(ex, dev) + + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] + gpu_out = vm["main"](gpu_data, *gpu_params) + + pytorch_out = torch_module(torch_data).detach().numpy() + + print("type of pytorch_out", type(pytorch_out)) + print("pytorch output shape", pytorch_out.shape) + + print("len of gpu_out", len(gpu_out)) # 1 for all tests + print("type of gpu_out[0]", type(gpu_out[0])) # tvm.ir.container.Array for batch norm, tvm.runtime.ndarray.NDArray for both existing tests + print("gpu_out[0] shape", gpu_out[0].shape) # defined for tests that work + + actual = gpu_out[0].numpy() + desired = pytorch_out + + + + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_detach_no_change(target, dev): +# # In TVM, detach() is just identity +# class DetachTester(nn.Module): +# def forward(self, x): +# detached = x.detach() +# return detached + +# raw_data = np.ones((2, 2)).astype(np.float32) +# torch_module = DetachTester().eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +# assert 0 + + +# @tvm.testing.parametrize_targets("cuda") +# def test_upsample_with_scale_factor(target, dev): +# """ +# The Upsample module can be used with the size arugment or the scale +# factor argument but not both. This tests the latter. +# """ +# batch_size = 2 +# channels = 3 +# height, width = 32, 32 + +# torch_module = Upsample( +# size=None, scale_factor=7, mode="nearest", align_corners=None, recompute_scale_factor=True +# ) + +# raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +# assert 0 + + +# TODO in a program! to make sure dimensions work +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm_prog(target, dev): +# # No momentum, eval +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + +# class BatchNorm(nn.Module): +# def __init__(self): +# super(BatchNorm, self).__init__() +# self.bn = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True +# ) +# def forward(): + + +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +# # # TODO can combine the tests together (they are separete to know which test fails) +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm0(target, dev): +# # No momentum, eval, with running stats +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm1(target, dev): +# # With momentum, eval +# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm2(target, dev): +# # Default args, eval +# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d(2).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm3(target, dev): +# # No momentum, eval +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # With momentum, eval +# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # Default args, eval +# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d(2).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm4(target, dev): +# # No momentum, eval +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # With momentum, eval +# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # Default args, eval +# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d(2).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm5(target, dev): + # No momentum, eval, no running stats + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +if __name__ == "__main__": + tvm.testing.main() From bc181829cb7ef6fc5a4e00c2019b54ada54b3ea8 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 14:10:39 -0400 Subject: [PATCH 23/47] output now of the right dimension, and close! but is not exactly equal --- .../tvm/relax/frontend/torch/exported_program_translator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index fe5efe95685d..39ad1de8d958 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -70,7 +70,7 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: epsilon=eps, momentum=momentum, training=training, - ) + )[0] print("type of inside", type(inside)) # outside = self.block_builder.emit( @@ -83,6 +83,8 @@ def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: print("Inside batch norm functional") # This method is called for batch_norm in training mode # TODO does not have correctness! + # TODO we need to store the running mean and variance returned by the + # previous call to batch_norm and pass it again training = True return self._batch_norm(node, training) From e2e7263e440d46b3bad897288b076810bfc71e5f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 14:23:58 -0400 Subject: [PATCH 24/47] still not the same with 2 1 2 2 --- .../test_from_exported_batch_norm_only.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 340ac25a7277..187f25d52296 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -44,7 +44,7 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar exported_program = export(torch_module, example_args) mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) - mod_from_torch.show() # TODO remove + # mod_from_torch.show() # TODO remove tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) @@ -68,8 +68,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar actual = gpu_out[0].numpy() desired = pytorch_out - - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) @@ -112,19 +110,18 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar # # No momentum, eval # raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# class BatchNorm(nn.Module): +# class BatchNormWrapper(nn.Module): # def __init__(self): -# super(BatchNorm, self).__init__() +# super(BatchNormWrapper, self).__init__() # self.bn = nn.BatchNorm2d( # 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True # ) -# def forward(): - - -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +# def forward(self, x): +# x = self.bn(x) +# x = x + 1 +# return x +# torch_module = BatchNormWrapper().eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) # # # TODO can combine the tests together (they are separete to know which test fails) @@ -197,15 +194,25 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar # torch_module0 = nn.BatchNorm2d(2).eval() # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm5(target, dev): +# # No momentum, eval, no running stats +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + @tvm.testing.parametrize_targets("cuda") -def test_batch_norm5(target, dev): - # No momentum, eval, no running stats - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +def test_batch_norm6(target, dev): + # Small input + raw_data = np.random.randn(2, 1, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + if __name__ == "__main__": tvm.testing.main() From 4cdb05a7b97649e4a8a155fe6942007c12c9f7c0 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 15:23:23 -0400 Subject: [PATCH 25/47] missing eps --- python/tvm/topi/nn/batch_norm.py | 19 +++++++++++++++---- .../test_from_exported_batch_norm_only.py | 4 ++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index fa03aed73660..ff956d9eef9c 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -112,6 +112,14 @@ def batch_norm( shape[axis] = data.shape[axis] if training: + moving_mean_rs = topi.reshape(moving_mean, shape) + moving_var_rs = topi.reshape(moving_var, shape) + + + + out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) + + else: reduce_axes = list(range(len(data.shape))) reduce_axes.remove(axis) shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) @@ -121,12 +129,15 @@ def batch_norm( topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod ) data_var_rs = topi.reshape(data_var, shape) + + print("data is", data) + print("data_mean_rs is", data_mean_rs) + print("data_var_rs is", data_var_rs) + print("epsilon is", epsilon) + print("sqrt of data_var_rs + epsilon is", topi.math.sqrt(data_var_rs + epsilon)) + out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon) - else: - moving_mean_rs = topi.reshape(moving_mean, shape) - moving_var_rs = topi.reshape(moving_var, shape) - out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) if scale: out = out * topi.reshape(gamma, shape) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 187f25d52296..0d384b89c201 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -206,9 +206,9 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar @tvm.testing.parametrize_targets("cuda") def test_batch_norm6(target, dev): # Small input - raw_data = np.random.randn(2, 1, 2, 2).astype(np.float32) + raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + 8, eps=0.1, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From b256163614f3466b6323ba811d424e4757f3b873 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 15:36:36 -0400 Subject: [PATCH 26/47] last small test passes, but most tests still fail --- .../torch/exported_program_translator.py | 11 +- .../test_from_exported_batch_norm_only.py | 258 +++++++++--------- 2 files changed, 139 insertions(+), 130 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 39ad1de8d958..7e967ce4d7eb 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -53,11 +53,20 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: channel = int(self.shape_of(x)[1]) dtype = x.struct_info.dtype weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + print("weight", weight) bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + print("bias", bias) running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + print("running mean", running_mean) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + print("running var", running_var) momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) - eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + print("momentum", momentum) # TODO is this affine? + whatisThis = node.args[6] if len(node.args) > 6 else node.kwargs.get("??????????", "???????") + print("_batch_norm found an whatisThis", whatisThis) + eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) + print("node.args[7]", node.args[7]) # TODO that's eps !!!!! + print("node.args[8]", node.args[8]) # TODO remove # TODO restore inside = relax.op.nn.batch_norm( diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 0d384b89c201..78007893a2f5 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -71,143 +71,143 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) -# @tvm.testing.parametrize_targets("cuda") -# def test_detach_no_change(target, dev): -# # In TVM, detach() is just identity -# class DetachTester(nn.Module): -# def forward(self, x): -# detached = x.detach() -# return detached - -# raw_data = np.ones((2, 2)).astype(np.float32) -# torch_module = DetachTester().eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# assert 0 - - -# @tvm.testing.parametrize_targets("cuda") -# def test_upsample_with_scale_factor(target, dev): -# """ -# The Upsample module can be used with the size arugment or the scale -# factor argument but not both. This tests the latter. -# """ -# batch_size = 2 -# channels = 3 -# height, width = 32, 32 - -# torch_module = Upsample( -# size=None, scale_factor=7, mode="nearest", align_corners=None, recompute_scale_factor=True -# ) - -# raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# assert 0 +@tvm.testing.parametrize_targets("cuda") +def test_detach_no_change(target, dev): + # In TVM, detach() is just identity + class DetachTester(nn.Module): + def forward(self, x): + detached = x.detach() + return detached + + raw_data = np.ones((2, 2)).astype(np.float32) + torch_module = DetachTester().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + assert 0 + + +@tvm.testing.parametrize_targets("cuda") +def test_upsample_with_scale_factor(target, dev): + """ + The Upsample module can be used with the size arugment or the scale + factor argument but not both. This tests the latter. + """ + batch_size = 2 + channels = 3 + height, width = 32, 32 + + torch_module = Upsample( + size=None, scale_factor=7, mode="nearest", align_corners=None, recompute_scale_factor=True + ) + + raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + assert 0 # TODO in a program! to make sure dimensions work -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm_prog(target, dev): -# # No momentum, eval -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - -# class BatchNormWrapper(nn.Module): -# def __init__(self): -# super(BatchNormWrapper, self).__init__() -# self.bn = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True -# ) -# def forward(self, x): -# x = self.bn(x) -# x = x + 1 -# return x -# torch_module = BatchNormWrapper().eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -# # # TODO can combine the tests together (they are separete to know which test fails) -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm0(target, dev): -# # No momentum, eval, with running stats -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm1(target, dev): -# # With momentum, eval -# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm2(target, dev): -# # Default args, eval -# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d(2).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm3(target, dev): -# # No momentum, eval -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # With momentum, eval -# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # Default args, eval -# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d(2).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm4(target, dev): -# # No momentum, eval -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # With momentum, eval -# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # Default args, eval -# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d(2).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm5(target, dev): -# # No momentum, eval, no running stats -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm_prog(target, dev): + # No momentum, eval + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + + class BatchNormWrapper(nn.Module): + def __init__(self): + super(BatchNormWrapper, self).__init__() + self.bn = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True + ) + def forward(self, x): + x = self.bn(x) + x = x + 1 + return x + torch_module = BatchNormWrapper().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +# # TODO can combine the tests together (they are separete to know which test fails) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm0(target, dev): + # No momentum, eval, with running stats + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm1(target, dev): + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm2(target, dev): + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm3(target, dev): + # No momentum, eval + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm4(target, dev): + # No momentum, eval + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm5(target, dev): + # No momentum, eval, no running stats + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm6(target, dev): # Small input raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) - torch_module0 = nn.BatchNorm2d( + torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) 8, eps=0.1, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From ab8d75cbebdc830ec3cbcd825069c522c4fdd60f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 15:51:28 -0400 Subject: [PATCH 27/47] passes --- .../torch/exported_program_translator.py | 8 +- .../test_from_exported_batch_norm_only.py | 176 +++++++++--------- 2 files changed, 91 insertions(+), 93 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7e967ce4d7eb..e66d0f4ef2d0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -60,12 +60,12 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: print("running mean", running_mean) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) print("running var", running_var) - momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) - print("momentum", momentum) # TODO is this affine? - whatisThis = node.args[6] if len(node.args) > 6 else node.kwargs.get("??????????", "???????") + whatisThis = node.args[5] if len(node.args) > 5 else node.kwargs.get("??????????", "???????") print("_batch_norm found an whatisThis", whatisThis) + momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) + print("momentum", momentum) # TODO is this affine? eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - print("node.args[7]", node.args[7]) # TODO that's eps !!!!! + print("eps", node.args[7]) # TODO that's eps !!!!! print("node.args[8]", node.args[8]) # TODO remove # TODO restore diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 78007893a2f5..9084baa0c858 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -82,7 +82,6 @@ def forward(self, x): raw_data = np.ones((2, 2)).astype(np.float32) torch_module = DetachTester().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - assert 0 @tvm.testing.parametrize_targets("cuda") @@ -101,27 +100,26 @@ def test_upsample_with_scale_factor(target, dev): raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - assert 0 -# TODO in a program! to make sure dimensions work -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm_prog(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# # TODO in a program! to make sure dimensions work +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm_prog(target, dev): +# # No momentum, eval +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - class BatchNormWrapper(nn.Module): - def __init__(self): - super(BatchNormWrapper, self).__init__() - self.bn = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True - ) - def forward(self, x): - x = self.bn(x) - x = x + 1 - return x - torch_module = BatchNormWrapper().eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +# class BatchNormWrapper(nn.Module): +# def __init__(self): +# super(BatchNormWrapper, self).__init__() +# self.bn = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True +# ) +# def forward(self, x): +# x = self.bn(x) +# x = x + 1 +# return x +# torch_module = BatchNormWrapper().eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) # # TODO can combine the tests together (they are separete to know which test fails) @@ -130,85 +128,85 @@ def test_batch_norm0(target, dev): # No momentum, eval, with running stats raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm1(target, dev): - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm2(target, dev): - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm3(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm4(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm5(target, dev): - # No momentum, eval, no running stats - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm1(target, dev): +# # With momentum, eval +# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm2(target, dev): +# # Default args, eval +# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d(2).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm3(target, dev): +# # No momentum, eval +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # With momentum, eval +# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # Default args, eval +# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d(2).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm4(target, dev): +# # No momentum, eval +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # With momentum, eval +# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# # Default args, eval +# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) +# torch_module0 = nn.BatchNorm2d(2).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm5(target, dev): +# # No momentum, eval, no running stats +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm6(target, dev): # Small input raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) - 8, eps=0.1, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + 8, eps=0.2, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From 7cb5a56c94866b1a6c33459558efb5f2e7c03267 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 15:53:50 -0400 Subject: [PATCH 28/47] passes --- .../tvm/relax/frontend/torch/exported_program_translator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e66d0f4ef2d0..a941450642ee 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -60,8 +60,9 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: print("running mean", running_mean) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) print("running var", running_var) - whatisThis = node.args[5] if len(node.args) > 5 else node.kwargs.get("??????????", "???????") - print("_batch_norm found an whatisThis", whatisThis) + ignore_running_stats = node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) + track_running_stats = not ignore_running_stats + print("_batch_norm found a track_running_stats =", track_running_stats) momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) print("momentum", momentum) # TODO is this affine? eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) From 536310ab4ea16cd9c8c156c3cb27762647dde7b2 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 15:59:26 -0400 Subject: [PATCH 29/47] need to fix test_batch_norm7 --- .../test_from_exported_batch_norm_only.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 9084baa0c858..f9aaafb2cb51 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -123,18 +123,18 @@ def test_upsample_with_scale_factor(target, dev): # # TODO can combine the tests together (they are separete to know which test fails) -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm0(target, dev): - # No momentum, eval, with running stats - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm0(target, dev): +# # Eval, no momentum, with affine, without running stats +# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( +# 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # @tvm.testing.parametrize_targets("cuda") # def test_batch_norm1(target, dev): -# # With momentum, eval +# # With momentum, no affine, with running stats # raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) # torch_module0 = nn.BatchNorm2d( # 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None @@ -201,16 +201,26 @@ def test_batch_norm0(target, dev): # ).eval() # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +# @tvm.testing.parametrize_targets("cuda") +# def test_batch_norm6(target, dev): +# # Small input +# raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) +# torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) +# 8, eps=0.2, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None +# ).eval() +# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + @tvm.testing.parametrize_targets("cuda") -def test_batch_norm6(target, dev): - # Small input +def test_batch_norm7(target, dev): + # Eval, small input, no momentum, with affine, with running stats raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) - torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) - 8, eps=0.2, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + if __name__ == "__main__": tvm.testing.main() From 4c55f2062f2aae96954ddef27f87bcfe2919a4ec Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Fri, 21 Mar 2025 16:03:56 -0400 Subject: [PATCH 30/47] commented out tests that pass --- .../test_from_exported_batch_norm_only.py | 59 +++++++++++-------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index f9aaafb2cb51..d85b9dd23930 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -123,14 +123,14 @@ def test_upsample_with_scale_factor(target, dev): # # TODO can combine the tests together (they are separete to know which test fails) -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm0(target, dev): -# # Eval, no momentum, with affine, without running stats -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm0(target, dev): + # Eval, no momentum, with affine, without running stats + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # @tvm.testing.parametrize_targets("cuda") # def test_batch_norm1(target, dev): @@ -192,33 +192,42 @@ def test_upsample_with_scale_factor(target, dev): # torch_module0 = nn.BatchNorm2d(2).eval() # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm5(target, dev): + # No momentum, eval, no running stats + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm6(target, dev): + # Small input + raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) + torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) + 8, eps=0.2, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + # @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm5(target, dev): -# # No momentum, eval, no running stats -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +# def test_batch_norm7(target, dev): +# # Eval, small input, no momentum, with affine, with running stats +# raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) # torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None +# 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None # ).eval() # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm6(target, dev): -# # Small input +# def test_batch_norm7(target, dev): +# # Eval, small input, no momentum, no affine, with running stats # raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) -# 8, eps=0.2, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None +# torch_module0 = nn.BatchNorm2d( +# 2, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None # ).eval() # assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm7(target, dev): - # Eval, small input, no momentum, with affine, with running stats - raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - From e99d659d1753c7e4abb7c39698983717fc5cd159 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 15:25:07 -0400 Subject: [PATCH 31/47] legalize tests --- python/tvm/topi/nn/batch_norm.py | 17 +- .../test_transform_legalize_ops_nn_copy.py | 508 ++++++++++++++++++ 2 files changed, 517 insertions(+), 8 deletions(-) create mode 100644 tests/python/relax/test_transform_legalize_ops_nn_copy.py diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index ff956d9eef9c..234a4b0d2e48 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -111,24 +111,25 @@ def batch_norm( shape = [1] * len(data.shape) shape[axis] = data.shape[axis] + + data_mean = topi.sum(data, axis=reduce_axes) / shape_prod + data_mean_rs = topi.reshape(data_mean, shape) + data_var = ( + topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod + ) + data_var_rs = topi.reshape(data_var, shape) + + if training: moving_mean_rs = topi.reshape(moving_mean, shape) moving_var_rs = topi.reshape(moving_var, shape) - - out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) else: reduce_axes = list(range(len(data.shape))) reduce_axes.remove(axis) shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) - data_mean = topi.sum(data, axis=reduce_axes) / shape_prod - data_mean_rs = topi.reshape(data_mean, shape) - data_var = ( - topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod - ) - data_var_rs = topi.reshape(data_var, shape) print("data is", data) print("data_mean_rs is", data_mean_rs) diff --git a/tests/python/relax/test_transform_legalize_ops_nn_copy.py b/tests/python/relax/test_transform_legalize_ops_nn_copy.py new file mode 100644 index 000000000000..ddc9a1283d0c --- /dev/null +++ b/tests/python/relax/test_transform_legalize_ops_nn_copy.py @@ -0,0 +1,508 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TODO remove +import sys +sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build + +import pytest + +import tvm +import tvm.testing +from tvm.relax.transform import LegalizeOps +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + +##################### Neural network ##################### + + + +def test_batch_norm(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")): + gv: R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),), "float32"), rxplaceholder_2: T.Buffer((T.int64(3),), "float32"), rxplaceholder_3: T.Buffer((T.int64(3),), "float32"), rxplaceholder_4: T.Buffer((T.int64(3),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), T_add_1: T.Buffer((T.int64(3),), "float32"), T_add_2: T.Buffer((T.int64(3),), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + rxplaceholder_red = T.alloc_buffer((T.int64(3),)) + T_divide = T.alloc_buffer((T.int64(3),)) + T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_red = T.alloc_buffer((T.int64(3),)) + T_divide_1 = T.alloc_buffer((T.int64(3),)) + T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_3 = T.alloc_buffer((T.int64(3),)) + T_multiply_4 = T.alloc_buffer((T.int64(3),)) + T_subtract_3 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_subtract_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_red_1 = T.alloc_buffer((T.int64(3),)) + T_divide_3 = T.alloc_buffer((T.int64(3),)) + T_multiply_6 = T.alloc_buffer((T.int64(3),)) + for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): + with T.block("rxplaceholder_red"): + v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) + T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) + T.writes(rxplaceholder_red[v_ax0]) + with T.init(): + rxplaceholder_red[v_ax0] = T.float32(0) + rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(rxplaceholder_red[v_ax0]) + T.writes(T_divide[v_ax0]) + T_divide[v_ax0] = rxplaceholder_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): + with T.block("T_multiply_red"): + v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) + T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_divide_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(rxplaceholder_3[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(rxplaceholder_4[v_ax0]) + T.writes(T_multiply_4[v_ax0]) + T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_subtract_4"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): + with T.block("T_multiply_5"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): + with T.block("T_multiply_red_1"): + v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) + T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red_1[v_ax0]) + with T.init(): + T_multiply_red_1[v_ax0] = T.float32(0) + T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_red_1[v_ax0]) + T.writes(T_divide_3[v_ax0]) + T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_6"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_3[v_ax0]) + T.writes(T_multiply_6[v_ax0]) + T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] + + @R.function + def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): + gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) + return gv + # fmt: on + + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_batch_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": True}) + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,)) + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,)) + rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,)) + rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,)) + T_add = T.match_buffer(var_T_add, (n, h, w, c)) + T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) + T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) + # with T.block("root"): + rxplaceholder_red = T.alloc_buffer((h,)) + T_divide = T.alloc_buffer((h,)) + T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((n, h, w, c)) + T_subtract_1 = T.alloc_buffer((n, h, w, c)) + T_subtract_2 = T.alloc_buffer((n, h, w, c)) + T_multiply = T.alloc_buffer((n, h, w, c)) + T_multiply_red = T.alloc_buffer((h,)) + T_divide_1 = T.alloc_buffer((h,)) + T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_divide_2 = T.alloc_buffer((n, h, w, c)) + T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((n, h, w, c)) + T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply_2 = T.alloc_buffer((c,)) + T_multiply_3 = T.alloc_buffer((h,)) + T_multiply_4 = T.alloc_buffer((c,)) + T_subtract_3 = T.alloc_buffer((n, h, w, c)) + T_subtract_4 = T.alloc_buffer((n, h, w, c)) + T_multiply_5 = T.alloc_buffer((n, h, w, c)) + T_multiply_red_1 = T.alloc_buffer((h,)) + T_divide_3 = T.alloc_buffer((h,)) + T_multiply_6 = T.alloc_buffer((h,)) + for ax0, k0, k2, k3 in T.grid(h, n, w, c): + with T.block("rxplaceholder_red"): + v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) + T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) + T.writes(rxplaceholder_red[v_ax0]) + with T.init(): + rxplaceholder_red[v_ax0] = T.float32(0) + rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(rxplaceholder_red[v_ax0]) + T.writes(T_divide[v_ax0]) + T_divide[v_ax0] = rxplaceholder_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_subtract"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_subtract_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_subtract_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, k0, k2, k3 in T.grid(h, n, w, c): + with T.block("T_multiply_red"): + v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) + T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): + with T.block("T_reshape_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_divide_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): + with T.block("T_reshape_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(c): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(rxplaceholder_3[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] + for ax0 in range(h): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] + for ax0 in range(c): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(rxplaceholder_4[v_ax0]) + T.writes(T_multiply_4[v_ax0]) + T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_subtract_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_subtract_4"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): + with T.block("T_multiply_5"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, k0, k2, k3 in T.grid(h, n, w, c): + with T.block("T_multiply_red_1"): + v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) + T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red_1[v_ax0]) + with T.init(): + T_multiply_red_1[v_ax0] = T.float32(0) + T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_3"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_multiply_red_1[v_ax0]) + T.writes(T_divide_3[v_ax0]) + T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_6"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_3[v_ax0]) + T.writes(T_multiply_6[v_ax0]) + T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] + + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32")): + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) + return gv + # fmt: on + + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + + + +if __name__ == "__main__": + tvm.testing.main() From 56b3999cebc49a03247c65baf855d8bf6e1a72bc Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 15:31:19 -0400 Subject: [PATCH 32/47] correct calc of data for everyone --- python/tvm/topi/nn/batch_norm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index 234a4b0d2e48..6225b5aa0852 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -111,6 +111,9 @@ def batch_norm( shape = [1] * len(data.shape) shape[axis] = data.shape[axis] + reduce_axes = list(range(len(data.shape))) + reduce_axes.remove(axis) + shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) data_mean = topi.sum(data, axis=reduce_axes) / shape_prod data_mean_rs = topi.reshape(data_mean, shape) @@ -127,9 +130,6 @@ def batch_norm( out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) else: - reduce_axes = list(range(len(data.shape))) - reduce_axes.remove(axis) - shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) print("data is", data) print("data_mean_rs is", data_mean_rs) From fc6b03a9346c2a202e1e9fa17d68feeada9c1faa Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 15:38:19 -0400 Subject: [PATCH 33/47] track running stats is equivalent to training! passes all --- .../torch/exported_program_translator.py | 3 ++ .../test_from_exported_batch_norm_only.py | 49 +++++++++---------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a941450642ee..bc440fcf11b0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -69,6 +69,9 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: print("eps", node.args[7]) # TODO that's eps !!!!! print("node.args[8]", node.args[8]) # TODO remove + if track_running_stats: + training = True + # TODO restore inside = relax.op.nn.batch_norm( data=x, diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index d85b9dd23930..dda3e5c8dbf9 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -102,24 +102,23 @@ def test_upsample_with_scale_factor(target, dev): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# # TODO in a program! to make sure dimensions work -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm_prog(target, dev): -# # No momentum, eval -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm_prog(target, dev): + # No momentum, eval, in a pytorch program + raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# class BatchNormWrapper(nn.Module): -# def __init__(self): -# super(BatchNormWrapper, self).__init__() -# self.bn = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True -# ) -# def forward(self, x): -# x = self.bn(x) -# x = x + 1 -# return x -# torch_module = BatchNormWrapper().eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + class BatchNormWrapper(nn.Module): + def __init__(self): + super(BatchNormWrapper, self).__init__() + self.bn = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False + ) + def forward(self, x): + x = self.bn(x) + x = x + 1 + return x + torch_module = BatchNormWrapper().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) # # TODO can combine the tests together (they are separete to know which test fails) @@ -132,14 +131,14 @@ def test_batch_norm0(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm1(target, dev): -# # With momentum, no affine, with running stats -# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm1(target, dev): + # With momentum, no affine, with running stats + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) # @tvm.testing.parametrize_targets("cuda") # def test_batch_norm2(target, dev): From 7139590ab0cf829dbfc3d211c648b50550f85d75 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 15:46:43 -0400 Subject: [PATCH 34/47] all tests pass except for cache size --- .../test_from_exported_batch_norm_only.py | 163 +++++++----------- 1 file changed, 61 insertions(+), 102 deletions(-) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index dda3e5c8dbf9..0c537865df87 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -70,42 +70,10 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - -@tvm.testing.parametrize_targets("cuda") -def test_detach_no_change(target, dev): - # In TVM, detach() is just identity - class DetachTester(nn.Module): - def forward(self, x): - detached = x.detach() - return detached - - raw_data = np.ones((2, 2)).astype(np.float32) - torch_module = DetachTester().eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_upsample_with_scale_factor(target, dev): - """ - The Upsample module can be used with the size arugment or the scale - factor argument but not both. This tests the latter. - """ - batch_size = 2 - channels = 3 - height, width = 32, 32 - - torch_module = Upsample( - size=None, scale_factor=7, mode="nearest", align_corners=None, recompute_scale_factor=True - ) - - raw_data = np.random.rand(batch_size, channels, height, width).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - @tvm.testing.parametrize_targets("cuda") def test_batch_norm_prog(target, dev): # No momentum, eval, in a pytorch program - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + raw_data = np.random.randn(8, 2, 4, 4).astype(np.float32) class BatchNormWrapper(nn.Module): def __init__(self): @@ -125,9 +93,9 @@ def forward(self, x): @tvm.testing.parametrize_targets("cuda") def test_batch_norm0(target, dev): # Eval, no momentum, with affine, without running stats - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) + raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None + 3, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @@ -140,56 +108,56 @@ def test_batch_norm1(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm2(target, dev): -# # Default args, eval -# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d(2).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm3(target, dev): -# # No momentum, eval -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # With momentum, eval -# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # Default args, eval -# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d(2).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm4(target, dev): -# # No momentum, eval -# raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # With momentum, eval -# raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# # Default args, eval -# raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) -# torch_module0 = nn.BatchNorm2d(2).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm2(target, dev): + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm3(target, dev): + # No momentum, eval + raw_data = np.random.randn(8, 1, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 1, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm4(target, dev): + # No momentum, eval + raw_data = np.random.randn(3, 8, 4, 4).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # With momentum, eval + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + + # Default args, eval + raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) + torch_module0 = nn.BatchNorm2d(2).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm5(target, dev): @@ -209,23 +177,14 @@ def test_batch_norm6(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm7(target, dev): -# # Eval, small input, no momentum, with affine, with running stats -# raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 8, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -# @tvm.testing.parametrize_targets("cuda") -# def test_batch_norm7(target, dev): -# # Eval, small input, no momentum, no affine, with running stats -# raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) -# torch_module0 = nn.BatchNorm2d( -# 2, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None -# ).eval() -# assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm7(target, dev): + # Eval, small input, no momentum, no affine, with running stats + raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 1, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From 267f011254a4b0f3aa7003f8e13497f467bf58db Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 15:52:15 -0400 Subject: [PATCH 35/47] all batch norm only pass! --- .../test_from_exported_batch_norm_only.py | 36 +++---------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 0c537865df87..59bbe4e3de89 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -118,47 +118,22 @@ def test_batch_norm2(target, dev): @tvm.testing.parametrize_targets("cuda") def test_batch_norm3(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 1, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 1, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + raw_data = np.random.randn(1, 3, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + 3, eps=1e-05, momentum=0.2, affine=False, track_running_stats=True, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - @tvm.testing.parametrize_targets("cuda") def test_batch_norm4(target, dev): # No momentum, eval - raw_data = np.random.randn(3, 8, 4, 4).astype(np.float32) + raw_data = np.random.randn(3, 3, 4, 4).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - @tvm.testing.parametrize_targets("cuda") def test_batch_norm5(target, dev): # No momentum, eval, no running stats @@ -186,8 +161,5 @@ def test_batch_norm7(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - - if __name__ == "__main__": tvm.testing.main() From 7a5cadd9f057e845f878b0ff09f9a07f94dd1130 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 16:00:44 -0400 Subject: [PATCH 36/47] all exported tests work, moved to main script --- .../test_from_exported_batch_norm_only.py | 65 ++++----------- .../relax/test_from_exported_to_cuda.py | 81 ++++++++----------- 2 files changed, 46 insertions(+), 100 deletions(-) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index 59bbe4e3de89..c1867c8eec7d 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -72,15 +72,13 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar @tvm.testing.parametrize_targets("cuda") def test_batch_norm_prog(target, dev): - # No momentum, eval, in a pytorch program - raw_data = np.random.randn(8, 2, 4, 4).astype(np.float32) + # Default args, in a pytorch program (to ensure output is in proper type and format) + raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32) class BatchNormWrapper(nn.Module): def __init__(self): super(BatchNormWrapper, self).__init__() - self.bn = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False - ) + self.bn = nn.BatchNorm2d(3) def forward(self, x): x = self.bn(x) x = x + 1 @@ -92,74 +90,39 @@ def forward(self, x): # # TODO can combine the tests together (they are separete to know which test fails) @tvm.testing.parametrize_targets("cuda") def test_batch_norm0(target, dev): - # Eval, no momentum, with affine, without running stats + # Eval, no momentum, no affine, no running stats raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 3, eps=1e-02, momentum=0.0, affine=True, track_running_stats=False, device=None, dtype=None + 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm1(target, dev): - # With momentum, no affine, with running stats + # Eval, with momentum, no affine, with running stats raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + 4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm2(target, dev): - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm3(target, dev): - # With momentum, eval - raw_data = np.random.randn(1, 3, 2, 2).astype(np.float32) + # Eval, with momentum, affine, no running stats + raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 3, eps=1e-05, momentum=0.2, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() + 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm4(target, dev): - # No momentum, eval - raw_data = np.random.randn(3, 3, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @tvm.testing.parametrize_targets("cuda") -def test_batch_norm5(target, dev): - # No momentum, eval, no running stats - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +def test_batch_norm3(target, dev): + # Eval, no momentum, affine, with running stats + raw_data = np.random.randn(1, 3, 3, 3).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm6(target, dev): - # Small input - raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) - torch_module0 = nn.BatchNorm2d( # TODO what does the 8 do? (feature num) - 8, eps=0.2, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None - ).eval() + 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm7(target, dev): - # Eval, small input, no momentum, no affine, with running stats - raw_data = np.array([[[[ 0.5]]], [[[1.5]]]]).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 1, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 669e9c25de28..0b6708191031 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -285,74 +285,57 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) -# TODO can combine the tests together (they are separete to know which test fails) @tvm.testing.parametrize_targets("cuda") -def test_batch_norm(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) +def test_batch_norm_prog(target, dev): + # Default args, in a pytorch program (to ensure output is in proper type and format) + raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32) -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm(target, dev): - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm(target, dev): - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + class BatchNormWrapper(nn.Module): + def __init__(self): + super(BatchNormWrapper, self).__init__() + self.bn = nn.BatchNorm2d(3) + def forward(self, x): + x = self.bn(x) + x = x + 1 + return x + torch_module = BatchNormWrapper().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +# # TODO can combine the tests together (they are separete to know which test fails) @tvm.testing.parametrize_targets("cuda") -def test_batch_norm(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +def test_batch_norm0(target, dev): + # Eval, no momentum, no affine, no running stats + raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # With momentum, eval +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm1(target, dev): + # Eval, with momentum, no affine, with running stats raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None + 4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, device=None, dtype=None ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - @tvm.testing.parametrize_targets("cuda") -def test_batch_norm(target, dev): - # No momentum, eval - raw_data = np.random.randn(8, 8, 4, 4).astype(np.float32) +def test_batch_norm2(target, dev): + # Eval, with momentum, affine, no running stats + raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 8, eps=1e-02, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() + 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # With momentum, eval - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.0, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - # Default args, eval - raw_data = np.random.randn(4, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d(2).eval() +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm3(target, dev): + # Eval, no momentum, affine, with running stats + raw_data = np.random.randn(1, 3, 3, 3).astype(np.float32) + torch_module0 = nn.BatchNorm2d( + 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From e76ab8c60da70e65d8ebb3875477c6bf187dd9c3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 16:21:56 -0400 Subject: [PATCH 37/47] need to fix legalize tests --- .../test_transform_legalize_ops_nn_copy.py | 480 +++++++++--------- 1 file changed, 240 insertions(+), 240 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn_copy.py b/tests/python/relax/test_transform_legalize_ops_nn_copy.py index ddc9a1283d0c..3969c33730b7 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn_copy.py +++ b/tests/python/relax/test_transform_legalize_ops_nn_copy.py @@ -257,249 +257,249 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dty tvm.ir.assert_structural_equal(mod, Expected) -def test_batch_norm_symbolic(): - # fmt: off - @tvm.script.ir_module - class BatchNorm: - @R.function - def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) - return gv +# def test_batch_norm_symbolic(): +# # fmt: off +# @tvm.script.ir_module +# class BatchNorm: +# @R.function +# def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): +# n = T.int64() +# h = T.int64() +# w = T.int64() +# c = T.int64() +# gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) +# return gv - @tvm.script.ir_module - class Expected: - @T.prim_func(private=True) - def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": True}) - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,)) - rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,)) - rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,)) - rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,)) - T_add = T.match_buffer(var_T_add, (n, h, w, c)) - T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) - T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) - # with T.block("root"): - rxplaceholder_red = T.alloc_buffer((h,)) - T_divide = T.alloc_buffer((h,)) - T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((n, h, w, c)) - T_subtract_1 = T.alloc_buffer((n, h, w, c)) - T_subtract_2 = T.alloc_buffer((n, h, w, c)) - T_multiply = T.alloc_buffer((n, h, w, c)) - T_multiply_red = T.alloc_buffer((h,)) - T_divide_1 = T.alloc_buffer((h,)) - T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_divide_2 = T.alloc_buffer((n, h, w, c)) - T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((n, h, w, c)) - T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_2 = T.alloc_buffer((c,)) - T_multiply_3 = T.alloc_buffer((h,)) - T_multiply_4 = T.alloc_buffer((c,)) - T_subtract_3 = T.alloc_buffer((n, h, w, c)) - T_subtract_4 = T.alloc_buffer((n, h, w, c)) - T_multiply_5 = T.alloc_buffer((n, h, w, c)) - T_multiply_red_1 = T.alloc_buffer((h,)) - T_divide_3 = T.alloc_buffer((h,)) - T_multiply_6 = T.alloc_buffer((h,)) - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("rxplaceholder_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) - T.writes(rxplaceholder_red[v_ax0]) - with T.init(): - rxplaceholder_red[v_ax0] = T.float32(0) - rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(rxplaceholder_red[v_ax0]) - T.writes(T_divide[v_ax0]) - T_divide[v_ax0] = rxplaceholder_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("T_multiply_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_divide_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(c): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(rxplaceholder_3[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] - for ax0 in range(h): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] - for ax0 in range(c): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(rxplaceholder_4[v_ax0]) - T.writes(T_multiply_4[v_ax0]) - T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_4"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply_5"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("T_multiply_red_1"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red_1[v_ax0]) - with T.init(): - T_multiply_red_1[v_ax0] = T.float32(0) - T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_3"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red_1[v_ax0]) - T.writes(T_divide_3[v_ax0]) - T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / T.Cast("float32", n * w * c) - for ax0 in range(h): - with T.block("T_multiply_6"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide_3[v_ax0]) - T.writes(T_multiply_6[v_ax0]) - T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] +# @tvm.script.ir_module +# class Expected: +# @T.prim_func(private=True) +# def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): +# T.func_attr({"tir.noalias": True}) +# n = T.int64() +# h = T.int64() +# w = T.int64() +# c = T.int64() +# rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c)) +# rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,)) +# rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,)) +# rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,)) +# rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,)) +# T_add = T.match_buffer(var_T_add, (n, h, w, c)) +# T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) +# T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) +# # with T.block("root"): +# rxplaceholder_red = T.alloc_buffer((h,)) +# T_divide = T.alloc_buffer((h,)) +# T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) +# T_subtract = T.alloc_buffer((n, h, w, c)) +# T_subtract_1 = T.alloc_buffer((n, h, w, c)) +# T_subtract_2 = T.alloc_buffer((n, h, w, c)) +# T_multiply = T.alloc_buffer((n, h, w, c)) +# T_multiply_red = T.alloc_buffer((h,)) +# T_divide_1 = T.alloc_buffer((h,)) +# T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) +# T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) +# compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) +# T_divide_2 = T.alloc_buffer((n, h, w, c)) +# T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) +# T_multiply_1 = T.alloc_buffer((n, h, w, c)) +# T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) +# T_multiply_2 = T.alloc_buffer((c,)) +# T_multiply_3 = T.alloc_buffer((h,)) +# T_multiply_4 = T.alloc_buffer((c,)) +# T_subtract_3 = T.alloc_buffer((n, h, w, c)) +# T_subtract_4 = T.alloc_buffer((n, h, w, c)) +# T_multiply_5 = T.alloc_buffer((n, h, w, c)) +# T_multiply_red_1 = T.alloc_buffer((h,)) +# T_divide_3 = T.alloc_buffer((h,)) +# T_multiply_6 = T.alloc_buffer((h,)) +# for ax0, k0, k2, k3 in T.grid(h, n, w, c): +# with T.block("rxplaceholder_red"): +# v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) +# T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) +# T.writes(rxplaceholder_red[v_ax0]) +# with T.init(): +# rxplaceholder_red[v_ax0] = T.float32(0) +# rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] +# for ax0 in range(h): +# with T.block("T_divide"): +# v_ax0 = T.axis.spatial(h, ax0) +# T.reads(rxplaceholder_red[v_ax0]) +# T.writes(T_divide[v_ax0]) +# T_divide[v_ax0] = rxplaceholder_red[v_ax0] / T.Cast("float32", n * w * c) +# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): +# with T.block("T_reshape"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) +# T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_subtract"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_subtract_1"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_subtract_2"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_multiply"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) +# T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] +# for ax0, k0, k2, k3 in T.grid(h, n, w, c): +# with T.block("T_multiply_red"): +# v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) +# T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) +# T.writes(T_multiply_red[v_ax0]) +# with T.init(): +# T_multiply_red[v_ax0] = T.float32(0) +# T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] +# for ax0 in range(h): +# with T.block("T_divide_1"): +# v_ax0 = T.axis.spatial(h, ax0) +# T.reads(T_multiply_red[v_ax0]) +# T.writes(T_divide_1[v_ax0]) +# T_divide_1[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) +# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): +# with T.block("T_reshape_1"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) +# T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] +# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): +# with T.block("T_add"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) +# T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) +# for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): +# with T.block("compute"): +# v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) +# T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) +# T.writes(compute[v_i0, v_i1, v_i2, v_i3]) +# compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_divide_2"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): +# with T.block("T_reshape_2"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) +# T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_multiply_1"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): +# with T.block("T_reshape_3"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) +# T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_add_1"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0 in range(c): +# with T.block("T_multiply_2"): +# v_ax0 = T.axis.spatial(c, ax0) +# T.reads(rxplaceholder_3[v_ax0]) +# T.writes(T_multiply_2[v_ax0]) +# T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] +# for ax0 in range(h): +# with T.block("T_multiply_3"): +# v_ax0 = T.axis.spatial(h, ax0) +# T.reads(T_divide[v_ax0]) +# T.writes(T_multiply_3[v_ax0]) +# T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] +# for ax0 in range(T.max(c, h)): +# with T.block("T_add_2"): +# v_ax0 = T.axis.spatial(T.max(c, h), ax0) +# T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) +# T.writes(T_add_1[v_ax0]) +# T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] +# for ax0 in range(c): +# with T.block("T_multiply_4"): +# v_ax0 = T.axis.spatial(c, ax0) +# T.reads(rxplaceholder_4[v_ax0]) +# T.writes(T_multiply_4[v_ax0]) +# T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_subtract_3"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_subtract_4"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) +# T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] +# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): +# with T.block("T_multiply_5"): +# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) +# T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) +# T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) +# T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] +# for ax0, k0, k2, k3 in T.grid(h, n, w, c): +# with T.block("T_multiply_red_1"): +# v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) +# T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) +# T.writes(T_multiply_red_1[v_ax0]) +# with T.init(): +# T_multiply_red_1[v_ax0] = T.float32(0) +# T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] +# for ax0 in range(h): +# with T.block("T_divide_3"): +# v_ax0 = T.axis.spatial(h, ax0) +# T.reads(T_multiply_red_1[v_ax0]) +# T.writes(T_divide_3[v_ax0]) +# T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / T.Cast("float32", n * w * c) +# for ax0 in range(h): +# with T.block("T_multiply_6"): +# v_ax0 = T.axis.spatial(h, ax0) +# T.reads(T_divide_3[v_ax0]) +# T.writes(T_multiply_6[v_ax0]) +# T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] +# for ax0 in range(T.max(c, h)): +# with T.block("T_add_3"): +# v_ax0 = T.axis.spatial(T.max(c, h), ax0) +# T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) +# T.writes(T_add_2[v_ax0]) +# T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] - @R.function - def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32")): - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) - return gv - # fmt: on +# @R.function +# def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32")): +# n = T.int64() +# h = T.int64() +# w = T.int64() +# c = T.int64() +# gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) +# return gv +# # fmt: on - mod = LegalizeOps()(BatchNorm) - tvm.ir.assert_structural_equal(mod, Expected) +# mod = LegalizeOps()(BatchNorm) +# tvm.ir.assert_structural_equal(mod, Expected) From 54f00f1daa523177a3b85f65b4ad42482e496cb3 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 20:20:24 -0400 Subject: [PATCH 38/47] first legalize test passes --- .../test_transform_legalize_ops_nn_copy.py | 491 ++++++++++-------- 1 file changed, 282 insertions(+), 209 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn_copy.py b/tests/python/relax/test_transform_legalize_ops_nn_copy.py index 3969c33730b7..7ee0a6cb40cc 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn_copy.py +++ b/tests/python/relax/test_transform_legalize_ops_nn_copy.py @@ -30,8 +30,6 @@ ##################### Neural network ##################### - - def test_batch_norm(): # fmt: off @tvm.script.ir_module @@ -44,212 +42,289 @@ def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32" @tvm.script.ir_module class Expected: @T.prim_func(private=True) - def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),), "float32"), rxplaceholder_2: T.Buffer((T.int64(3),), "float32"), rxplaceholder_3: T.Buffer((T.int64(3),), "float32"), rxplaceholder_4: T.Buffer((T.int64(3),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), T_add_1: T.Buffer((T.int64(3),), "float32"), T_add_2: T.Buffer((T.int64(3),), "float32")): - T.func_attr({"tir.noalias": True}) - # with T.block("root"): - rxplaceholder_red = T.alloc_buffer((T.int64(3),)) - T_divide = T.alloc_buffer((T.int64(3),)) - T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red = T.alloc_buffer((T.int64(3),)) - T_divide_1 = T.alloc_buffer((T.int64(3),)) - T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_2 = T.alloc_buffer((T.int64(3),)) - T_multiply_3 = T.alloc_buffer((T.int64(3),)) - T_multiply_4 = T.alloc_buffer((T.int64(3),)) - T_subtract_3 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red_1 = T.alloc_buffer((T.int64(3),)) - T_divide_3 = T.alloc_buffer((T.int64(3),)) - T_multiply_6 = T.alloc_buffer((T.int64(3),)) - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("rxplaceholder_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) - T.writes(rxplaceholder_red[v_ax0]) - with T.init(): - rxplaceholder_red[v_ax0] = T.float32(0) - rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_red[v_ax0]) - T.writes(T_divide[v_ax0]) - T_divide[v_ax0] = rxplaceholder_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("T_multiply_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_divide_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_3[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_4[v_ax0]) - T.writes(T_multiply_4[v_ax0]) - T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_4"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply_5"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("T_multiply_red_1"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red_1[v_ax0]) - with T.init(): - T_multiply_red_1[v_ax0] = T.float32(0) - T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red_1[v_ax0]) - T.writes(T_divide_3[v_ax0]) - T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] * T.float32(0.00063775510204081628) - for ax0 in range(T.int64(3)): - with T.block("T_multiply_6"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide_3[v_ax0]) - T.writes(T_multiply_6[v_ax0]) - T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] - + def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + gamma = T.match_buffer(var_gamma, (T.int64(3),)) + beta = T.match_buffer(var_beta, (T.int64(3),)) + moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),)) + moving_var = T.match_buffer(var_moving_var, (T.int64(3),)) + T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),)) + T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),)) + with T.block("root"): + T.reads() + T.writes() + T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_divide = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((T.int64(3),)) + x_red = T.alloc_buffer((T.int64(3),)) + T_divide_1 = T.alloc_buffer((T.int64(3),)) + T_multiply_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_3 = T.alloc_buffer((T.int64(3),)) + T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_red = T.alloc_buffer((T.int64(3),)) + T_divide_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_5 = T.alloc_buffer((T.int64(3),)) + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_1"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0 in range(T.int64(1)): + for i1 in range(T.int64(3)): + for i2 in range(T.int64(1)): + for i3 in range(T.int64(1)): + with T.block("compute"): + v_i0 = T.axis.spatial(T.int64(1), i0) + v_i1 = T.axis.spatial(T.int64(3), i1) + v_i2 = T.axis.spatial(T.int64(1), i2) + v_i3 = T.axis.spatial(T.int64(1), i3) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_2"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_3"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_add_1"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(moving_mean[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] + for ax0 in range(T.int64(3)): + for k0 in range(T.int64(2)): + for k2 in range(T.int64(28)): + for k3 in range(T.int64(28)): + with T.block("x_red"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + v_k0 = T.axis.reduce(T.int64(2), k0) + v_k2 = T.axis.reduce(T.int64(28), k2) + v_k3 = T.axis.reduce(T.int64(28), k3) + T.reads(x[v_k0, v_ax0, v_k2, v_k3]) + T.writes(x_red[v_ax0]) + with T.init(): + x_red[v_ax0] = T.float32(0.0) + x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(x_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = x_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_1[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(moving_var[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_4"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract_1"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract_2"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0 in range(T.int64(3)): + for k0 in range(T.int64(2)): + for k2 in range(T.int64(28)): + for k3 in range(T.int64(28)): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + v_k0 = T.axis.reduce(T.int64(2), k0) + v_k2 = T.axis.reduce(T.int64(28), k2) + v_k3 = T.axis.reduce(T.int64(28), k3) + T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0.0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_2[v_ax0]) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_5"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_2[v_ax0]) + T.writes(T_multiply_5[v_ax0]) + T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] + @R.function def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): - gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) + cls = Expected + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) return gv # fmt: on @@ -257,6 +332,7 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dty tvm.ir.assert_structural_equal(mod, Expected) + # def test_batch_norm_symbolic(): # # fmt: off # @tvm.script.ir_module @@ -501,8 +577,5 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dty # mod = LegalizeOps()(BatchNorm) # tvm.ir.assert_structural_equal(mod, Expected) - - - if __name__ == "__main__": tvm.testing.main() From d74cfbfb561552b58f7d930b7126c2144042d5a7 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 20:24:05 -0400 Subject: [PATCH 39/47] all legalize pass --- .../relax/test_transform_legalize_ops_nn.py | 986 ++++++++++-------- .../test_transform_legalize_ops_nn_copy.py | 547 +++++----- 2 files changed, 871 insertions(+), 662 deletions(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index d83d0567e482..392183cbd383 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import sys +sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build import pytest @@ -1942,7 +1944,6 @@ def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: mod = LegalizeOps()(CrossEntropyWithLogits) tvm.ir.assert_structural_equal(mod, Expected) - def test_batch_norm(): # fmt: off @tvm.script.ir_module @@ -1955,212 +1956,289 @@ def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32" @tvm.script.ir_module class Expected: @T.prim_func(private=True) - def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),), "float32"), rxplaceholder_2: T.Buffer((T.int64(3),), "float32"), rxplaceholder_3: T.Buffer((T.int64(3),), "float32"), rxplaceholder_4: T.Buffer((T.int64(3),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), T_add_1: T.Buffer((T.int64(3),), "float32"), T_add_2: T.Buffer((T.int64(3),), "float32")): - T.func_attr({"tir.noalias": True}) - # with T.block("root"): - rxplaceholder_red = T.alloc_buffer((T.int64(3),)) - T_divide = T.alloc_buffer((T.int64(3),)) - T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red = T.alloc_buffer((T.int64(3),)) - T_divide_1 = T.alloc_buffer((T.int64(3),)) - T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_2 = T.alloc_buffer((T.int64(3),)) - T_multiply_3 = T.alloc_buffer((T.int64(3),)) - T_multiply_4 = T.alloc_buffer((T.int64(3),)) - T_subtract_3 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red_1 = T.alloc_buffer((T.int64(3),)) - T_divide_3 = T.alloc_buffer((T.int64(3),)) - T_multiply_6 = T.alloc_buffer((T.int64(3),)) - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("rxplaceholder_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) - T.writes(rxplaceholder_red[v_ax0]) - with T.init(): - rxplaceholder_red[v_ax0] = T.float32(0) - rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_red[v_ax0]) - T.writes(T_divide[v_ax0]) - T_divide[v_ax0] = rxplaceholder_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("T_multiply_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_divide_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_3[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_4[v_ax0]) - T.writes(T_multiply_4[v_ax0]) - T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_4"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply_5"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("T_multiply_red_1"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red_1[v_ax0]) - with T.init(): - T_multiply_red_1[v_ax0] = T.float32(0) - T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red_1[v_ax0]) - T.writes(T_divide_3[v_ax0]) - T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] * T.float32(0.00063775510204081628) - for ax0 in range(T.int64(3)): - with T.block("T_multiply_6"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide_3[v_ax0]) - T.writes(T_multiply_6[v_ax0]) - T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] - + def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + gamma = T.match_buffer(var_gamma, (T.int64(3),)) + beta = T.match_buffer(var_beta, (T.int64(3),)) + moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),)) + moving_var = T.match_buffer(var_moving_var, (T.int64(3),)) + T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),)) + T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),)) + with T.block("root"): + T.reads() + T.writes() + T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_divide = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((T.int64(3),)) + x_red = T.alloc_buffer((T.int64(3),)) + T_divide_1 = T.alloc_buffer((T.int64(3),)) + T_multiply_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_3 = T.alloc_buffer((T.int64(3),)) + T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_red = T.alloc_buffer((T.int64(3),)) + T_divide_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_5 = T.alloc_buffer((T.int64(3),)) + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_1"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0 in range(T.int64(1)): + for i1 in range(T.int64(3)): + for i2 in range(T.int64(1)): + for i3 in range(T.int64(1)): + with T.block("compute"): + v_i0 = T.axis.spatial(T.int64(1), i0) + v_i1 = T.axis.spatial(T.int64(3), i1) + v_i2 = T.axis.spatial(T.int64(1), i2) + v_i3 = T.axis.spatial(T.int64(1), i3) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_2"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_3"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_add_1"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(moving_mean[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] + for ax0 in range(T.int64(3)): + for k0 in range(T.int64(2)): + for k2 in range(T.int64(28)): + for k3 in range(T.int64(28)): + with T.block("x_red"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + v_k0 = T.axis.reduce(T.int64(2), k0) + v_k2 = T.axis.reduce(T.int64(28), k2) + v_k3 = T.axis.reduce(T.int64(28), k3) + T.reads(x[v_k0, v_ax0, v_k2, v_k3]) + T.writes(x_red[v_ax0]) + with T.init(): + x_red[v_ax0] = T.float32(0.0) + x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(x_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = x_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_1[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(moving_var[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_4"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract_1"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract_2"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0 in range(T.int64(3)): + for k0 in range(T.int64(2)): + for k2 in range(T.int64(28)): + for k3 in range(T.int64(28)): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + v_k0 = T.axis.reduce(T.int64(2), k0) + v_k2 = T.axis.reduce(T.int64(28), k2) + v_k3 = T.axis.reduce(T.int64(28), k3) + T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0.0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_2[v_ax0]) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_5"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_2[v_ax0]) + T.writes(T_multiply_5[v_ax0]) + T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] + @R.function def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): - gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) + cls = Expected + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) return gv # fmt: on @@ -2168,6 +2246,7 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dty tvm.ir.assert_structural_equal(mod, Expected) + def test_batch_norm_symbolic(): # fmt: off @tvm.script.ir_module @@ -2184,230 +2263,295 @@ def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), " @tvm.script.ir_module class Expected: @T.prim_func(private=True) - def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": True}) - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,)) - rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,)) - rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,)) - rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,)) + def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64() + x = T.match_buffer(var_x, (n, h, w, c)) + gamma = T.match_buffer(var_gamma, (c,)) + beta = T.match_buffer(var_beta, (c,)) + moving_mean = T.match_buffer(var_moving_mean, (c,)) + moving_var = T.match_buffer(var_moving_var, (c,)) T_add = T.match_buffer(var_T_add, (n, h, w, c)) T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) - # with T.block("root"): - rxplaceholder_red = T.alloc_buffer((h,)) - T_divide = T.alloc_buffer((h,)) - T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((n, h, w, c)) - T_subtract_1 = T.alloc_buffer((n, h, w, c)) - T_subtract_2 = T.alloc_buffer((n, h, w, c)) - T_multiply = T.alloc_buffer((n, h, w, c)) - T_multiply_red = T.alloc_buffer((h,)) - T_divide_1 = T.alloc_buffer((h,)) - T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_divide_2 = T.alloc_buffer((n, h, w, c)) - T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((n, h, w, c)) - T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_2 = T.alloc_buffer((c,)) - T_multiply_3 = T.alloc_buffer((h,)) - T_multiply_4 = T.alloc_buffer((c,)) - T_subtract_3 = T.alloc_buffer((n, h, w, c)) - T_subtract_4 = T.alloc_buffer((n, h, w, c)) - T_multiply_5 = T.alloc_buffer((n, h, w, c)) - T_multiply_red_1 = T.alloc_buffer((h,)) - T_divide_3 = T.alloc_buffer((h,)) - T_multiply_6 = T.alloc_buffer((h,)) - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("rxplaceholder_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) - T.writes(rxplaceholder_red[v_ax0]) - with T.init(): - rxplaceholder_red[v_ax0] = T.float32(0) - rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(rxplaceholder_red[v_ax0]) - T.writes(T_divide[v_ax0]) - T_divide[v_ax0] = rxplaceholder_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("T_multiply_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_divide_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(c): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(rxplaceholder_3[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] - for ax0 in range(h): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] - for ax0 in range(c): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(rxplaceholder_4[v_ax0]) - T.writes(T_multiply_4[v_ax0]) - T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_4"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply_5"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("T_multiply_red_1"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red_1[v_ax0]) - with T.init(): - T_multiply_red_1[v_ax0] = T.float32(0) - T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_3"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red_1[v_ax0]) - T.writes(T_divide_3[v_ax0]) - T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / T.Cast("float32", n * w * c) - for ax0 in range(h): - with T.block("T_multiply_6"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide_3[v_ax0]) - T.writes(T_multiply_6[v_ax0]) - T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] - - @R.function - def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32")): + with T.block("root"): + T.reads() + T.writes() + T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((n, h, w, c)) + T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_divide = T.alloc_buffer((n, h, w, c)) + T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply = T.alloc_buffer((n, h, w, c)) + T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((c,)) + x_red = T.alloc_buffer((h,)) + T_divide_1 = T.alloc_buffer((h,)) + T_multiply_2 = T.alloc_buffer((h,)) + T_multiply_3 = T.alloc_buffer((c,)) + T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract_1 = T.alloc_buffer((n, h, w, c)) + T_subtract_2 = T.alloc_buffer((n, h, w, c)) + T_multiply_4 = T.alloc_buffer((n, h, w, c)) + T_multiply_red = T.alloc_buffer((h,)) + T_divide_2 = T.alloc_buffer((h,)) + T_multiply_5 = T.alloc_buffer((h,)) + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_1"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0 in range(T.int64(1)): + for i1 in range(h): + for i2 in range(T.int64(1)): + for i3 in range(T.int64(1)): + with T.block("compute"): + v_i0 = T.axis.spatial(T.int64(1), i0) + v_i1 = T.axis.spatial(h, i1) + v_i2 = T.axis.spatial(T.int64(1), i2) + v_i3 = T.axis.spatial(T.int64(1), i3) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_2"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_3"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_add_1"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(c): + with T.block("T_multiply_1"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(moving_mean[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] + for ax0 in range(h): + for k0 in range(n): + for k2 in range(w): + for k3 in range(c): + with T.block("x_red"): + v_ax0 = T.axis.spatial(h, ax0) + v_k0 = T.axis.reduce(n, k0) + v_k2 = T.axis.reduce(w, k2) + v_k3 = T.axis.reduce(c, k3) + T.reads(x[v_k0, v_ax0, v_k2, v_k3]) + T.writes(x_red[v_ax0]) + with T.init(): + x_red[v_ax0] = T.float32(0.0) + x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(x_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_1[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] + for ax0 in range(c): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(moving_var[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_4"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract_1"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract_2"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0 in range(h): + for k0 in range(n): + for k2 in range(w): + for k3 in range(c): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(h, ax0) + v_k0 = T.axis.reduce(n, k0) + v_k2 = T.axis.reduce(w, k2) + v_k3 = T.axis.reduce(c, k3) + T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0.0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_2"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_2[v_ax0]) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_5"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_2[v_ax0]) + T.writes(T_multiply_5[v_ax0]) + T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] + + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32")): n = T.int64() h = T.int64() w = T.int64() c = T.int64() - gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) + cls = Expected + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) return gv - # fmt: on mod = LegalizeOps()(BatchNorm) tvm.ir.assert_structural_equal(mod, Expected) diff --git a/tests/python/relax/test_transform_legalize_ops_nn_copy.py b/tests/python/relax/test_transform_legalize_ops_nn_copy.py index 7ee0a6cb40cc..27f0e4bb6c23 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn_copy.py +++ b/tests/python/relax/test_transform_legalize_ops_nn_copy.py @@ -333,249 +333,314 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dty -# def test_batch_norm_symbolic(): -# # fmt: off -# @tvm.script.ir_module -# class BatchNorm: -# @R.function -# def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): -# n = T.int64() -# h = T.int64() -# w = T.int64() -# c = T.int64() -# gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) -# return gv - -# @tvm.script.ir_module -# class Expected: -# @T.prim_func(private=True) -# def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): -# T.func_attr({"tir.noalias": True}) -# n = T.int64() -# h = T.int64() -# w = T.int64() -# c = T.int64() -# rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c)) -# rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,)) -# rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,)) -# rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,)) -# rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,)) -# T_add = T.match_buffer(var_T_add, (n, h, w, c)) -# T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) -# T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) -# # with T.block("root"): -# rxplaceholder_red = T.alloc_buffer((h,)) -# T_divide = T.alloc_buffer((h,)) -# T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) -# T_subtract = T.alloc_buffer((n, h, w, c)) -# T_subtract_1 = T.alloc_buffer((n, h, w, c)) -# T_subtract_2 = T.alloc_buffer((n, h, w, c)) -# T_multiply = T.alloc_buffer((n, h, w, c)) -# T_multiply_red = T.alloc_buffer((h,)) -# T_divide_1 = T.alloc_buffer((h,)) -# T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) -# T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) -# compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) -# T_divide_2 = T.alloc_buffer((n, h, w, c)) -# T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) -# T_multiply_1 = T.alloc_buffer((n, h, w, c)) -# T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) -# T_multiply_2 = T.alloc_buffer((c,)) -# T_multiply_3 = T.alloc_buffer((h,)) -# T_multiply_4 = T.alloc_buffer((c,)) -# T_subtract_3 = T.alloc_buffer((n, h, w, c)) -# T_subtract_4 = T.alloc_buffer((n, h, w, c)) -# T_multiply_5 = T.alloc_buffer((n, h, w, c)) -# T_multiply_red_1 = T.alloc_buffer((h,)) -# T_divide_3 = T.alloc_buffer((h,)) -# T_multiply_6 = T.alloc_buffer((h,)) -# for ax0, k0, k2, k3 in T.grid(h, n, w, c): -# with T.block("rxplaceholder_red"): -# v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) -# T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) -# T.writes(rxplaceholder_red[v_ax0]) -# with T.init(): -# rxplaceholder_red[v_ax0] = T.float32(0) -# rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] -# for ax0 in range(h): -# with T.block("T_divide"): -# v_ax0 = T.axis.spatial(h, ax0) -# T.reads(rxplaceholder_red[v_ax0]) -# T.writes(T_divide[v_ax0]) -# T_divide[v_ax0] = rxplaceholder_red[v_ax0] / T.Cast("float32", n * w * c) -# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): -# with T.block("T_reshape"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) -# T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_subtract"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_subtract_1"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_subtract_2"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_multiply"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) -# T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] -# for ax0, k0, k2, k3 in T.grid(h, n, w, c): -# with T.block("T_multiply_red"): -# v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) -# T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) -# T.writes(T_multiply_red[v_ax0]) -# with T.init(): -# T_multiply_red[v_ax0] = T.float32(0) -# T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] -# for ax0 in range(h): -# with T.block("T_divide_1"): -# v_ax0 = T.axis.spatial(h, ax0) -# T.reads(T_multiply_red[v_ax0]) -# T.writes(T_divide_1[v_ax0]) -# T_divide_1[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) -# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): -# with T.block("T_reshape_1"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) -# T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] -# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): -# with T.block("T_add"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) -# T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) -# for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): -# with T.block("compute"): -# v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) -# T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) -# T.writes(compute[v_i0, v_i1, v_i2, v_i3]) -# compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_divide_2"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): -# with T.block("T_reshape_2"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) -# T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_multiply_1"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): -# with T.block("T_reshape_3"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) -# T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_add_1"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0 in range(c): -# with T.block("T_multiply_2"): -# v_ax0 = T.axis.spatial(c, ax0) -# T.reads(rxplaceholder_3[v_ax0]) -# T.writes(T_multiply_2[v_ax0]) -# T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] -# for ax0 in range(h): -# with T.block("T_multiply_3"): -# v_ax0 = T.axis.spatial(h, ax0) -# T.reads(T_divide[v_ax0]) -# T.writes(T_multiply_3[v_ax0]) -# T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] -# for ax0 in range(T.max(c, h)): -# with T.block("T_add_2"): -# v_ax0 = T.axis.spatial(T.max(c, h), ax0) -# T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) -# T.writes(T_add_1[v_ax0]) -# T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] -# for ax0 in range(c): -# with T.block("T_multiply_4"): -# v_ax0 = T.axis.spatial(c, ax0) -# T.reads(rxplaceholder_4[v_ax0]) -# T.writes(T_multiply_4[v_ax0]) -# T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_subtract_3"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_subtract_4"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) -# T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] -# for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): -# with T.block("T_multiply_5"): -# v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) -# T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) -# T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) -# T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] -# for ax0, k0, k2, k3 in T.grid(h, n, w, c): -# with T.block("T_multiply_red_1"): -# v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) -# T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) -# T.writes(T_multiply_red_1[v_ax0]) -# with T.init(): -# T_multiply_red_1[v_ax0] = T.float32(0) -# T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] -# for ax0 in range(h): -# with T.block("T_divide_3"): -# v_ax0 = T.axis.spatial(h, ax0) -# T.reads(T_multiply_red_1[v_ax0]) -# T.writes(T_divide_3[v_ax0]) -# T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / T.Cast("float32", n * w * c) -# for ax0 in range(h): -# with T.block("T_multiply_6"): -# v_ax0 = T.axis.spatial(h, ax0) -# T.reads(T_divide_3[v_ax0]) -# T.writes(T_multiply_6[v_ax0]) -# T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] -# for ax0 in range(T.max(c, h)): -# with T.block("T_add_3"): -# v_ax0 = T.axis.spatial(T.max(c, h), ax0) -# T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) -# T.writes(T_add_2[v_ax0]) -# T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] +def test_batch_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class BatchNorm: + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) + return gv -# @R.function -# def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32")): -# n = T.int64() -# h = T.int64() -# w = T.int64() -# c = T.int64() -# gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) -# return gv -# # fmt: on + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64() + x = T.match_buffer(var_x, (n, h, w, c)) + gamma = T.match_buffer(var_gamma, (c,)) + beta = T.match_buffer(var_beta, (c,)) + moving_mean = T.match_buffer(var_moving_mean, (c,)) + moving_var = T.match_buffer(var_moving_var, (c,)) + T_add = T.match_buffer(var_T_add, (n, h, w, c)) + T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) + T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) + with T.block("root"): + T.reads() + T.writes() + T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((n, h, w, c)) + T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_divide = T.alloc_buffer((n, h, w, c)) + T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply = T.alloc_buffer((n, h, w, c)) + T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((c,)) + x_red = T.alloc_buffer((h,)) + T_divide_1 = T.alloc_buffer((h,)) + T_multiply_2 = T.alloc_buffer((h,)) + T_multiply_3 = T.alloc_buffer((c,)) + T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract_1 = T.alloc_buffer((n, h, w, c)) + T_subtract_2 = T.alloc_buffer((n, h, w, c)) + T_multiply_4 = T.alloc_buffer((n, h, w, c)) + T_multiply_red = T.alloc_buffer((h,)) + T_divide_2 = T.alloc_buffer((h,)) + T_multiply_5 = T.alloc_buffer((h,)) + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_1"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0 in range(T.int64(1)): + for i1 in range(h): + for i2 in range(T.int64(1)): + for i3 in range(T.int64(1)): + with T.block("compute"): + v_i0 = T.axis.spatial(T.int64(1), i0) + v_i1 = T.axis.spatial(h, i1) + v_i2 = T.axis.spatial(T.int64(1), i2) + v_i3 = T.axis.spatial(T.int64(1), i3) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_2"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_3"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_add_1"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(c): + with T.block("T_multiply_1"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(moving_mean[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] + for ax0 in range(h): + for k0 in range(n): + for k2 in range(w): + for k3 in range(c): + with T.block("x_red"): + v_ax0 = T.axis.spatial(h, ax0) + v_k0 = T.axis.reduce(n, k0) + v_k2 = T.axis.reduce(w, k2) + v_k3 = T.axis.reduce(c, k3) + T.reads(x[v_k0, v_ax0, v_k2, v_k3]) + T.writes(x_red[v_ax0]) + with T.init(): + x_red[v_ax0] = T.float32(0.0) + x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(x_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_1[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] + for ax0 in range(c): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(moving_var[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_4"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract_1"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract_2"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0 in range(h): + for k0 in range(n): + for k2 in range(w): + for k3 in range(c): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(h, ax0) + v_k0 = T.axis.reduce(n, k0) + v_k2 = T.axis.reduce(w, k2) + v_k3 = T.axis.reduce(c, k3) + T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0.0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_2"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_2[v_ax0]) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_5"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_2[v_ax0]) + T.writes(T_multiply_5[v_ax0]) + T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] + + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32")): + n = T.int64() + h = T.int64() + w = T.int64() + c = T.int64() + cls = Expected + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) + return gv -# mod = LegalizeOps()(BatchNorm) -# tvm.ir.assert_structural_equal(mod, Expected) + mod = LegalizeOps()(BatchNorm) + tvm.ir.assert_structural_equal(mod, Expected) if __name__ == "__main__": tvm.testing.main() From a9de5efcfce9637c449af392292f701512bda128 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 20:26:14 -0400 Subject: [PATCH 40/47] all tests pass --- .../relax/test_from_exported_to_cuda.py | 4 - .../relax/test_transform_legalize_ops_nn.py | 2 - .../test_transform_legalize_ops_nn_copy.py | 646 ------------------ 3 files changed, 652 deletions(-) delete mode 100644 tests/python/relax/test_transform_legalize_ops_nn_copy.py diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 0b6708191031..f83d38273126 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -15,10 +15,6 @@ # specific language governing permissions and limitations # under the License. -# TODO remove -import sys -sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build - import tvm from tvm import relax import tvm.testing diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 392183cbd383..45205c4f2ae4 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys -sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build import pytest diff --git a/tests/python/relax/test_transform_legalize_ops_nn_copy.py b/tests/python/relax/test_transform_legalize_ops_nn_copy.py deleted file mode 100644 index 27f0e4bb6c23..000000000000 --- a/tests/python/relax/test_transform_legalize_ops_nn_copy.py +++ /dev/null @@ -1,646 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# TODO remove -import sys -sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build - -import pytest - -import tvm -import tvm.testing -from tvm.relax.transform import LegalizeOps -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tir as T - -##################### Neural network ##################### - -def test_batch_norm(): - # fmt: off - @tvm.script.ir_module - class BatchNorm: - @R.function - def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32"), beta: R.Tensor((3,), "float32"), moving_mean: R.Tensor((3,), "float32"), moving_var: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")): - gv: R.Tuple(R.Tensor((2, 3, 28, 28), "float32"), R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) - return gv - - @tvm.script.ir_module - class Expected: - @T.prim_func(private=True) - def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - gamma = T.match_buffer(var_gamma, (T.int64(3),)) - beta = T.match_buffer(var_beta, (T.int64(3),)) - moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),)) - moving_var = T.match_buffer(var_moving_var, (T.int64(3),)) - T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),)) - T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),)) - with T.block("root"): - T.reads() - T.writes() - T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_divide = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((T.int64(3),)) - x_red = T.alloc_buffer((T.int64(3),)) - T_divide_1 = T.alloc_buffer((T.int64(3),)) - T_multiply_2 = T.alloc_buffer((T.int64(3),)) - T_multiply_3 = T.alloc_buffer((T.int64(3),)) - T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red = T.alloc_buffer((T.int64(3),)) - T_divide_2 = T.alloc_buffer((T.int64(3),)) - T_multiply_5 = T.alloc_buffer((T.int64(3),)) - for ax0 in range(T.int64(1)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_subtract"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(1)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_1"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0 in range(T.int64(1)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_add"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0 in range(T.int64(1)): - for i1 in range(T.int64(3)): - for i2 in range(T.int64(1)): - for i3 in range(T.int64(1)): - with T.block("compute"): - v_i0 = T.axis.spatial(T.int64(1), i0) - v_i1 = T.axis.spatial(T.int64(3), i1) - v_i2 = T.axis.spatial(T.int64(1), i2) - v_i3 = T.axis.spatial(T.int64(1), i3) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(1)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_2"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_multiply"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(1)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_3"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_add_1"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_1"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(moving_mean[v_ax0]) - T.writes(T_multiply_1[v_ax0]) - T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] - for ax0 in range(T.int64(3)): - for k0 in range(T.int64(2)): - for k2 in range(T.int64(28)): - for k3 in range(T.int64(28)): - with T.block("x_red"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - v_k0 = T.axis.reduce(T.int64(2), k0) - v_k2 = T.axis.reduce(T.int64(28), k2) - v_k3 = T.axis.reduce(T.int64(28), k3) - T.reads(x[v_k0, v_ax0, v_k2, v_k3]) - T.writes(x_red[v_ax0]) - with T.init(): - x_red[v_ax0] = T.float32(0.0) - x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(x_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = x_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0 in range(T.int64(3)): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide_1[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(moving_var[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] - for ax0 in range(T.int64(1)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_4"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_subtract_1"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_subtract_2"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(2)): - for ax1 in range(T.int64(3)): - for ax2 in range(T.int64(28)): - for ax3 in range(T.int64(28)): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(T.int64(2), ax0) - v_ax1 = T.axis.spatial(T.int64(3), ax1) - v_ax2 = T.axis.spatial(T.int64(28), ax2) - v_ax3 = T.axis.spatial(T.int64(28), ax3) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0 in range(T.int64(3)): - for k0 in range(T.int64(2)): - for k2 in range(T.int64(28)): - for k3 in range(T.int64(28)): - with T.block("T_multiply_red"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - v_k0 = T.axis.reduce(T.int64(2), k0) - v_k2 = T.axis.reduce(T.int64(28), k2) - v_k3 = T.axis.reduce(T.int64(28), k3) - T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0.0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_2[v_ax0]) - T_divide_2[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0 in range(T.int64(3)): - with T.block("T_multiply_5"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide_2[v_ax0]) - T.writes(T_multiply_5[v_ax0]) - T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] - - @R.function - def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): - cls = Expected - gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) - return gv - # fmt: on - - mod = LegalizeOps()(BatchNorm) - tvm.ir.assert_structural_equal(mod, Expected) - - - -def test_batch_norm_symbolic(): - # fmt: off - @tvm.script.ir_module - class BatchNorm: - @R.function - def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), "float32"), beta: R.Tensor(("c",), "float32"), moving_mean: R.Tensor(("c",), "float32"), moving_var: R.Tensor(("c",), "float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), "float32"), R.Tensor(("c",), "float32"), R.Tensor(("c",), "float32")): - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - gv: R.Tuple(R.Tensor((n, h, w, c), "float32"), R.Tensor((c,), "float32"), R.Tensor((c,), "float32")) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1) - return gv - - @tvm.script.ir_module - class Expected: - @T.prim_func(private=True) - def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64() - x = T.match_buffer(var_x, (n, h, w, c)) - gamma = T.match_buffer(var_gamma, (c,)) - beta = T.match_buffer(var_beta, (c,)) - moving_mean = T.match_buffer(var_moving_mean, (c,)) - moving_var = T.match_buffer(var_moving_var, (c,)) - T_add = T.match_buffer(var_T_add, (n, h, w, c)) - T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) - T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) - with T.block("root"): - T.reads() - T.writes() - T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((n, h, w, c)) - T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_divide = T.alloc_buffer((n, h, w, c)) - T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply = T.alloc_buffer((n, h, w, c)) - T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((c,)) - x_red = T.alloc_buffer((h,)) - T_divide_1 = T.alloc_buffer((h,)) - T_multiply_2 = T.alloc_buffer((h,)) - T_multiply_3 = T.alloc_buffer((c,)) - T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_subtract_1 = T.alloc_buffer((n, h, w, c)) - T_subtract_2 = T.alloc_buffer((n, h, w, c)) - T_multiply_4 = T.alloc_buffer((n, h, w, c)) - T_multiply_red = T.alloc_buffer((h,)) - T_divide_2 = T.alloc_buffer((h,)) - T_multiply_5 = T.alloc_buffer((h,)) - for ax0 in range(T.int64(1)): - for ax1 in range(h): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_subtract"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(1)): - for ax1 in range(h): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_1"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0 in range(T.int64(1)): - for ax1 in range(h): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_add"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0 in range(T.int64(1)): - for i1 in range(h): - for i2 in range(T.int64(1)): - for i3 in range(T.int64(1)): - with T.block("compute"): - v_i0 = T.axis.spatial(T.int64(1), i0) - v_i1 = T.axis.spatial(h, i1) - v_i2 = T.axis.spatial(T.int64(1), i2) - v_i3 = T.axis.spatial(T.int64(1), i3) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(1)): - for ax1 in range(h): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_2"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_multiply"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(1)): - for ax1 in range(h): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_3"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_add_1"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(c): - with T.block("T_multiply_1"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(moving_mean[v_ax0]) - T.writes(T_multiply_1[v_ax0]) - T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] - for ax0 in range(h): - for k0 in range(n): - for k2 in range(w): - for k3 in range(c): - with T.block("x_red"): - v_ax0 = T.axis.spatial(h, ax0) - v_k0 = T.axis.reduce(n, k0) - v_k2 = T.axis.reduce(w, k2) - v_k3 = T.axis.reduce(c, k3) - T.reads(x[v_k0, v_ax0, v_k2, v_k3]) - T.writes(x_red[v_ax0]) - with T.init(): - x_red[v_ax0] = T.float32(0.0) - x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(x_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0 in range(h): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide_1[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] - for ax0 in range(c): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(moving_var[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] - for ax0 in range(T.int64(1)): - for ax1 in range(h): - for ax2 in range(T.int64(1)): - for ax3 in range(T.int64(1)): - with T.block("T_reshape_4"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(T.int64(1), ax2) - v_ax3 = T.axis.spatial(T.int64(1), ax3) - T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_subtract_1"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_subtract_2"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(n): - for ax1 in range(h): - for ax2 in range(w): - for ax3 in range(c): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(n, ax0) - v_ax1 = T.axis.spatial(h, ax1) - v_ax2 = T.axis.spatial(w, ax2) - v_ax3 = T.axis.spatial(c, ax3) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0 in range(h): - for k0 in range(n): - for k2 in range(w): - for k3 in range(c): - with T.block("T_multiply_red"): - v_ax0 = T.axis.spatial(h, ax0) - v_k0 = T.axis.reduce(n, k0) - v_k2 = T.axis.reduce(w, k2) - v_k3 = T.axis.reduce(c, k3) - T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0.0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_2"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_2[v_ax0]) - T_divide_2[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0 in range(h): - with T.block("T_multiply_5"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide_2[v_ax0]) - T.writes(T_multiply_5[v_ax0]) - T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] - - @R.function - def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32")): - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - cls = Expected - gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) - return gv - - mod = LegalizeOps()(BatchNorm) - tvm.ir.assert_structural_equal(mod, Expected) - -if __name__ == "__main__": - tvm.testing.main() From 6458a641b77d664234012ca51fcae7b5c550c13f Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 20:28:12 -0400 Subject: [PATCH 41/47] linting --- .../torch/exported_program_translator.py | 62 +++++++++---------- python/tvm/topi/nn/batch_norm.py | 2 - .../test_from_exported_batch_norm_only.py | 24 ++++--- .../relax/test_from_exported_to_cuda.py | 12 +++- .../relax/test_transform_legalize_ops_nn.py | 6 +- 5 files changed, 58 insertions(+), 48 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index bc440fcf11b0..e119d0bf32ee 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -60,43 +60,43 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: print("running mean", running_mean) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) print("running var", running_var) - ignore_running_stats = node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) + ignore_running_stats = ( + node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) + ) track_running_stats = not ignore_running_stats print("_batch_norm found a track_running_stats =", track_running_stats) momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) - print("momentum", momentum) # TODO is this affine? + print("momentum", momentum) # TODO is this affine? eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - print("eps", node.args[7]) # TODO that's eps !!!!! - print("node.args[8]", node.args[8]) # TODO remove + print("eps", node.args[7]) # TODO that's eps !!!!! + print("node.args[8]", node.args[8]) # TODO remove if track_running_stats: training = True # TODO restore inside = relax.op.nn.batch_norm( - data=x, - gamma=weight, - beta=bias, - moving_mean=running_mean, - moving_var=running_var, - axis=1, # Always over channel - epsilon=eps, - momentum=momentum, - training=training, - )[0] - print("type of inside", type(inside)) # - - outside = self.block_builder.emit( - inside - ) - print("type of outside", type(outside)) # + data=x, + gamma=weight, + beta=bias, + moving_mean=running_mean, + moving_var=running_var, + axis=1, # Always over channel + epsilon=eps, + momentum=momentum, + training=training, + )[0] + print("type of inside", type(inside)) # + + outside = self.block_builder.emit(inside) + print("type of outside", type(outside)) # return outside def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: print("Inside batch norm functional") # This method is called for batch_norm in training mode # TODO does not have correctness! - # TODO we need to store the running mean and variance returned by the + # TODO we need to store the running mean and variance returned by the # previous call to batch_norm and pass it again training = True return self._batch_norm(node, training) @@ -156,24 +156,20 @@ def _upsample_impl( # ) # ) - inside = relax.op.image.resize2d( - x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans - ) - - print("type of inside", type(inside)) # - - - outside = self.block_builder.emit( - inside + inside = relax.op.image.resize2d( + x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans ) - print("type of outside", type(outside)) # + print("type of inside", type(inside)) # - return outside + outside = self.block_builder.emit(inside) + print("type of outside", type(outside)) # + + return outside def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: - print("Inside upsample bilinear 2d") + print("Inside upsample bilinear 2d") x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) align_corners = ( diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index 6225b5aa0852..f8ed0de042e4 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -122,7 +122,6 @@ def batch_norm( ) data_var_rs = topi.reshape(data_var, shape) - if training: moving_mean_rs = topi.reshape(moving_mean, shape) moving_var_rs = topi.reshape(moving_var, shape) @@ -139,7 +138,6 @@ def batch_norm( out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon) - if scale: out = out * topi.reshape(gamma, shape) if center: diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py index c1867c8eec7d..bc5729455552 100644 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ b/tests/python/relax/test_from_exported_batch_norm_only.py @@ -17,7 +17,8 @@ # TODO remove import sys -sys.path.append('/ssd1/htalendr/tvm/python') # Refer to local TVM build + +sys.path.append("/ssd1/htalendr/tvm/python") # Refer to local TVM build import tvm from tvm import relax @@ -61,15 +62,18 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar print("type of pytorch_out", type(pytorch_out)) print("pytorch output shape", pytorch_out.shape) - print("len of gpu_out", len(gpu_out)) # 1 for all tests - print("type of gpu_out[0]", type(gpu_out[0])) # tvm.ir.container.Array for batch norm, tvm.runtime.ndarray.NDArray for both existing tests - print("gpu_out[0] shape", gpu_out[0].shape) # defined for tests that work + print("len of gpu_out", len(gpu_out)) # 1 for all tests + print( + "type of gpu_out[0]", type(gpu_out[0]) + ) # tvm.ir.container.Array for batch norm, tvm.runtime.ndarray.NDArray for both existing tests + print("gpu_out[0] shape", gpu_out[0].shape) # defined for tests that work actual = gpu_out[0].numpy() desired = pytorch_out np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) + @tvm.testing.parametrize_targets("cuda") def test_batch_norm_prog(target, dev): # Default args, in a pytorch program (to ensure output is in proper type and format) @@ -79,10 +83,12 @@ class BatchNormWrapper(nn.Module): def __init__(self): super(BatchNormWrapper, self).__init__() self.bn = nn.BatchNorm2d(3) + def forward(self, x): x = self.bn(x) x = x + 1 - return x + return x + torch_module = BatchNormWrapper().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -97,6 +103,7 @@ def test_batch_norm0(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_batch_norm1(target, dev): # Eval, with momentum, no affine, with running stats @@ -106,12 +113,14 @@ def test_batch_norm1(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_batch_norm2(target, dev): # Eval, with momentum, affine, no running stats raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False).eval() + 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @@ -120,7 +129,8 @@ def test_batch_norm3(target, dev): # Eval, no momentum, affine, with running stats raw_data = np.random.randn(1, 3, 3, 3).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True).eval() + 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index f83d38273126..f659162e945d 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -290,10 +290,12 @@ class BatchNormWrapper(nn.Module): def __init__(self): super(BatchNormWrapper, self).__init__() self.bn = nn.BatchNorm2d(3) + def forward(self, x): x = self.bn(x) x = x + 1 - return x + return x + torch_module = BatchNormWrapper().eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -308,6 +310,7 @@ def test_batch_norm0(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_batch_norm1(target, dev): # Eval, with momentum, no affine, with running stats @@ -317,12 +320,14 @@ def test_batch_norm1(target, dev): ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + @tvm.testing.parametrize_targets("cuda") def test_batch_norm2(target, dev): # Eval, with momentum, affine, no running stats raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False).eval() + 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) @@ -331,7 +336,8 @@ def test_batch_norm3(target, dev): # Eval, no momentum, affine, with running stats raw_data = np.random.randn(1, 3, 3, 3).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True).eval() + 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 45205c4f2ae4..4ac4b57b91d4 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1942,6 +1942,7 @@ def cross_entropy_with_logits(var_rxplaceholder: T.handle, var_rxplaceholder_1: mod = LegalizeOps()(CrossEntropyWithLogits) tvm.ir.assert_structural_equal(mod, Expected) + def test_batch_norm(): # fmt: off @tvm.script.ir_module @@ -2232,7 +2233,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) T.writes(T_add_2[v_ax0]) T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] - + @R.function def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): cls = Expected @@ -2244,7 +2245,6 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dty tvm.ir.assert_structural_equal(mod, Expected) - def test_batch_norm_symbolic(): # fmt: off @tvm.script.ir_module @@ -2540,7 +2540,7 @@ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_mov T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) T.writes(T_add_2[v_ax0]) T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] - + @R.function def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32")): n = T.int64() From cd8fa7bb95148b0fb5af21eb8a0978fb8285fbae Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 22:50:33 -0400 Subject: [PATCH 42/47] cleanup --- .../torch/exported_program_translator.py | 45 +----- .../test_from_exported_batch_norm_only.py | 138 ------------------ 2 files changed, 7 insertions(+), 176 deletions(-) delete mode 100644 tests/python/relax/test_from_exported_batch_norm_only.py diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index e119d0bf32ee..370c6283140d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -46,36 +46,27 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: ########## Neural Network ########## def _batch_norm(self, node: fx.Node, training) -> relax.Var: - print("Inside batch norm") import numpy as np x = self.env[node.args[0]] channel = int(self.shape_of(x)[1]) dtype = x.struct_info.dtype weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) - print("weight", weight) bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) - print("bias", bias) running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) - print("running mean", running_mean) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) - print("running var", running_var) ignore_running_stats = ( node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) ) track_running_stats = not ignore_running_stats - print("_batch_norm found a track_running_stats =", track_running_stats) momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) - print("momentum", momentum) # TODO is this affine? eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) - print("eps", node.args[7]) # TODO that's eps !!!!! - print("node.args[8]", node.args[8]) # TODO remove if track_running_stats: training = True - # TODO restore - inside = relax.op.nn.batch_norm( + return self.block_builder.emit( + relax.op.nn.batch_norm( data=x, gamma=weight, beta=bias, @@ -86,14 +77,9 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: momentum=momentum, training=training, )[0] - print("type of inside", type(inside)) # - - outside = self.block_builder.emit(inside) - print("type of outside", type(outside)) # - return outside + ) def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: - print("Inside batch norm functional") # This method is called for batch_norm in training mode # TODO does not have correctness! # TODO we need to store the running mean and variance returned by the @@ -102,7 +88,6 @@ def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: return self._batch_norm(node, training) def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: - print("Inside batch norm no training") # This method is called for batch_norm in eval mode training = False return self._batch_norm(node, training) @@ -135,7 +120,6 @@ def _upsample_impl( method: str, align_corners: bool, ) -> relax.Var: - print("Inside upsample impl") coord_trans = "align_corners" if align_corners else "half_pixel" if size is None: @@ -149,27 +133,13 @@ def _upsample_impl( else: size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) - # TODO restore - # return self.block_builder.emit( - # relax.op.image.resize2d( - # x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans - # ) - # ) - - inside = relax.op.image.resize2d( - x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + return self.block_builder.emit( + relax.op.image.resize2d( + x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) ) - print("type of inside", type(inside)) # - - outside = self.block_builder.emit(inside) - - print("type of outside", type(outside)) # - - return outside - def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: - print("Inside upsample bilinear 2d") x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) align_corners = ( @@ -181,7 +151,6 @@ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: ) def _upsample_nearest2d(self, node: fx.node) -> relax.Var: - print("Inside upsample nearest 2d") x = self.env[node.args[0]] size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) diff --git a/tests/python/relax/test_from_exported_batch_norm_only.py b/tests/python/relax/test_from_exported_batch_norm_only.py deleted file mode 100644 index bc5729455552..000000000000 --- a/tests/python/relax/test_from_exported_batch_norm_only.py +++ /dev/null @@ -1,138 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# TODO remove -import sys - -sys.path.append("/ssd1/htalendr/tvm/python") # Refer to local TVM build - -import tvm -from tvm import relax -import tvm.testing -import numpy as np -import torch -from torch import nn -from torch.export import export -from tvm.relax.frontend.torch import from_exported_program -from torch.nn import Softmax, Upsample - - -def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): - """ - This util ensures that a torch module can successfully be exported to TVM - using torch.export and that the resuling IR program gives the same result - as PyTorch when ran on CUDA. - """ - raw_data_for_tvm = raw_data.copy() # In case the data is modified - torch_data = torch.from_numpy(raw_data) - example_args = (torch_data,) - - with torch.no_grad(): - exported_program = export(torch_module, example_args) - mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) - - # mod_from_torch.show() # TODO remove - - tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) - - relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) - ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) - vm = relax.VirtualMachine(ex, dev) - - gpu_data = tvm.nd.array(raw_data_for_tvm, dev) - gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] - gpu_out = vm["main"](gpu_data, *gpu_params) - - pytorch_out = torch_module(torch_data).detach().numpy() - - print("type of pytorch_out", type(pytorch_out)) - print("pytorch output shape", pytorch_out.shape) - - print("len of gpu_out", len(gpu_out)) # 1 for all tests - print( - "type of gpu_out[0]", type(gpu_out[0]) - ) # tvm.ir.container.Array for batch norm, tvm.runtime.ndarray.NDArray for both existing tests - print("gpu_out[0] shape", gpu_out[0].shape) # defined for tests that work - - actual = gpu_out[0].numpy() - desired = pytorch_out - - np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm_prog(target, dev): - # Default args, in a pytorch program (to ensure output is in proper type and format) - raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32) - - class BatchNormWrapper(nn.Module): - def __init__(self): - super(BatchNormWrapper, self).__init__() - self.bn = nn.BatchNorm2d(3) - - def forward(self, x): - x = self.bn(x) - x = x + 1 - return x - - torch_module = BatchNormWrapper().eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) - - -# # TODO can combine the tests together (they are separete to know which test fails) -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm0(target, dev): - # Eval, no momentum, no affine, no running stats - raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm1(target, dev): - # Eval, with momentum, no affine, with running stats - raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, device=None, dtype=None - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm2(target, dev): - # Eval, with momentum, affine, no running stats - raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -@tvm.testing.parametrize_targets("cuda") -def test_batch_norm3(target, dev): - # Eval, no momentum, affine, with running stats - raw_data = np.random.randn(1, 3, 3, 3).astype(np.float32) - torch_module0 = nn.BatchNorm2d( - 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True - ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) - - -if __name__ == "__main__": - tvm.testing.main() From 73ef53e938bf4026192a27b37d605f71c73b7d37 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 22:51:31 -0400 Subject: [PATCH 43/47] cleanup batchnorm --- python/tvm/topi/nn/batch_norm.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index f8ed0de042e4..8308c93eae4f 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -129,13 +129,6 @@ def batch_norm( out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) else: - - print("data is", data) - print("data_mean_rs is", data_mean_rs) - print("data_var_rs is", data_var_rs) - print("epsilon is", epsilon) - print("sqrt of data_var_rs + epsilon is", topi.math.sqrt(data_var_rs + epsilon)) - out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon) if scale: From 1578ae2aa355b736728539ef66c8d3258bb86712 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 22:52:10 -0400 Subject: [PATCH 44/47] linting --- .../torch/exported_program_translator.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 370c6283140d..70ca4a524cd0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -67,16 +67,16 @@ def _batch_norm(self, node: fx.Node, training) -> relax.Var: return self.block_builder.emit( relax.op.nn.batch_norm( - data=x, - gamma=weight, - beta=bias, - moving_mean=running_mean, - moving_var=running_var, - axis=1, # Always over channel - epsilon=eps, - momentum=momentum, - training=training, - )[0] + data=x, + gamma=weight, + beta=bias, + moving_mean=running_mean, + moving_var=running_var, + axis=1, # Always over channel + epsilon=eps, + momentum=momentum, + training=training, + )[0] ) def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: From 95254bc47e072319b71178add078a3012d83f45e Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 22:53:14 -0400 Subject: [PATCH 45/47] smaller third test --- tests/python/relax/test_from_exported_to_cuda.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index f659162e945d..d60576271434 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -334,9 +334,9 @@ def test_batch_norm2(target, dev): @tvm.testing.parametrize_targets("cuda") def test_batch_norm3(target, dev): # Eval, no momentum, affine, with running stats - raw_data = np.random.randn(1, 3, 3, 3).astype(np.float32) + raw_data = np.random.randn(1, 2, 2, 2).astype(np.float32) torch_module0 = nn.BatchNorm2d( - 3, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True + 2, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) From 698460956aeff14d7548d3a468e37b2d82ae6373 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 22:53:53 -0400 Subject: [PATCH 46/47] formatting --- tests/python/relax/test_from_exported_to_cuda.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index d60576271434..88e79019d6d6 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -300,7 +300,6 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) -# # TODO can combine the tests together (they are separete to know which test fails) @tvm.testing.parametrize_targets("cuda") def test_batch_norm0(target, dev): # Eval, no momentum, no affine, no running stats From 8233013cbea68a57d29b6f09235fe1e0e4e2db93 Mon Sep 17 00:00:00 2001 From: Hugo Latendresse Date: Sun, 23 Mar 2025 22:56:12 -0400 Subject: [PATCH 47/47] renaming --- tests/python/relax/test_from_exported_to_cuda.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 88e79019d6d6..70b8503d6e4a 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -304,40 +304,40 @@ def forward(self, x): def test_batch_norm0(target, dev): # Eval, no momentum, no affine, no running stats raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32) - torch_module0 = nn.BatchNorm2d( + torch_module = nn.BatchNorm2d( 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm1(target, dev): # Eval, with momentum, no affine, with running stats raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( + torch_module = nn.BatchNorm2d( 4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, device=None, dtype=None ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm2(target, dev): # Eval, with momentum, affine, no running stats raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( + torch_module = nn.BatchNorm2d( 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @tvm.testing.parametrize_targets("cuda") def test_batch_norm3(target, dev): # Eval, no momentum, affine, with running stats raw_data = np.random.randn(1, 2, 2, 2).astype(np.float32) - torch_module0 = nn.BatchNorm2d( + torch_module = nn.BatchNorm2d( 2, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True ).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) if __name__ == "__main__":