From 6fee77aac6f3121bb24bfd61070b76bdc86ed7db Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 4 Jan 2024 15:56:17 -0600 Subject: [PATCH 1/5] [Unity][MSC] Avoid depending on trivial bindings in Relax intermediate The conversion from tensorflow to MSC is done by first converting from tensorflow to relay, then converting from relay to executable python code, executing that python code to generate relax, and finally converting from relax to MSC. During the relax phase of this conversion, some relax `IRModule` are applied, including `FuseOpsByPattern`. The test cases in `test_msc/test_translate_tensorflow.py` rely on `FuseOpsByPattern` preserving trivial bindings (e.g. `var_1 = var_2`) in the relax IRModule. If these trivial bindings are removed by `CanonicalizeBindings`, then the test cases in this file fail. The presence or absence of trivial bindings `FuseOpsByPattern` should be considered an implementation detail, and relax passes should not be required to preserve trivial bindings. This PR updates the relay to executable python step of the tensorflow to MSC conversion, to remove trivial bindings and output a variable name that matches the expected value in the test case. While not an ideal resolution, as other variable name changes could still reintroduce the same test failures, it ensures that `FuseOpsByPattern` may canonicalize bindings as an internal pre- or post-processing step without breaking these unit tests. --- python/tvm/contrib/msc/core/codegen/codegen.py | 10 +++++++++- src/contrib/msc/framework/tvm/codegen.cc | 1 + src/contrib/msc/framework/tvm/relax_opcode.cc | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index f0884bf2d64e..f61432ec2150 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -172,4 +172,12 @@ def relay_to_relax( def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule: return BindParams("main", weights)(mod) - return codegen.load(inputs, post_load=_bind_weights) + mod = codegen.load(inputs, post_load=_bind_weights) + + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + mod = tvm.relax.transform.CanonicalizeBindings()(mod) + + return mod diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index c8956ca399ff..c9fb359559e8 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -105,6 +105,7 @@ void RelaxCodeGen::CodeGenGraph() { } } stack_.func_call("block_builder.emit_output", idx_exit).call_arg(idx_exit); + stack_.call_arg(DocUtils::ToStrDoc(e->name), "name_hint"); idx_exits.push_back(idx_exit); } stack_.scope_end(); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 3bd40cbfd79f..225d4617c029 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -49,7 +49,7 @@ const Array RelaxOpCode::GetDocs() { void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { stack_.func_call("block_builder.emit", ret).call_arg(ret); if (name.size() > 0) { - stack_.call_arg(DocUtils::ToStrDoc(name), "name_hint"); + stack_.call_arg(DocUtils::ToStrDoc(name + "_compute"), "name_hint"); } } From 552c8b6d8889f10b0cdcafabac1665e41309aeb3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 Jan 2024 09:14:04 -0600 Subject: [PATCH 2/5] Update implementation to remove dataflow block in MSC codegen The potential for duplicate variable names was introduced by having the `block_builder.emit_output` call, which is only required to export values from a dataflow block. The dataflow block is not used in any later MSC conversion, and its removal avoids this re-export of variables. If the dataflow block is required in the future, it can be generated using `tvm.relax.transform.ConvertToDataflowBlock`. --- src/contrib/msc/framework/tvm/codegen.cc | 12 +++--------- src/contrib/msc/framework/tvm/relax_opcode.cc | 5 +---- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index ac0a052b5f87..20c47d929125 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -71,9 +71,7 @@ void RelaxCodeGen::CodeGenGraph() { continue; } int scope_level = CompareScope(node); - if (scope_level == 1) { - stack_.scope_start("block_builder.dataflow()"); - } else if (scope_level == -1) { + if (scope_level == -1) { stack_.scope_end(); } CodeGenNode(node, config()->use_tools); @@ -83,13 +81,11 @@ void RelaxCodeGen::CodeGenGraph() { for (size_t i = 0; i < scopes().size() - 1; i++) { stack_.scope_end(); } - } else if (scopes().size() == 0) { - // start dataflow scope for non-scope graph - stack_.scope_start("block_builder.dataflow()"); } // mark outputs stack_.comment("Emit the outputs"); Array idx_exits; + for (const auto& e : graph()->GetExits()) { const auto& idx_exit = IdxNodeBase(e) + (config()->use_tools ? "_exit" : ""); if (config()->use_tools) { @@ -104,11 +100,9 @@ void RelaxCodeGen::CodeGenGraph() { stack_.call_arg(DocUtils::ToStr(e->name + "_exit"), "name_hint"); } } - stack_.func_call("emit_output", idx_exit, "block_builder").call_arg(idx_exit); - stack_.call_arg(DocUtils::ToStr(e->name), "name_hint"); idx_exits.push_back(idx_exit); } - stack_.scope_end(); + if (config()->use_tools) { stack_.func_call("msc_tools.execute_step", "output").call_arg(DocUtils::ToStr("after_build")); if (idx_exits.size() == 1) { diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 7e5b4fe1cc1a..16b78193ae6a 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -36,9 +36,6 @@ const Array RelaxOpCode::GetDocs() { if (node()->optype == "input" || node()->optype == "constant" || node()->optype == "shape") { emit_var = false; } - if (node()->optype == "tuple" && node()->children.size() == 0) { - emit_var = false; - } if (emit_var) { const auto& name = config()->explicit_name ? node()->name : ""; BuilderEmit(IdxNode(), name); @@ -49,7 +46,7 @@ const Array RelaxOpCode::GetDocs() { void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { stack_.func_call("block_builder.emit", ret).call_arg(ret); if (name.size() > 0) { - stack_.call_arg(DocUtils::ToDoc(name + "_compute"), "name_hint"); + stack_.call_arg(DocUtils::ToStr(name), "name_hint"); } } From 63c185c60a17855ae10a7634aa403ade6d6b1c66 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 Jan 2024 15:06:35 -0600 Subject: [PATCH 3/5] Make failing test cases be close to the same structural form --- python/tvm/contrib/msc/core/codegen/codegen.py | 16 +++++++++++----- .../contrib/msc/framework/tvm/codegen/codegen.py | 16 +++++++++++++++- .../contrib/test_msc/test_translate_relax.py | 2 ++ .../contrib/test_msc/test_translate_relay.py | 2 ++ 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/msc/core/codegen/codegen.py b/python/tvm/contrib/msc/core/codegen/codegen.py index f61432ec2150..e9013ba66153 100644 --- a/python/tvm/contrib/msc/core/codegen/codegen.py +++ b/python/tvm/contrib/msc/core/codegen/codegen.py @@ -174,10 +174,16 @@ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRMo mod = codegen.load(inputs, post_load=_bind_weights) - # The canonicalization of relax variable bindings is not required - # for correctness. It does, however, remove trivial `x = y` - # bindings, preventing test cases from depending on their - # presence. - mod = tvm.relax.transform.CanonicalizeBindings()(mod) + mod = tvm.ir.transform.Sequential( + [ + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ], + name="tvm.contrib.msc.core.codegen.relay_to_relax_postproc", + )(mod) return mod diff --git a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py index c30e05ed989c..489ca0a2b528 100644 --- a/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tvm/codegen/codegen.py @@ -71,4 +71,18 @@ def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRMo return mod codegen = CodeGen(graph, _ffi_api.GetRelaxSources, codegen_config, print_config, build_folder) - return codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights) + mod = codegen.load(inputs, pre_load=_save_weights, post_load=_bind_weights) + + mod = tvm.ir.transform.Sequential( + [ + # The canonicalization of relax variable bindings is not required + # for correctness. It does, however, remove trivial `x = y` + # bindings, preventing test cases from depending on their + # presence. + tvm.relax.transform.CanonicalizeBindings(), + tvm.relax.transform.ConvertToDataflow(min_size=1), + ], + name="tvm.contrib.msc.framework.tvm.codegen.to_relax_postproc", + )(mod) + + return mod diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index cff2ee18ca6c..4824a8f25764 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -31,6 +31,8 @@ def verify_model(torch_model, input_info, opt_config=None): graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): expected = from_fx(graph_model, input_info) + expected = tvm.relax.transform.CanonicalizeBindings()(expected) + graph, weights = translate.from_relax(expected, opt_config=opt_config) mod = tvm_codegen.to_relax(graph, weights, codegen_config={"explicit_name": False}) tvm.ir.assert_structural_equal(mod, expected) diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index aca0f2689003..39a45035a5b2 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -60,6 +60,8 @@ def verify_model(torch_model, input_info, opt_config=None, codegen_config=None, graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): expected = from_fx(graph_model, input_info) + expected = tvm.relax.transform.CanonicalizeBindings()(expected) + # graph from relay datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] torch_datas = [torch.from_numpy(i) for i in datas] From 95713bc6f15b89659de72bd2a7ec2a8e7076aeef Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Jan 2024 16:00:22 +0000 Subject: [PATCH 4/5] Updated tests to validate output after compilation --- .../contrib/test_msc/test_translate_relax.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 4824a8f25764..16e5dbc2447a 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -26,16 +26,46 @@ from tvm.contrib.msc.core.frontend import translate from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen +import numpy as np + def verify_model(torch_model, input_info, opt_config=None): graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): - expected = from_fx(graph_model, input_info) - expected = tvm.relax.transform.CanonicalizeBindings()(expected) + orig_mod = from_fx(graph_model, input_info) + + target = "llvm" + dev = tvm.cpu() + args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info] + + def _tvm_runtime_to_np(obj): + if isinstance(obj, tvm.runtime.NDArray): + return obj.numpy() + elif isinstance(obj, tvm.runtime.ShapeTuple): + return np.array(obj, dtype="int64") + elif isinstance(obj, (list, tvm.ir.container.Array)): + return [_tvm_runtime_to_np(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(_tvm_runtime_to_np(item) for item in obj) + else: + return obj + + def _run_relax(relax_mod): + relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) + relax_exec = tvm.relax.build(relax_mod, target) + vm = tvm.relax.VirtualMachine(relax_exec, dev) + res = vm["main"](*args) + + return _tvm_runtime_to_np(res) + + rt_mod = tvm_codegen.to_relax( + *translate.from_relax(orig_mod, opt_config=opt_config), + codegen_config={"explicit_name": False}, + ) - graph, weights = translate.from_relax(expected, opt_config=opt_config) - mod = tvm_codegen.to_relax(graph, weights, codegen_config={"explicit_name": False}) - tvm.ir.assert_structural_equal(mod, expected) + orig_output = _run_relax(orig_mod) + rt_output = _run_relax(rt_mod) + tvm.testing.assert_allclose(orig_output, rt_output) def test_conv1d(): From 0682cc74043f61d819fac54b254c09af15773fad Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 Jan 2024 17:06:42 +0000 Subject: [PATCH 5/5] Lint fixes --- .../contrib/test_msc/test_translate_relax.py | 230 +++++++++--------- 1 file changed, 115 insertions(+), 115 deletions(-) diff --git a/tests/python/contrib/test_msc/test_translate_relax.py b/tests/python/contrib/test_msc/test_translate_relax.py index 16e5dbc2447a..fdc15777152b 100644 --- a/tests/python/contrib/test_msc/test_translate_relax.py +++ b/tests/python/contrib/test_msc/test_translate_relax.py @@ -21,15 +21,15 @@ from torch import fx from torch.nn import Module +import numpy as np + import tvm.testing from tvm.relax.frontend.torch import from_fx from tvm.contrib.msc.core.frontend import translate from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen -import numpy as np - -def verify_model(torch_model, input_info, opt_config=None): +def _verify_model(torch_model, input_info, opt_config=None): graph_model = fx.symbolic_trace(torch_model) with torch.no_grad(): orig_mod = from_fx(graph_model, input_info) @@ -53,8 +53,8 @@ def _tvm_runtime_to_np(obj): def _run_relax(relax_mod): relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) relax_exec = tvm.relax.build(relax_mod, target) - vm = tvm.relax.VirtualMachine(relax_exec, dev) - res = vm["main"](*args) + vm_runner = tvm.relax.VirtualMachine(relax_exec, dev) + res = vm_runner["main"](*args) return _tvm_runtime_to_np(res) @@ -88,8 +88,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10], "float32")] - verify_model(Conv1D1(), input_info) - verify_model(Conv1D2(), input_info) + _verify_model(Conv1D1(), input_info) + _verify_model(Conv1D2(), input_info) def test_conv2d(): @@ -112,8 +112,8 @@ def forward(self, data): return self.conv(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Conv2D1(), input_info) - verify_model(Conv2D2(), input_info) + _verify_model(Conv2D1(), input_info) + _verify_model(Conv2D2(), input_info) def test_linear(): @@ -140,9 +140,9 @@ def forward(self, x, y): return torch.matmul(x, y) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dense1(), input_info) - verify_model(Dense2(), input_info) - verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) + _verify_model(Dense1(), input_info) + _verify_model(Dense2(), input_info) + _verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) def test_bmm(): @@ -153,7 +153,7 @@ def forward(self, x, y): return torch.bmm(x, y) input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] - verify_model(BMM(), input_info) + _verify_model(BMM(), input_info) def test_baddbmm(): @@ -172,8 +172,8 @@ def forward(self, c, x, y): ((4, 128, 256), "float32"), ((4, 256, 512), "float32"), ] - verify_model(BAddBMM1(), input_info) - verify_model(BAddBMM2(), input_info) + _verify_model(BAddBMM1(), input_info) + _verify_model(BAddBMM2(), input_info) def test_relu(): @@ -192,8 +192,8 @@ def forward(self, data): return torch.nn.functional.relu(data) input_info = [([10, 10], "float32")] - verify_model(ReLU(), input_info) - verify_model(ReLU1(), input_info) + _verify_model(ReLU(), input_info) + _verify_model(ReLU1(), input_info) def test_relu6(): @@ -208,7 +208,7 @@ def forward(self, data): return self.relu6(data) input_info = [([10, 10], "float32")] - verify_model(ReLU6(), input_info) + _verify_model(ReLU6(), input_info) def test_maxpool2d(): @@ -239,9 +239,9 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(MaxPool2d(), input_info) - verify_model(MaxPool2d2(), input_info) - verify_model(MaxPool2d3(), input_info) + _verify_model(MaxPool2d(), input_info) + _verify_model(MaxPool2d2(), input_info) + _verify_model(MaxPool2d3(), input_info) def test_avgpool2d(): @@ -264,8 +264,8 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(AvgPool2d(), input_info) - verify_model(AvgPool2d2(), input_info) + _verify_model(AvgPool2d(), input_info) + _verify_model(AvgPool2d2(), input_info) def test_adaptive_avgpool2d(): @@ -280,7 +280,7 @@ def forward(self, data): return self.pool(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(AdaptiveAvgPool2d0(), input_info) + _verify_model(AdaptiveAvgPool2d0(), input_info) def test_flatten(): @@ -295,8 +295,8 @@ def forward(self, data): return self.f(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Flatten(), input_info) - verify_model(torch.nn.Flatten(2, -1), input_info) + _verify_model(Flatten(), input_info) + _verify_model(torch.nn.Flatten(2, -1), input_info) def test_batchnorm2d(): @@ -311,7 +311,7 @@ def forward(self, data): return self.batchnorm(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(BatchNorm2d(), input_info) + _verify_model(BatchNorm2d(), input_info) def test_embedding(): @@ -325,8 +325,8 @@ def __init__(self): def forward(self, data): return self.embedding(data) - verify_model(Embedding(), [([4], "int64")]) - verify_model(Embedding(), [([4, 5], "int64")]) + _verify_model(Embedding(), [([4], "int64")]) + _verify_model(Embedding(), [([4, 5], "int64")]) def test_dropout(): @@ -345,8 +345,8 @@ def forward(self, data): return torch.dropout(data, 0.5, train=True) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Dropout1(), input_info) - verify_model(Dropout2(), input_info) + _verify_model(Dropout1(), input_info) + _verify_model(Dropout2(), input_info) def test_layernorm(): @@ -361,7 +361,7 @@ def forward(self, data): return self.layernorm(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm(), input_info) + _verify_model(LayerNorm(), input_info) def test_functional_layernorm(): @@ -379,7 +379,7 @@ def forward(self, data): ) input_info = [([1, 3, 10, 10], "float32")] - verify_model(LayerNorm((10, 10)), input_info) + _verify_model(LayerNorm((10, 10)), input_info) def test_cross_entropy(): @@ -411,9 +411,9 @@ def forward(self, logits, targets): return self.loss(logits, targets) input_info = [([3, 2], "float32"), ([3], "int32")] - verify_model(CrossEntropy1(), input_info) - verify_model(CrossEntropy2(), input_info) - verify_model(CrossEntropy3(), input_info) + _verify_model(CrossEntropy1(), input_info) + _verify_model(CrossEntropy2(), input_info) + _verify_model(CrossEntropy3(), input_info) def test_functional_cross_entropy(): @@ -424,7 +424,7 @@ def forward(self, logits, targets): return torch.nn.functional.cross_entropy(logits, targets) input_info = [([3, 10], "float32"), ([3], "int32")] - verify_model(CrossEntropy(), input_info) + _verify_model(CrossEntropy(), input_info) def test_silu(): @@ -443,8 +443,8 @@ def forward(self, data): return torch.nn.functional.silu(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(SiLU(), input_info) - verify_model(SiLU2(), input_info) + _verify_model(SiLU(), input_info) + _verify_model(SiLU2(), input_info) def test_groupnorm(): @@ -459,7 +459,7 @@ def forward(self, data): return self.groupnorm(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(GroupNorm(), input_info) + _verify_model(GroupNorm(), input_info) def test_softmax(): @@ -474,7 +474,7 @@ def forward(self, data): return self.softmax(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Softmax(), input_info) + _verify_model(Softmax(), input_info) def test_binary(): @@ -492,8 +492,8 @@ class Add2(Module): def forward(self, lhs): return lhs + 1.0 - verify_model(Add1(), input_info1) - verify_model(Add2(), input_info2) + _verify_model(Add1(), input_info1) + _verify_model(Add2(), input_info2) # Sub class Sub1(Module): @@ -504,8 +504,8 @@ class Sub2(Module): def forward(self, lhs): return lhs - 1.0 - verify_model(Sub1(), input_info1) - verify_model(Sub2(), input_info2) + _verify_model(Sub1(), input_info1) + _verify_model(Sub2(), input_info2) # Mul class Mul1(Module): @@ -516,8 +516,8 @@ class Mul2(Module): def forward(self, lhs): return lhs * 1.0 - verify_model(Mul1(), input_info1) - verify_model(Mul2(), input_info2) + _verify_model(Mul1(), input_info1) + _verify_model(Mul2(), input_info2) # True div class TrueDiv1(Module): @@ -528,8 +528,8 @@ class TrueDiv2(Module): def forward(self, lhs): return lhs / 1.0 - verify_model(TrueDiv1(), input_info1) - verify_model(TrueDiv2(), input_info2) + _verify_model(TrueDiv1(), input_info1) + _verify_model(TrueDiv2(), input_info2) # Floor div class FloorDiv1(Module): @@ -540,8 +540,8 @@ class FloorDiv2(Module): def forward(self, lhs): return lhs // 1.0 - verify_model(FloorDiv1(), input_info1) - verify_model(FloorDiv2(), input_info2) + _verify_model(FloorDiv1(), input_info1) + _verify_model(FloorDiv2(), input_info2) # Power class Power1(Module): @@ -552,8 +552,8 @@ class Power2(Module): def forward(self, lhs): return lhs**1.0 - verify_model(Power1(), input_info1) - verify_model(Power2(), input_info2) + _verify_model(Power1(), input_info1) + _verify_model(Power2(), input_info2) # LT class LT1(Module): @@ -564,8 +564,8 @@ class LT2(Module): def forward(self, lhs): return lhs < 1.0 - verify_model(LT1(), input_info1) - verify_model(LT2(), input_info2) + _verify_model(LT1(), input_info1) + _verify_model(LT2(), input_info2) def test_size(): @@ -576,7 +576,7 @@ def forward(self, data): return data.size() input_info = [([1, 3, 10, 10], "float32")] - verify_model(Size(), input_info) + _verify_model(Size(), input_info) def test_squeeze(): @@ -591,8 +591,8 @@ def forward(self, data): return data.squeeze() input_info = [([3, 1, 4, 1], "float32")] - verify_model(Squeeze1(), input_info) - verify_model(Squeeze2(), input_info) + _verify_model(Squeeze1(), input_info) + _verify_model(Squeeze2(), input_info) def test_unsqueeze(): @@ -607,8 +607,8 @@ def forward(self, data): return data.unsqueeze(-1) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Unsqueeze1(), input_info) - verify_model(Unsqueeze2(), input_info) + _verify_model(Unsqueeze1(), input_info) + _verify_model(Unsqueeze2(), input_info) def test_getattr(): @@ -619,7 +619,7 @@ def forward(self, data): return data.shape input_info = [([1, 3, 10, 10], "float32")] - verify_model(GetAttr1(), input_info) + _verify_model(GetAttr1(), input_info) def test_getitem(): @@ -633,8 +633,8 @@ class Slice2(Module): def forward(self, x): return x[:, None, None, :, None] - verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) - verify_model(Slice2(), [([8, 16], "float32")]) + _verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) + _verify_model(Slice2(), [([8, 16], "float32")]) def test_unary(): @@ -647,42 +647,42 @@ class Sin(Module): def forward(self, data): return torch.sin(data) - verify_model(Sin(), input_info) + _verify_model(Sin(), input_info) # cos class Cos(Module): def forward(self, data): return torch.cos(data) - verify_model(Cos(), input_info) + _verify_model(Cos(), input_info) # exp class Exp(Module): def forward(self, data): return torch.exp(data) - verify_model(Exp(), input_info) + _verify_model(Exp(), input_info) # sqrt class Sqrt(Module): def forward(self, data): return torch.sqrt(data) - verify_model(Sqrt(), input_info) + _verify_model(Sqrt(), input_info) # sigmoid class Sigmoid(Module): def forward(self, data): return torch.sigmoid(data) - verify_model(Sigmoid(), input_info) + _verify_model(Sigmoid(), input_info) # round class Round(Module): def forward(self, data): return torch.round(data) - verify_model(Round(), input_info) + _verify_model(Round(), input_info) def test_gelu(): @@ -693,7 +693,7 @@ def forward(self, data): return torch.nn.functional.gelu(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Gelu(), input_info) + _verify_model(Gelu(), input_info) def test_tanh(): @@ -704,7 +704,7 @@ def forward(self, data): return torch.tanh(data) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Tanh(), input_info) + _verify_model(Tanh(), input_info) def test_clamp(): @@ -715,7 +715,7 @@ def forward(self, data): return torch.clamp(data, min=0.1, max=0.5) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Clamp(), input_info) + _verify_model(Clamp(), input_info) def test_interpolate(): @@ -726,7 +726,7 @@ def forward(self, data): return torch.nn.functional.interpolate(data, (5, 5)) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Interpolate(), input_info) + _verify_model(Interpolate(), input_info) def test_addmm(): @@ -741,7 +741,7 @@ def forward(self, x_1, x_2, x_3): ([10, 10], "float32"), ([10, 10], "float32"), ] - verify_model(Addmm(), input_info) + _verify_model(Addmm(), input_info) def test_split(): @@ -752,7 +752,7 @@ def forward(self, data): return torch.split(data, 1, dim=1) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Split(), input_info) + _verify_model(Split(), input_info) def test_cumsum(): @@ -763,7 +763,7 @@ def forward(self, data): return torch.cumsum(data, dim=1, dtype=torch.int32) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Cumsum(), input_info) + _verify_model(Cumsum(), input_info) def test_chunk(): @@ -774,7 +774,7 @@ def forward(self, data): return torch.chunk(data, 3, dim=1) input_info = [([1, 3, 10, 10], "float32")] - verify_model(Chunk(), input_info) + _verify_model(Chunk(), input_info) def test_inplace_fill(): @@ -785,7 +785,7 @@ def forward(self, data): data.fill_(1.5) return data - verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) + _verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) def test_arange(): @@ -795,7 +795,7 @@ class Arange(Module): def forward(self): return torch.arange(0, 20, dtype=torch.int32) - verify_model(Arange(), [([10, 10], "float32")]) + _verify_model(Arange(), [([10, 10], "float32")]) def test_empty(): @@ -805,7 +805,7 @@ class Empty(Module): def forward(self): return torch.empty((10, 10), dtype=torch.float32) - verify_model(Empty(), [([10, 10], "float32")]) + _verify_model(Empty(), [([10, 10], "float32")]) def test_tensor(): @@ -819,8 +819,8 @@ class Empty2(Module): def forward(self): return torch.tensor(3) - verify_model(Empty1(), [([10, 10], "float32")]) - verify_model(Empty2(), [([10, 10], "float32")]) + _verify_model(Empty1(), [([10, 10], "float32")]) + _verify_model(Empty2(), [([10, 10], "float32")]) def test_tril(): @@ -836,8 +836,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - verify_model(Tril(), input_info) - verify_model(InplaceTril(), input_info) + _verify_model(Tril(), input_info) + _verify_model(InplaceTril(), input_info) def test_triu(): @@ -853,8 +853,8 @@ def forward(self, data): return data input_info = [([10, 10], "float32")] - verify_model(Triu(), input_info) - verify_model(InplaceTriu(), input_info) + _verify_model(Triu(), input_info) + _verify_model(InplaceTriu(), input_info) def test_new_ones(): @@ -865,7 +865,7 @@ def forward(self, x): return x.new_ones(1, 2, 3) input_info = [([1, 2, 3], "float32")] - verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) + _verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) def test_expand(): @@ -876,7 +876,7 @@ def forward(self, x): return x.expand(4, 2, 3, 4) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Expand(), input_info) + _verify_model(Expand(), input_info) def test_reduce(): @@ -888,7 +888,7 @@ def forward(self, x): return torch.sum(x, (2, 1)) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Sum(), input_info) + _verify_model(Sum(), input_info) def test_datatype(): @@ -901,14 +901,14 @@ class ToFloat(Module): def forward(self, x): return x.float() - verify_model(ToFloat(), input_info) + _verify_model(ToFloat(), input_info) # half class ToHalf(Module): def forward(self, x): return x.half() - verify_model(ToHalf(), input_info) + _verify_model(ToHalf(), input_info) # type class Type(Module): @@ -925,9 +925,9 @@ class AsType(Module): def forward(self, x): return x.astype(torch.float32) - verify_model(Type(), input_info) - verify_model(TypeFromAttr(), input_info) - verify_model(AsType(), input_info) + _verify_model(Type(), input_info) + _verify_model(TypeFromAttr(), input_info) + _verify_model(AsType(), input_info) def test_permute(): @@ -938,7 +938,7 @@ def forward(self, x): return x.permute(0, 3, 2, 1) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Permute(), input_info) + _verify_model(Permute(), input_info) def test_reshape(): @@ -949,7 +949,7 @@ def forward(self, x): return x.reshape(2, 12) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Reshape(), input_info) + _verify_model(Reshape(), input_info) def test_transpose(): @@ -960,7 +960,7 @@ def forward(self, x): return x.transpose(1, 3) input_info = [([1, 2, 3, 4], "float32")] - verify_model(Transpose(), input_info) + _verify_model(Transpose(), input_info) def test_view(): @@ -971,7 +971,7 @@ def forward(self, x): return x.view(2, 12) input_info = [([1, 2, 3, 4], "float32")] - verify_model(View(), input_info) + _verify_model(View(), input_info) def test_keep_params(): @@ -985,7 +985,7 @@ def __init__(self): def forward(self, data): return self.conv(data) - verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) + _verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) def test_unwrap_unit_return_tuple(): @@ -995,7 +995,7 @@ class Identity(Module): def forward(self, x): return (x,) - verify_model(Identity(), [([256, 256], "float32")]) + _verify_model(Identity(), [([256, 256], "float32")]) def test_no_bind_return_tuple(): @@ -1006,7 +1006,7 @@ def forward(self, x, y): return (x, y) input_info = [([256, 256], "float32"), ([256, 256], "float32")] - verify_model(Identity(), input_info) + _verify_model(Identity(), input_info) def test_argmax(): @@ -1020,8 +1020,8 @@ class Argmax2(Module): def forward(self, data): return torch.argmax(data, dim=-1, keepdim=True) - verify_model(Argmax1(), [([256, 256], "float32")]) - verify_model(Argmax2(), [([256, 256], "float32")]) + _verify_model(Argmax1(), [([256, 256], "float32")]) + _verify_model(Argmax2(), [([256, 256], "float32")]) def test_argmin(): @@ -1035,8 +1035,8 @@ class Argmin2(Module): def forward(self, data): return torch.argmin(data, keepdim=True) - verify_model(Argmin1(), [([256, 256], "float32")]) - verify_model(Argmin2(), [([256, 256], "float32")]) + _verify_model(Argmin1(), [([256, 256], "float32")]) + _verify_model(Argmin2(), [([256, 256], "float32")]) def test_to(): @@ -1050,8 +1050,8 @@ class To2(Module): def forward(self, data): return data.to("cpu") - verify_model(To1(), [([256, 256], "float32")]) - verify_model(To2(), [([256, 256], "float32")]) + _verify_model(To1(), [([256, 256], "float32")]) + _verify_model(To2(), [([256, 256], "float32")]) def test_mean(): @@ -1065,8 +1065,8 @@ class MeanKeepDim(Module): def forward(self, data): return data.mean(-1, keepdim=True) - verify_model(Mean(), [([256, 256], "float32")]) - verify_model(MeanKeepDim(), [([256, 256], "float32")]) + _verify_model(Mean(), [([256, 256], "float32")]) + _verify_model(MeanKeepDim(), [([256, 256], "float32")]) def test_rsqrt(): @@ -1076,7 +1076,7 @@ class Rsqrt(Module): def forward(self, data): return torch.rsqrt(data) - verify_model(Rsqrt(), [([256, 256], "float32")]) + _verify_model(Rsqrt(), [([256, 256], "float32")]) def test_neg(): @@ -1086,7 +1086,7 @@ class Neg(Module): def forward(self, data): return -data - verify_model(Neg(), [([256, 256], "float32")]) + _verify_model(Neg(), [([256, 256], "float32")]) def test_max(): @@ -1096,7 +1096,7 @@ class Max(Module): def forward(self, x, y): return torch.max(x, y) - verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) + _verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) def test_attention(): @@ -1118,14 +1118,14 @@ def forward(self, q_data, k_data, v_data): ([32, 8, 128, 64], "float32"), ([32, 8, 128, 64], "float32"), ] - verify_model(Attention1(), input_info) - verify_model(Attention2(), input_info) + _verify_model(Attention1(), input_info) + _verify_model(Attention2(), input_info) class Attention3(Module): def forward(self, q_data, k_data, v_data, mask): return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) - verify_model( + _verify_model( Attention3(), [ ([32, 8, 128, 64], "float32"),