Skip to content

Commit

Permalink
[TVMC] Add configuration tir.add_lower_pass to option `--pass-confi…
Browse files Browse the repository at this point in the history
…g` (apache#9817)
  • Loading branch information
leeexyz authored and ylc committed Feb 16, 2022
1 parent ebafa06 commit 3e11289
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 6 deletions.
3 changes: 2 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def add_compile_parser(subparsers, _):
metavar=("name=value"),
help="configurations to be used at compile time. This option can be provided multiple "
"times, each one to set one configuration value, "
"e.g. '--pass-config relay.backend.use_auto_scheduler=0'.",
"e.g. '--pass-config relay.backend.use_auto_scheduler=0', "
"e.g. '--pass-config tir.add_lower_pass=opt_level1,pass1,opt_level2,pass2'.",
)

generate_target_args(parser)
Expand Down
82 changes: 77 additions & 5 deletions python/tvm/driver/tvmc/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,41 @@
TVMC PassContext Interface
"""

import importlib

import tvm
from tvm.driver.tvmc import TVMCException


def load_function(full_name):
"""Dynamic loading a function by the full name.
Parameters
----------
full_name: str
The name of a PackedFunc or a string of the form "path.to.module.func"
that indicates the module that can be imported.
You must be aware of the load order here, it first tries to find it via
TVM global function, if not find, try to import it by "importlib.import_module".
Returns
-------
func: function or PackedFunc
The loaded fucntion.
"""
global_func = tvm.get_global_func(full_name, allow_missing=True)
if global_func is not None:
return global_func

# split full name "path.to.module.func" into two parts ["path.to.module", "func"]
module_name, func_name = full_name.rsplit(".", 1)

# import module and find the function
module = importlib.import_module(module_name)
if hasattr(module, func_name):
return getattr(module, func_name)

raise TVMCException(f"No function '{func_name}' found in module '{module_name}'.")


def get_pass_config_value(name, value, config_type):
"""Get a PassContext configuration value, based on its config data type.
Expand All @@ -41,6 +72,8 @@ def get_pass_config_value(name, value, config_type):
specified by config_type.
"""

parsed_value = None

if config_type == "IntImm":
# "Bool" configurations in the PassContext are recognized as
# IntImm, so deal with this case here
Expand All @@ -56,11 +89,44 @@ def get_pass_config_value(name, value, config_type):
parsed_value = mapping_values.get(value.lower(), None)

if parsed_value is None:
raise TVMCException(f"Invalid value '{value}' for configuration '{name}'. ")
raise TVMCException(f"Invalid value '{value}' for configuration '{name}'.")

if config_type == "runtime.String":
elif config_type == "runtime.String":
parsed_value = value

elif config_type == "Array":
if name == "tir.add_lower_pass":
pass_list = value.split(",")
if len(pass_list) % 2 != 0:
raise TVMCException(
f"The configuration of '{name}' must be of the form "
"'tir.add_lower_pass=opt_level1,pass1,opt_evel2,pass2'"
)

parsed_value = []
for i in range(0, len(pass_list), 2):
level, pass_func = pass_list[i].strip(), pass_list[i + 1].strip()
try:
level = int(level)
except ValueError:
raise TVMCException(f"Only integer is allow for configuration '{name}'.")

# TODO (@leeexyz) We should parse configurations of each tir Pass.
# For now, we only use the defaults. Currently, There are four config nodes:
# `tir.transform.LoopPartitionConfig`
# `tir.transform.UnrollLoopConfig`
# `tir.transform.HoistIfThenElseConfig`
# `tir.transform.InjectDoubleBufferConfig`
# loading pass func and calling it to get the Pass
pass_func = load_function(pass_func)()
parsed_value.append((level, pass_func))
else:
raise TVMCException(f"Unsupported configuration '{name}' for '{config_type}' type.")

else:
# not raise here cause we alreay checked before calling this function
pass

return parsed_value


Expand All @@ -81,7 +147,7 @@ def parse_configs(input_configs):
return {}

all_configs = tvm.ir.transform.PassContext.list_configs()
supported_config_types = ("IntImm", "runtime.String")
supported_config_types = ("IntImm", "runtime.String", "Array")
supported_configs = [
name for name in all_configs.keys() if all_configs[name]["type"] in supported_config_types
]
Expand Down Expand Up @@ -116,7 +182,13 @@ def parse_configs(input_configs):
f"The following configurations are supported: {', '.join(supported_configs)}"
)

parsed_value = get_pass_config_value(name, value, all_configs[name]["type"])
pass_context_configs[name] = parsed_value
config_type = all_configs[name]["type"]
parsed_value = get_pass_config_value(name, value, config_type)

if config_type == "Array" and name in pass_context_configs:
# merge configs if the configuration exists
pass_context_configs[name].extend(parsed_value)
else:
pass_context_configs[name] = parsed_value

return pass_context_configs
88 changes: 88 additions & 0 deletions tests/python/driver/tvmc/test_pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.

import pytest
from unittest import mock

from tvm.contrib.target.vitis_ai import vitis_ai_available

from tvm.driver.tvmc import TVMCException
from tvm.driver.tvmc.pass_config import parse_configs
from tvm.tir.transform import PrimFuncPass


def test_config_invalid_format():
Expand Down Expand Up @@ -71,3 +73,89 @@ def test_config_valid_multiple_configs():
assert configs["tir.detect_global_barrier"] == 10
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"


def test_add_lower_pass_multi_built_in_pass():
configs = parse_configs(
[
"tir.add_lower_pass=1,tir.transform.UnrollLoop",
"tir.add_lower_pass=1,tir.transform.HoistIfThenElse,2,tir.transform.LoopPartition",
]
)

assert len(configs["tir.add_lower_pass"]) == 3
# opt_level: 1, pass: tir.transform.UnrollLoop
assert configs["tir.add_lower_pass"][0][0] == 1
assert isinstance(configs["tir.add_lower_pass"][0][1], PrimFuncPass)
# opt_level: 1, pass: tir.transform.HoistIfThenElse
assert configs["tir.add_lower_pass"][1][0] == 1
assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
# opt_level: 2, pass: tir.transform.LoopPartition
assert configs["tir.add_lower_pass"][2][0] == 2
assert isinstance(configs["tir.add_lower_pass"][2][1], PrimFuncPass)


def test_add_lower_pass_multi_external_pass():
fake_pass_1 = mock.MagicMock()
fake_pass_2 = mock.MagicMock()
fake_pass_3 = mock.MagicMock()
with mock.patch.dict(
"sys.modules",
{"fake_module": fake_pass_1, "fake_module": fake_pass_2, "fake_module": fake_pass_3},
):
configs = parse_configs(
[
"tir.add_lower_pass=1,fake_module.fake_pass_1,2,fake_module.fake_pass2",
"tir.add_lower_pass=3,fake_module.fake_pass_3",
]
)
assert len(configs["tir.add_lower_pass"]) == 3
# opt_level: 1, pass: fake_module.fake_pass_1
assert configs["tir.add_lower_pass"][0][0] == 1
# opt_level: 2, pass: fake_module.fake_pass_2
assert configs["tir.add_lower_pass"][1][0] == 2
# opt_level: 3, pass: fake_module.fake_pass_3
assert configs["tir.add_lower_pass"][2][0] == 3


def test_add_lower_pass_multi_mix_pass():
fake_pass_1 = mock.MagicMock()
fake_pass_2 = mock.MagicMock()
with mock.patch.dict("sys.modules", {"fake_module": fake_pass_1, "fake_module": fake_pass_2}):
configs = parse_configs(
[
"tir.add_lower_pass=1,fake_module.fake_pass_1,1,tir.transform.UnrollLoop",
"tir.add_lower_pass=2,fake_module.fake_pass_2,2,tir.transform.LoopPartition",
]
)
assert len(configs["tir.add_lower_pass"]) == 4
# opt_level: 1, pass: fake_module.fake_pass_1
assert configs["tir.add_lower_pass"][0][0] == 1
# opt_level: 1, pass: tir.transform.UnrollLoop
assert configs["tir.add_lower_pass"][1][0] == 1
assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
# opt_level: 2, pass: fake_module.fake_pass_2
assert configs["tir.add_lower_pass"][2][0] == 2
# opt_level: 2, pass: tir.transform.LoopPartition
assert configs["tir.add_lower_pass"][3][0] == 2
assert isinstance(configs["tir.add_lower_pass"][3][1], PrimFuncPass)


def test_add_lower_pass_invalid_format():
# wrong format
with pytest.raises(TVMCException):
_ = parse_configs(["tir.add_lower_pass=tir.transform.UnrollLoop,1"])
# missing pass name
with pytest.raises(TVMCException):
_ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,3"])
# wrong opt level
with pytest.raises(TVMCException):
_ = parse_configs(["tir.add_lower_pass=a,tir.transform.UnrollLoop"])
# fake module
with pytest.raises(ModuleNotFoundError):
_ = parse_configs(
["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,path.to.module.fake_func"]
)
# real module and fake func
with pytest.raises(TVMCException):
_ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,2,tvm.tir.fake_func"])

0 comments on commit 3e11289

Please sign in to comment.