Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 126 additions & 2 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tarfile

import pytest

import tvm
from tvm.ir.module import IRModule

from tvm.driver import tvmc
Expand Down Expand Up @@ -229,3 +228,128 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
)


def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
before = tvmc_model.mod

expected_layout = "NCHW"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NHWC"
and node.attrs.dst_layout == "NCHW"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found"


def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50):
# some CI environments wont offer ONNX, so skip in case it is not present
pytest.importorskip("onnx")

tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"


def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50):
# some CI environments wont offer Paddle, so skip in case it is not present
pytest.importorskip("paddle")

tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle")
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"


def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NHWC"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"


def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50):
# some CI environments wont offer ONNX, so skip in case it is not present
pytest.importorskip("onnx")

tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NCHW"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NCHW"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"
73 changes: 73 additions & 0 deletions tests/python/driver/tvmc/test_pass_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest

from tvm.contrib.target.vitis_ai import vitis_ai_available
from tvm.driver import tvmc

from tvm.driver.tvmc.common import TVMCException


def test_config_invalid_format():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"])


def test_config_missing_from_tvm():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"])


def test_config_unsupported_tvmc_config():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["tir.LoopPartition=value"])


def test_config_empty():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs([""])


def test_config_valid_config_bool():
configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"])

assert len(configs) == 1
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == True


@pytest.mark.skipif(
not vitis_ai_available(),
reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'",
)
def test_config_valid_multiple_configs():
configs = tvmc.common.parse_configs(
[
"relay.backend.use_auto_scheduler=false",
"tir.detect_global_barrier=10",
"relay.ext.vitis_ai.options.build_dir=mystring",
]
)

assert len(configs) == 3
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == False
assert "tir.detect_global_barrier" in configs.keys()
assert configs["tir.detect_global_barrier"] == 10
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import pytest
from tvm.driver import tvmc


def test_common_parse_pass_list_str():
def test_parse_pass_list_str():
assert [""] == tvmc.common.parse_pass_list_str("")
assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps")

Expand Down
96 changes: 96 additions & 0 deletions tests/python/driver/tvmc/test_shape_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse

import pytest

from tvm.driver import tvmc


def test_shape_parser():
# Check that a valid input is parsed correctly
shape_string = "input:[10,10,10]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10]}


def test_alternate_syntax():
shape_string = "input:0:[10,10,10] input2:[20,20,20,20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]}


@pytest.mark.parametrize(
"shape_string",
[
"input:[10,10,10] input2:[20,20,20,20]",
"input: [10, 10, 10] input2: [20, 20, 20, 20]",
"input:[10,10,10],input2:[20,20,20,20]",
],
)
def test_alternate_syntaxes(shape_string):
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}


def test_negative_dimensions():
# Check that negative dimensions parse to Any correctly.
shape_string = "input:[-1,3,224,224]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"


def test_multiple_valid_gpu_inputs():
# Check that multiple valid gpu inputs are parsed correctly.
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}"
assert str(shape_dict) == expected


def test_invalid_pattern():
shape_string = "input:[a,10]"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


def test_invalid_separators():
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


def test_invalid_colon():
shape_string = "gpu_0/data_0:5,10 :test:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


@pytest.mark.parametrize(
"shape_string",
[
"gpu_0/data_0:5,10 /:10,10",
"gpu_0/data_0:5,10 data/:10,10",
"gpu_0/data_0:5,10 /data:10,10",
"gpu_0/invalid/data_0:5,10 data_1:10,10",
],
)
def test_invalid_slashes(shape_string):
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
Loading