Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions python/tvm/driver/tvmc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,103 @@ def parse_shape_string(inputs_string):
shape_dict[name] = shape

return shape_dict


def get_pass_config_value(name, value, config_type):
"""Get a PassContext configuration value, based on its config data type.

Parameters
----------
name: str
config identifier name.
value: str
value assigned to the config, provided via command line.
config_type: str
data type defined to the config, as string.

Returns
-------
parsed_value: bool, int or str
a representation of the input value, converted to the type
specified by config_type.
"""

if config_type == "IntImm":
# "Bool" configurations in the PassContext are recognized as
# IntImm, so deal with this case here
mapping_values = {
"false": False,
"true": True,
}

if value.isdigit():
parsed_value = int(value)
else:
# if not an int, accept only values on the mapping table, case insensitive
parsed_value = mapping_values.get(value.lower(), None)

if parsed_value is None:
raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ")

if config_type == "runtime.String":
parsed_value = value
Comment on lines +439 to +457
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config_type == "IntImm":
# "Bool" configurations in the PassContext are recognized as
# IntImm, so deal with this case here
mapping_values = {
"false": False,
"true": True,
}
if value.isdigit():
parsed_value = int(value)
else:
# if not an int, accept only values on the mapping table, case insensitive
parsed_value = mapping_values.get(value.lower(), None)
if parsed_value is None:
raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ")
if config_type == "runtime.String":
parsed_value = value
parsed_value = value
if config_type == "IntImm":
if value.isdigit():
parsed_value = int(value)
else:
# must be boolean values if not an int
try:
parsed_value = bool(distutils.util.strtobool(value))
except ValueError as err:
raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ")

If the dependency of distuilts is a concen, the following also works:

try:
    parsed_value = json.loads(value)
except json.decoder.JSONDecodeError as err:
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some investigation on this, and wrt to the distutils I didn't want to add that dependency here, because I think it would be a a bit misplaced.

Also wrt to the json approach, I think it would still require more validation because the allowed values for that option are "int numbers", "true" or "false", and opening that to "json.loads" would add all sorts of json passing, then requiring more validation. That's why I added my own mapping table.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm you already processed int so I don't think it would be an issue here. Anyways, I don't have a strong preference of using json so I'll commit.


return parsed_value


def parse_configs(input_configs):
"""Parse configuration values set via command line.

Parameters
----------
input_configs: list of str
list of configurations provided via command line.

Returns
-------
pass_context_configs: dict
a dict containing key-value configs to be used in the PassContext.
"""
if not input_configs:
return {}

all_configs = tvm.ir.transform.PassContext.list_configs()
supported_config_types = ("IntImm", "runtime.String")
supported_configs = [
name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types
]

pass_context_configs = {}

for config in input_configs:
if not config:
raise TVMCException(
f"Invalid format for configuration '{config}', use <config>=<value>"
)

# Each config is expected to be provided as "name=value"
try:
name, value = config.split("=")
name = name.strip()
value = value.strip()
except ValueError:
raise TVMCException(
f"Invalid format for configuration '{config}', use <config>=<value>"
)

if name not in all_configs:
raise TVMCException(
f"Configuration '{name}' is not defined in TVM. "
f"These are the existing configurations: {', '.join(all_configs)}"
)

if name not in supported_configs:
raise TVMCException(
f"Configuration '{name}' uses a data type not supported by TVMC. "
f"The following configurations are supported: {', '.join(supported_configs)}"
)

parsed_value = get_pass_config_value(name, value, all_configs[name]["type"])
pass_context_configs[name] = parsed_value

return pass_context_configs
15 changes: 14 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def add_compile_parser(subparsers):
help="output format. Use 'so' for shared object or 'mlf' for Model Library Format "
"(only for µTVM targets). Defaults to 'so'.",
)
parser.add_argument(
"--pass-config",
action="append",
metavar=("name=value"),
help="configurations to be used at compile time. This option can be provided multiple "
"times, each one to set one configuration value, "
"e.g. '--pass-config relay.backend.use_auto_scheduler=0'.",
)
parser.add_argument(
"--target",
help="compilation targets as comma separated string, inline JSON or path to a JSON file.",
Expand Down Expand Up @@ -145,6 +153,7 @@ def drive_compile(args):
target_host=None,
desired_layout=args.desired_layout,
disabled_pass=args.disabled_pass,
pass_context_configs=args.pass_config,
)

return 0
Expand All @@ -162,6 +171,7 @@ def compile_model(
target_host: Optional[str] = None,
desired_layout: Optional[str] = None,
disabled_pass: Optional[str] = None,
pass_context_configs: Optional[str] = None,
):
"""Compile a model from a supported framework into a TVM module.

Expand Down Expand Up @@ -202,6 +212,9 @@ def compile_model(
disabled_pass: str, optional
Comma-separated list of passes which needs to be disabled
during compilation
pass_context_configs: str, optional
String containing a set of configurations to be passed to the
PassContext.


Returns
Expand All @@ -212,7 +225,7 @@ def compile_model(
"""
mod, params = tvmc_model.mod, tvmc_model.params

config = {}
config = common.parse_configs(pass_context_configs)

if desired_layout:
mod = common.convert_graph_layout(mod, desired_layout)
Expand Down
51 changes: 51 additions & 0 deletions tests/python/driver/tvmc/test_tvmc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

import tvm
from tvm.contrib.target.vitis_ai import vitis_ai_available
from tvm.driver import tvmc

from tvm.driver.tvmc.common import TVMCException
Expand Down Expand Up @@ -306,3 +307,53 @@ def test_parse_quotes_and_separators_on_options():

assert len(targets_double_quote) == 1
assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"]


def test_config_invalid_format():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"])


def test_config_missing_from_tvm():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"])


def test_config_unsupported_tvmc_config():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["tir.LoopPartition=value"])


def test_config_empty():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs([""])


def test_config_valid_config_bool():
configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"])

assert len(configs) == 1
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == True


@pytest.mark.skipif(
not vitis_ai_available(),
reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'",
)
def test_config_valid_multiple_configs():
configs = tvmc.common.parse_configs(
[
"relay.backend.use_auto_scheduler=false",
"tir.detect_global_barrier=10",
"relay.ext.vitis_ai.options.build_dir=mystring",
]
)

assert len(configs) == 3
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == False
assert "tir.detect_global_barrier" in configs.keys()
assert configs["tir.detect_global_barrier"] == 10
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"