diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 87e55593e943..363812fd562b 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -40,7 +40,7 @@ def _darknet_not_support(attr, op="relay"): def _get_params_prefix(opname, layer_num): """Makes the params prefix name from opname and layer number.""" - return str(opname) + str(layer_num) + return str(opname).replace(".", "_") + str(layer_num) def _get_params_name(prefix, item): diff --git a/tests/python/frontend/darknet/test_forward.py b/tests/python/frontend/darknet/test_forward.py index 77c72e770fef..1535c3a1b88f 100644 --- a/tests/python/frontend/darknet/test_forward.py +++ b/tests/python/frontend/darknet/test_forward.py @@ -45,6 +45,17 @@ ) +def astext(program, unify_free_vars=False): + """check that program is parsable in text format""" + text = program.astext() + if isinstance(program, relay.Expr): + roundtrip_program = tvm.parser.parse_expr(text) + else: + roundtrip_program = tvm.parser.fromtext(text) + + tvm.ir.assert_structural_equal(roundtrip_program, program, map_free_vars=True) + + def _read_memory_buffer(shape, data, dtype="float32"): length = 1 for x in shape: @@ -59,6 +70,10 @@ def _get_tvm_output(net, data, build_dtype="float32", states=None): """Compute TVM output""" dtype = "float32" mod, params = relay.frontend.from_darknet(net, data.shape, dtype) + # verify that from_darknet creates a valid, parsable relay program + mod = relay.transform.InferType()(mod) + astext(mod) + target = "llvm" shape_dict = {"data": data.shape} lib = relay.build(mod, target, params=params) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 4a3569aca2ec..72a243dbbb67 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -21,6 +21,7 @@ import numpy as np from tvm.relay import Expr from tvm.relay.analysis import free_vars +import pytest DEBUG_PRINT = False @@ -269,6 +270,4 @@ def test_span(): if __name__ == "__main__": - import sys - - pytext.argv(sys.argv) + pytest.main([__file__])