diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index f0884bf2d64e..e9013ba66153 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -172,4 +172,18 @@ def relay_to_relax( def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: return BindParams("main", weights)(mod) - return codegen.load(inputs, post_load=_bind_weights) + mod = codegen.load(inputs, post_load=_bind_weights) + + mod = tvm.ir.transform.Sequential( + [ + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ], + name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc", + )(mod) + + return mod diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index c30e05ed989c..489ca0a2b528 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -71,4 +71,18 @@ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRMo return mod codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder) - return codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights) + mod = codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights) + + mod = tvm.ir.transform.Sequential( + [ + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ], + name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc", + )(mod) + + return mod diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 0b65240401fc..20c47d929125 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -71,9 +71,7 @@ void RelaxCodeGen::CodeGenGraph() { continue; } int scope_level = CompareScope(node); - if (scope_level == 1) { - stack_.scope_start("block_builder.dataflow()"); - } else if (scope_level == -1) { + if (scope_level == -1) { stack_.scope_end(); } CodeGenNode(node, config()->use_tools); @@ -83,13 +81,11 @@ void RelaxCodeGen::CodeGenGraph() { for (size_t i = 0; i < scopes().size() - 1; i++) { stack_.scope_end(); } - } else if (scopes().size() == 0) { - // start dataflow scope for non-scope graph - stack_.scope_start("block_builder.dataflow()"); } // mark outputs stack_.comment("Emit the outputs"); Array idx_exits; + for (const auto& e : graph()->GetExits()) { const auto& idx_exit = IdxNodeBase(e) + (config()->use_tools ? "_exit" : ""); if (config()->use_tools) { @@ -104,10 +100,9 @@ void RelaxCodeGen::CodeGenGraph() { stack_.call_arg(DocUtils::ToStr(e->name + "_exit"), "name_hint"); } } - stack_.func_call("emit_output", idx_exit, "block_builder").call_arg(idx_exit); idx_exits.push_back(idx_exit); } - stack_.scope_end(); + if (config()->use_tools) { stack_.func_call("msc_tools.execute_step", "output").call_arg(DocUtils::ToStr("after_build")); if (idx_exits.size() == 1) { diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 3867083b90ca..16b78193ae6a 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -36,9 +36,6 @@ const Array RelaxOpCode::GetDocs() { if (node()->optype == "input" || node()->optype == "constant" || node()->optype == "shape") { emit_var = false; } - if (node()->optype == "tuple" && node()->children.size() == 0) { - emit_var = false; - } if (emit_var) { const auto& name = config()->explicit_name ? node()->name : ""; BuilderEmit(IdxNode(), name); diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index cff2ee18ca6c..fdc15777152b 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -21,19 +21,51 @@ from torch import fx from torch.nn import Module +import numpy as np + import tvm.testing from tvm.relax.frontend.torch import from_fx from tvm.contrib.msc.core.frontend import translate from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen -def verify_model(torch_model, input_info, opt_config=None): +def _verify_model(torch_model, input_info, opt_config=None): graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): - expected = from_fx(graph_model, input_info) - graph, weights = translate.from_relax(expected, opt_config=opt_config) - mod = tvm_codegen.to_relax(graph, weights, codegen_config={"explicit_name": False}) - tvm.ir.assert_structural_equal(mod, expected) + orig_mod = from_fx(graph_model, input_info) + + target = "llvm" + dev = tvm.cpu() + args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info] + + def _tvm_runtime_to_np(obj): + if isinstance(obj, tvm.runtime.NDArray): + return obj.numpy() + elif isinstance(obj, tvm.runtime.ShapeTuple): + return np.array(obj, dtype="int64") + elif isinstance(obj, (list, tvm.ir.container.Array)): + return [_tvm_runtime_to_np(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(_tvm_runtime_to_np(item) for item in obj) + else: + return obj + + def _run_relax(relax_mod): + relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) + relax_exec = tvm.relax.build(relax_mod, target) + vm_runner = tvm.relax.VirtualMachine(relax_exec, dev) + res = vm_runner["main"](*args) + + return _tvm_runtime_to_np(res) + + rt_mod = tvm_codegen.to_relax( + *translate.from_relax(orig_mod, opt_config=opt_config), + codegen_config={"explicit_name": False}, + ) + + orig_output = _run_relax(orig_mod) + rt_output = _run_relax(rt_mod) + tvm.testing.assert_allclose(orig_output, rt_output) def test_conv1d(): @@ -56,8 +88,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10], "float32")] - verify_model(Conv1D1(), input_info) - verify_model(Conv1D2(), input_info) + _verify_model(Conv1D1(), input_info) + _verify_model(Conv1D2(), input_info) def test_conv2d(): @@ -80,8 +112,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Conv2D1(), input_info) - verify_model(Conv2D2(), input_info) + _verify_model(Conv2D1(), input_info) + _verify_model(Conv2D2(), input_info) def test_linear(): @@ -108,9 +140,9 @@ def forward(self, x, y): return torch.matmul(x, y) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info) - verify_model(Dense2(), input_info) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) + _verify_model(Dense1(), input_info) + _verify_model(Dense2(), input_info) + _verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) def test_bmm(): @@ -121,7 +153,7 @@ def forward(self, x, y): return torch.bmm(x, y) input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - verify_model(BMM(), input_info) + _verify_model(BMM(), input_info) def test_baddbmm(): @@ -140,8 +172,8 @@ def forward(self, c, x, y): ((4, 128, 256), "float32"), ((4, 256, 512), "float32"), ] - verify_model(BAddBMM1(), input_info) - verify_model(BAddBMM2(), input_info) + _verify_model(BAddBMM1(), input_info) + _verify_model(BAddBMM2(), input_info) def test_relu(): @@ -160,8 +192,8 @@ def forward(self, data): return torch.nn.functional.relu(data) input_info = [([10, 10], "float32")] - verify_model(ReLU(), input_info) - verify_model(ReLU1(), input_info) + _verify_model(ReLU(), input_info) + _verify_model(ReLU1(), input_info) def test_relu6(): @@ -176,7 +208,7 @@ def forward(self, data): return self.relu6(data) input_info = [([10, 10], "float32")] - verify_model(ReLU6(), input_info) + _verify_model(ReLU6(), input_info) def test_maxpool2d(): @@ -207,9 +239,9 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(MaxPool2d(), input_info) - verify_model(MaxPool2d2(), input_info) - verify_model(MaxPool2d3(), input_info) + _verify_model(MaxPool2d(), input_info) + _verify_model(MaxPool2d2(), input_info) + _verify_model(MaxPool2d3(), input_info) def test_avgpool2d(): @@ -232,8 +264,8 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(AvgPool2d(), input_info) - verify_model(AvgPool2d2(), input_info) + _verify_model(AvgPool2d(), input_info) + _verify_model(AvgPool2d2(), input_info) def test_adaptive_avgpool2d(): @@ -248,7 +280,7 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(AdaptiveAvgPool2d0(), input_info) + _verify_model(AdaptiveAvgPool2d0(), input_info) def test_flatten(): @@ -263,8 +295,8 @@ def forward(self, data): return self.f(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Flatten(), input_info) - verify_model(torch.nn.Flatten(2, -1), input_info) + _verify_model(Flatten(), input_info) + _verify_model(torch.nn.Flatten(2, -1), input_info) def test_batchnorm2d(): @@ -279,7 +311,7 @@ def forward(self, data): return self.batchnorm(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(BatchNorm2d(), input_info) + _verify_model(BatchNorm2d(), input_info) def test_embedding(): @@ -293,8 +325,8 @@ def __init__(self): def forward(self, data): return self.embedding(data) - verify_model(Embedding(), [([4], "int64")]) - verify_model(Embedding(), [([4, 5], "int64")]) + _verify_model(Embedding(), [([4], "int64")]) + _verify_model(Embedding(), [([4, 5], "int64")]) def test_dropout(): @@ -313,8 +345,8 @@ def forward(self, data): return torch.dropout(data, 0.5, train=True) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dropout1(), input_info) - verify_model(Dropout2(), input_info) + _verify_model(Dropout1(), input_info) + _verify_model(Dropout2(), input_info) def test_layernorm(): @@ -329,7 +361,7 @@ def forward(self, data): return self.layernorm(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm(), input_info) + _verify_model(LayerNorm(), input_info) def test_functional_layernorm(): @@ -347,7 +379,7 @@ def forward(self, data): ) input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm((10, 10)), input_info) + _verify_model(LayerNorm((10, 10)), input_info) def test_cross_entropy(): @@ -379,9 +411,9 @@ def forward(self, logits, targets): return self.loss(logits, targets) input_info = [([3, 2], "float32"), ([3], "int32")] - verify_model(CrossEntropy1(), input_info) - verify_model(CrossEntropy2(), input_info) - verify_model(CrossEntropy3(), input_info) + _verify_model(CrossEntropy1(), input_info) + _verify_model(CrossEntropy2(), input_info) + _verify_model(CrossEntropy3(), input_info) def test_functional_cross_entropy(): @@ -392,7 +424,7 @@ def forward(self, logits, targets): return torch.nn.functional.cross_entropy(logits, targets) input_info = [([3, 10], "float32"), ([3], "int32")] - verify_model(CrossEntropy(), input_info) + _verify_model(CrossEntropy(), input_info) def test_silu(): @@ -411,8 +443,8 @@ def forward(self, data): return torch.nn.functional.silu(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(SiLU(), input_info) - verify_model(SiLU2(), input_info) + _verify_model(SiLU(), input_info) + _verify_model(SiLU2(), input_info) def test_groupnorm(): @@ -427,7 +459,7 @@ def forward(self, data): return self.groupnorm(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(GroupNorm(), input_info) + _verify_model(GroupNorm(), input_info) def test_softmax(): @@ -442,7 +474,7 @@ def forward(self, data): return self.softmax(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Softmax(), input_info) + _verify_model(Softmax(), input_info) def test_binary(): @@ -460,8 +492,8 @@ class Add2(Module): def forward(self, lhs): return lhs + 1.0 - verify_model(Add1(), input_info1) - verify_model(Add2(), input_info2) + _verify_model(Add1(), input_info1) + _verify_model(Add2(), input_info2) # Sub class Sub1(Module): @@ -472,8 +504,8 @@ class Sub2(Module): def forward(self, lhs): return lhs - 1.0 - verify_model(Sub1(), input_info1) - verify_model(Sub2(), input_info2) + _verify_model(Sub1(), input_info1) + _verify_model(Sub2(), input_info2) # Mul class Mul1(Module): @@ -484,8 +516,8 @@ class Mul2(Module): def forward(self, lhs): return lhs * 1.0 - verify_model(Mul1(), input_info1) - verify_model(Mul2(), input_info2) + _verify_model(Mul1(), input_info1) + _verify_model(Mul2(), input_info2) # True div class TrueDiv1(Module): @@ -496,8 +528,8 @@ class TrueDiv2(Module): def forward(self, lhs): return lhs / 1.0 - verify_model(TrueDiv1(), input_info1) - verify_model(TrueDiv2(), input_info2) + _verify_model(TrueDiv1(), input_info1) + _verify_model(TrueDiv2(), input_info2) # Floor div class FloorDiv1(Module): @@ -508,8 +540,8 @@ class FloorDiv2(Module): def forward(self, lhs): return lhs // 1.0 - verify_model(FloorDiv1(), input_info1) - verify_model(FloorDiv2(), input_info2) + _verify_model(FloorDiv1(), input_info1) + _verify_model(FloorDiv2(), input_info2) # Power class Power1(Module): @@ -520,8 +552,8 @@ class Power2(Module): def forward(self, lhs): return lhs**1.0 - verify_model(Power1(), input_info1) - verify_model(Power2(), input_info2) + _verify_model(Power1(), input_info1) + _verify_model(Power2(), input_info2) # LT class LT1(Module): @@ -532,8 +564,8 @@ class LT2(Module): def forward(self, lhs): return lhs < 1.0 - verify_model(LT1(), input_info1) - verify_model(LT2(), input_info2) + _verify_model(LT1(), input_info1) + _verify_model(LT2(), input_info2) def test_size(): @@ -544,7 +576,7 @@ def forward(self, data): return data.size() input_info = [([1, 3, 10, 10], "float32")] - verify_model(Size(), input_info) + _verify_model(Size(), input_info) def test_squeeze(): @@ -559,8 +591,8 @@ def forward(self, data): return data.squeeze() input_info = [([3, 1, 4, 1], "float32")] - verify_model(Squeeze1(), input_info) - verify_model(Squeeze2(), input_info) + _verify_model(Squeeze1(), input_info) + _verify_model(Squeeze2(), input_info) def test_unsqueeze(): @@ -575,8 +607,8 @@ def forward(self, data): return data.unsqueeze(-1) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Unsqueeze1(), input_info) - verify_model(Unsqueeze2(), input_info) + _verify_model(Unsqueeze1(), input_info) + _verify_model(Unsqueeze2(), input_info) def test_getattr(): @@ -587,7 +619,7 @@ def forward(self, data): return data.shape input_info = [([1, 3, 10, 10], "float32")] - verify_model(GetAttr1(), input_info) + _verify_model(GetAttr1(), input_info) def test_getitem(): @@ -601,8 +633,8 @@ class Slice2(Module): def forward(self, x): return x[:, None, None, :, None] - verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) - verify_model(Slice2(), [([8, 16], "float32")]) + _verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) + _verify_model(Slice2(), [([8, 16], "float32")]) def test_unary(): @@ -615,42 +647,42 @@ class Sin(Module): def forward(self, data): return torch.sin(data) - verify_model(Sin(), input_info) + _verify_model(Sin(), input_info) # cos class Cos(Module): def forward(self, data): return torch.cos(data) - verify_model(Cos(), input_info) + _verify_model(Cos(), input_info) # exp class Exp(Module): def forward(self, data): return torch.exp(data) - verify_model(Exp(), input_info) + _verify_model(Exp(), input_info) # sqrt class Sqrt(Module): def forward(self, data): return torch.sqrt(data) - verify_model(Sqrt(), input_info) + _verify_model(Sqrt(), input_info) # sigmoid class Sigmoid(Module): def forward(self, data): return torch.sigmoid(data) - verify_model(Sigmoid(), input_info) + _verify_model(Sigmoid(), input_info) # round class Round(Module): def forward(self, data): return torch.round(data) - verify_model(Round(), input_info) + _verify_model(Round(), input_info) def test_gelu(): @@ -661,7 +693,7 @@ def forward(self, data): return torch.nn.functional.gelu(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Gelu(), input_info) + _verify_model(Gelu(), input_info) def test_tanh(): @@ -672,7 +704,7 @@ def forward(self, data): return torch.tanh(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Tanh(), input_info) + _verify_model(Tanh(), input_info) def test_clamp(): @@ -683,7 +715,7 @@ def forward(self, data): return torch.clamp(data, min=0.1, max=0.5) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Clamp(), input_info) + _verify_model(Clamp(), input_info) def test_interpolate(): @@ -694,7 +726,7 @@ def forward(self, data): return torch.nn.functional.interpolate(data, (5, 5)) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Interpolate(), input_info) + _verify_model(Interpolate(), input_info) def test_addmm(): @@ -709,7 +741,7 @@ def forward(self, x_1, x_2, x_3): ([10, 10], "float32"), ([10, 10], "float32"), ] - verify_model(Addmm(), input_info) + _verify_model(Addmm(), input_info) def test_split(): @@ -720,7 +752,7 @@ def forward(self, data): return torch.split(data, 1, dim=1) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info) + _verify_model(Split(), input_info) def test_cumsum(): @@ -731,7 +763,7 @@ def forward(self, data): return torch.cumsum(data, dim=1, dtype=torch.int32) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Cumsum(), input_info) + _verify_model(Cumsum(), input_info) def test_chunk(): @@ -742,7 +774,7 @@ def forward(self, data): return torch.chunk(data, 3, dim=1) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Chunk(), input_info) + _verify_model(Chunk(), input_info) def test_inplace_fill(): @@ -753,7 +785,7 @@ def forward(self, data): data.fill_(1.5) return data - verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) + _verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) def test_arange(): @@ -763,7 +795,7 @@ class Arange(Module): def forward(self): return torch.arange(0, 20, dtype=torch.int32) - verify_model(Arange(), [([10, 10], "float32")]) + _verify_model(Arange(), [([10, 10], "float32")]) def test_empty(): @@ -773,7 +805,7 @@ class Empty(Module): def forward(self): return torch.empty((10, 10), dtype=torch.float32) - verify_model(Empty(), [([10, 10], "float32")]) + _verify_model(Empty(), [([10, 10], "float32")]) def test_tensor(): @@ -787,8 +819,8 @@ class Empty2(Module): def forward(self): return torch.tensor(3) - verify_model(Empty1(), [([10, 10], "float32")]) - verify_model(Empty2(), [([10, 10], "float32")]) + _verify_model(Empty1(), [([10, 10], "float32")]) + _verify_model(Empty2(), [([10, 10], "float32")]) def test_tril(): @@ -804,8 +836,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - verify_model(Tril(), input_info) - verify_model(InplaceTril(), input_info) + _verify_model(Tril(), input_info) + _verify_model(InplaceTril(), input_info) def test_triu(): @@ -821,8 +853,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - verify_model(Triu(), input_info) - verify_model(InplaceTriu(), input_info) + _verify_model(Triu(), input_info) + _verify_model(InplaceTriu(), input_info) def test_new_ones(): @@ -833,7 +865,7 @@ def forward(self, x): return x.new_ones(1, 2, 3) input_info = [([1, 2, 3], "float32")] - verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) + _verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) def test_expand(): @@ -844,7 +876,7 @@ def forward(self, x): return x.expand(4, 2, 3, 4) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info) + _verify_model(Expand(), input_info) def test_reduce(): @@ -856,7 +888,7 @@ def forward(self, x): return torch.sum(x, (2, 1)) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Sum(), input_info) + _verify_model(Sum(), input_info) def test_datatype(): @@ -869,14 +901,14 @@ class ToFloat(Module): def forward(self, x): return x.float() - verify_model(ToFloat(), input_info) + _verify_model(ToFloat(), input_info) # half class ToHalf(Module): def forward(self, x): return x.half() - verify_model(ToHalf(), input_info) + _verify_model(ToHalf(), input_info) # type class Type(Module): @@ -893,9 +925,9 @@ class AsType(Module): def forward(self, x): return x.astype(torch.float32) - verify_model(Type(), input_info) - verify_model(TypeFromAttr(), input_info) - verify_model(AsType(), input_info) + _verify_model(Type(), input_info) + _verify_model(TypeFromAttr(), input_info) + _verify_model(AsType(), input_info) def test_permute(): @@ -906,7 +938,7 @@ def forward(self, x): return x.permute(0, 3, 2, 1) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Permute(), input_info) + _verify_model(Permute(), input_info) def test_reshape(): @@ -917,7 +949,7 @@ def forward(self, x): return x.reshape(2, 12) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Reshape(), input_info) + _verify_model(Reshape(), input_info) def test_transpose(): @@ -928,7 +960,7 @@ def forward(self, x): return x.transpose(1, 3) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Transpose(), input_info) + _verify_model(Transpose(), input_info) def test_view(): @@ -939,7 +971,7 @@ def forward(self, x): return x.view(2, 12) input_info = [([1, 2, 3, 4], "float32")] - verify_model(View(), input_info) + _verify_model(View(), input_info) def test_keep_params(): @@ -953,7 +985,7 @@ def __init__(self): def forward(self, data): return self.conv(data) - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) + _verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) def test_unwrap_unit_return_tuple(): @@ -963,7 +995,7 @@ class Identity(Module): def forward(self, x): return (x,) - verify_model(Identity(), [([256, 256], "float32")]) + _verify_model(Identity(), [([256, 256], "float32")]) def test_no_bind_return_tuple(): @@ -974,7 +1006,7 @@ def forward(self, x, y): return (x, y) input_info = [([256, 256], "float32"), ([256, 256], "float32")] - verify_model(Identity(), input_info) + _verify_model(Identity(), input_info) def test_argmax(): @@ -988,8 +1020,8 @@ class Argmax2(Module): def forward(self, data): return torch.argmax(data, dim=-1, keepdim=True) - verify_model(Argmax1(), [([256, 256], "float32")]) - verify_model(Argmax2(), [([256, 256], "float32")]) + _verify_model(Argmax1(), [([256, 256], "float32")]) + _verify_model(Argmax2(), [([256, 256], "float32")]) def test_argmin(): @@ -1003,8 +1035,8 @@ class Argmin2(Module): def forward(self, data): return torch.argmin(data, keepdim=True) - verify_model(Argmin1(), [([256, 256], "float32")]) - verify_model(Argmin2(), [([256, 256], "float32")]) + _verify_model(Argmin1(), [([256, 256], "float32")]) + _verify_model(Argmin2(), [([256, 256], "float32")]) def test_to(): @@ -1018,8 +1050,8 @@ class To2(Module): def forward(self, data): return data.to("cpu") - verify_model(To1(), [([256, 256], "float32")]) - verify_model(To2(), [([256, 256], "float32")]) + _verify_model(To1(), [([256, 256], "float32")]) + _verify_model(To2(), [([256, 256], "float32")]) def test_mean(): @@ -1033,8 +1065,8 @@ class MeanKeepDim(Module): def forward(self, data): return data.mean(-1, keepdim=True) - verify_model(Mean(), [([256, 256], "float32")]) - verify_model(MeanKeepDim(), [([256, 256], "float32")]) + _verify_model(Mean(), [([256, 256], "float32")]) + _verify_model(MeanKeepDim(), [([256, 256], "float32")]) def test_rsqrt(): @@ -1044,7 +1076,7 @@ class Rsqrt(Module): def forward(self, data): return torch.rsqrt(data) - verify_model(Rsqrt(), [([256, 256], "float32")]) + _verify_model(Rsqrt(), [([256, 256], "float32")]) def test_neg(): @@ -1054,7 +1086,7 @@ class Neg(Module): def forward(self, data): return -data - verify_model(Neg(), [([256, 256], "float32")]) + _verify_model(Neg(), [([256, 256], "float32")]) def test_max(): @@ -1064,7 +1096,7 @@ class Max(Module): def forward(self, x, y): return torch.max(x, y) - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) + _verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) def test_attention(): @@ -1086,14 +1118,14 @@ def forward(self, q_data, k_data, v_data): ([32, 8, 128, 64], "float32"), ([32, 8, 128, 64], "float32"), ] - verify_model(Attention1(), input_info) - verify_model(Attention2(), input_info) + _verify_model(Attention1(), input_info) + _verify_model(Attention2(), input_info) class Attention3(Module): def forward(self, q_data, k_data, v_data, mask): return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) - verify_model( + _verify_model( Attention3(), [ ([32, 8, 128, 64], "float32"), diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index aca0f2689003..39a45035a5b2 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -60,6 +60,8 @@ def verify_model(torch_model, input_info, opt_config=None, codegen_config=None, graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): expected = from_fx(graph_model, input_info) + expected = tvm.relax.transform.CanonicalizeBindings()(expected) + # graph from relay datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] torch_datas = [torch.from_numpy(i) for i in datas]